Skip to content

Commit

Permalink
adjoint JaxFieldData.flux returns float for single frequency and …
Browse files Browse the repository at this point in the history
…`JaxDataArray` for multi-freq
  • Loading branch information
tylerflex authored and momchil-flex committed Apr 13, 2024
1 parent 951e353 commit cb85379
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Bug in PolySlab intersection if slab bounds are `inf` on one side.
- Better error message when trying to transform a geometry with infinite bounds.
- `JaxSimulation.epsilon` properly handles `input_structures`.
- `FieldData.flux` in adjoint plugin properly returns `JaxDataArray` containing frequency coordinate `f` instead of summing over values.

## [2.6.3] - 2024-04-02

Expand Down
19 changes: 18 additions & 1 deletion tests/test_plugins/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
from tidy3d.plugins.adjoint.components.simulation import JaxSimulation, JaxInfo, RUN_TIME_FACTOR
from tidy3d.plugins.adjoint.components.simulation import MAX_NUM_INPUT_STRUCTURES
from tidy3d.plugins.adjoint.components.data.sim_data import JaxSimulationData
from tidy3d.plugins.adjoint.components.data.monitor_data import JaxModeData, JaxDiffractionData
from tidy3d.plugins.adjoint.components.data.monitor_data import (
JaxModeData,
JaxDiffractionData,
JaxFieldData,
)
from tidy3d.plugins.adjoint.components.data.data_array import JaxDataArray, JAX_DATA_ARRAY_TAG
from tidy3d.plugins.adjoint.components.data.dataset import JaxPermittivityDataset
from tidy3d.plugins.adjoint.web import run, run_async
Expand Down Expand Up @@ -1797,3 +1801,16 @@ def test_sidewall_angle_validator(log_capture, sidewall_angle, log_expected):

with AssertLogLevel(log_capture, log_expected, contains_str="sidewall"):
jax_polyslab1.updated_copy(sidewall_angle=sidewall_angle)


def test_package_flux():
"""Test handling of packaging flux data for single and multi-freq."""

value = 1.0
da_single = JaxDataArray(values=[value], coords=dict(f=[1.0]))
res_single = JaxFieldData.package_flux_results(None, da_single)
assert res_single == value

da_multi = JaxDataArray(values=[1.0, 2.0], coords=dict(f=[1.0, 2.0]))
res_multi = JaxFieldData.package_flux_results(None, da_multi)
assert res_multi == da_multi
16 changes: 11 additions & 5 deletions tidy3d/plugins/adjoint/components/data/monitor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,18 @@ def package_colocate_results(self, centered_fields: Dict[str, ScalarFieldDataArr
"""How to package the dictionary of fields computed via self.colocate()."""
return self.updated_copy(**centered_fields)

def package_flux_results(self, flux_values: JaxDataArray) -> float:
def package_flux_results(self, flux_values: JaxDataArray) -> Union[float, JaxDataArray]:
"""How to package the dictionary of fields computed via self.colocate()."""
flux_data = flux_values
if isinstance(flux_data, JaxDataArray):
return jnp.sum(flux_data.values)
return jnp.sum(flux_data)

freqs = flux_values.coords.get("f")

# handle single frequency case separately for backwards compatibility
# return a float of the only value
if freqs is not None and len(freqs) == 1:
return jnp.sum(flux_values.values)

# for multi-frequency, return a JaxDataArray
return flux_values

@property
def intensity(self) -> ScalarFieldDataArray:
Expand Down

0 comments on commit cb85379

Please sign in to comment.