Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix field data flux #1606

Merged
merged 1 commit into from
Apr 13, 2024
Merged

fix field data flux #1606

merged 1 commit into from
Apr 13, 2024

Conversation

tylerflex
Copy link
Collaborator

@tylerflex tylerflex commented Apr 12, 2024

The adjoint plugin FieldData.flux previously returned a float that summed over all of the flux values.

For multi-frequency adjoint, this is not quite what we want, rather we want a JaxDataArray with frequency coordinate.

(I think this was added before multi-frequency adjoint and just wasn't fixed when that was introduced).

This PR properly returns a JaxDataArray (FluxDataArray).

A potential issue for backwards compatibility if users have .flux in their notebooks. I'm not sure how to resolve this. Options are

  1. Just change it to the "expected" behavior (return data array) and explain if users are confused.
  2. Do some automatic sum() if certain conditions are met (like a single frequency coordinate)
  3. Add a sum: bool=True kwarg so we can instruct users that want all the flux values to set this to False

fyi @tomflexcompute

@momchil-flex
Copy link
Collaborator

Ah. Kinda bad. I kinda prefer 1 (not sure how 2 differs, if this is an issue only with multi-freq). It's probably rare to want to sum the values, and even if you do, you can still do it very easily in your own function. Not sure how many people we'll have to explain this to but it feels like we should just pull the band-aid...

@tylerflex
Copy link
Collaborator Author

Option 2 would keep the current handling (return float) for single frequency, but return a DataArray for multi-frequency.

@tylerflex tylerflex force-pushed the tyler/fix/adjoint/flux_values branch 2 times, most recently from 955f5ec to 1b39cc2 Compare April 12, 2024 23:36
@tylerflex
Copy link
Collaborator Author

I just revised this PR to try the handling in 2. see the new code here:

def package_flux_results(self, flux_values: JaxDataArray) -> Union[float, JaxDataArray]:
"""How to package the dictionary of fields computed via self.colocate()."""
freqs = flux_values.coords.get("f")
# handle single frequency case separately for backwards compatibility
# return a float of the only value
if freqs is not None and len(freqs) == 1:
return jnp.sum(flux_values.values)
# for multi-frequency, return a JaxDataArray
return flux_values

also added a more proper test of both cases.

Maybe I'd prefer this for now?

@momchil-flex
Copy link
Collaborator

I think this makes sense. Do we really want to put it in a patch though? Since we're planning the 2.7.0rc1 for Monday is it better to just put it there?

@tylerflex
Copy link
Collaborator Author

Fine with me. I'll try to change it to pre/2.7 before Monday

@tylerflex tylerflex force-pushed the tyler/fix/adjoint/flux_values branch from 1b39cc2 to b3c257d Compare April 13, 2024 07:45
@tylerflex tylerflex added 2.7 will go into version 2.7.* and removed 2.6 labels Apr 13, 2024
@tylerflex tylerflex changed the base branch from develop to pre/2.7 April 13, 2024 07:46
@tylerflex
Copy link
Collaborator Author

@momchil-flex rebased against 2.7

@tylerflex tylerflex force-pushed the tyler/fix/adjoint/flux_values branch from b3c257d to 5265373 Compare April 13, 2024 07:46
@momchil-flex momchil-flex reopened this Apr 13, 2024
@momchil-flex momchil-flex merged commit cb85379 into pre/2.7 Apr 13, 2024
15 of 24 checks passed
@momchil-flex momchil-flex deleted the tyler/fix/adjoint/flux_values branch April 13, 2024 21:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
2.7 will go into version 2.7.*
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants