Skip to content

Commit

Permalink
fix field data flux
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex committed Apr 12, 2024
1 parent 951e353 commit 955f5ec
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Bug in PolySlab intersection if slab bounds are `inf` on one side.
- Better error message when trying to transform a geometry with infinite bounds.
- `JaxSimulation.epsilon` properly handles `input_structures`.
- `FieldData.flux` in adjoint plugin properly returns `JaxDataArray` containing frequency coordinate `f` instead of summing over values.

## [2.6.3] - 2024-04-02

Expand Down
2 changes: 1 addition & 1 deletion tests/test_plugins/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def get_flux(x):

mnt_data2 = mnt_data.updated_copy(**fld_components)

return jnp.sum(mnt_data2.flux)
return jnp.sum(mnt_data2.flux.values)

_ = get_flux(1.0)

Expand Down
16 changes: 11 additions & 5 deletions tidy3d/plugins/adjoint/components/data/monitor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,18 @@ def package_colocate_results(self, centered_fields: Dict[str, ScalarFieldDataArr
"""How to package the dictionary of fields computed via self.colocate()."""
return self.updated_copy(**centered_fields)

def package_flux_results(self, flux_values: JaxDataArray) -> float:
def package_flux_results(self, flux_values: JaxDataArray) -> Union[float, JaxDataArray]:
"""How to package the dictionary of fields computed via self.colocate()."""
flux_data = flux_values
if isinstance(flux_data, JaxDataArray):
return jnp.sum(flux_data.values)
return jnp.sum(flux_data)

freqs = flux_values.coords.get("f")

# handle single frequency case separately for backwards compatibility
# return a float of the only value
if freqs is not None and len(freqs) == 1:
return jnp.sum(flux_values.values)

# for multi-frequency, return a JaxDataArray
return flux_values

@property
def intensity(self) -> ScalarFieldDataArray:
Expand Down

0 comments on commit 955f5ec

Please sign in to comment.