Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix field data flux #1606

Merged
merged 1 commit into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Bug in PolySlab intersection if slab bounds are `inf` on one side.
- Better error message when trying to transform a geometry with infinite bounds.
- `JaxSimulation.epsilon` properly handles `input_structures`.
- `FieldData.flux` in adjoint plugin properly returns `JaxDataArray` containing frequency coordinate `f` instead of summing over values.

## [2.6.3] - 2024-04-02

Expand Down
19 changes: 18 additions & 1 deletion tests/test_plugins/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
from tidy3d.plugins.adjoint.components.simulation import JaxSimulation, JaxInfo, RUN_TIME_FACTOR
from tidy3d.plugins.adjoint.components.simulation import MAX_NUM_INPUT_STRUCTURES
from tidy3d.plugins.adjoint.components.data.sim_data import JaxSimulationData
from tidy3d.plugins.adjoint.components.data.monitor_data import JaxModeData, JaxDiffractionData
from tidy3d.plugins.adjoint.components.data.monitor_data import (
JaxModeData,
JaxDiffractionData,
JaxFieldData,
)
from tidy3d.plugins.adjoint.components.data.data_array import JaxDataArray, JAX_DATA_ARRAY_TAG
from tidy3d.plugins.adjoint.components.data.dataset import JaxPermittivityDataset
from tidy3d.plugins.adjoint.web import run, run_async
Expand Down Expand Up @@ -1797,3 +1801,16 @@ def test_sidewall_angle_validator(log_capture, sidewall_angle, log_expected):

with AssertLogLevel(log_capture, log_expected, contains_str="sidewall"):
jax_polyslab1.updated_copy(sidewall_angle=sidewall_angle)


def test_package_flux():
"""Test handling of packaging flux data for single and multi-freq."""

value = 1.0
da_single = JaxDataArray(values=[value], coords=dict(f=[1.0]))
res_single = JaxFieldData.package_flux_results(None, da_single)
assert res_single == value

da_multi = JaxDataArray(values=[1.0, 2.0], coords=dict(f=[1.0, 2.0]))
res_multi = JaxFieldData.package_flux_results(None, da_multi)
assert res_multi == da_multi
16 changes: 11 additions & 5 deletions tidy3d/plugins/adjoint/components/data/monitor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,18 @@ def package_colocate_results(self, centered_fields: Dict[str, ScalarFieldDataArr
"""How to package the dictionary of fields computed via self.colocate()."""
return self.updated_copy(**centered_fields)

def package_flux_results(self, flux_values: JaxDataArray) -> float:
def package_flux_results(self, flux_values: JaxDataArray) -> Union[float, JaxDataArray]:
"""How to package the dictionary of fields computed via self.colocate()."""
flux_data = flux_values
if isinstance(flux_data, JaxDataArray):
return jnp.sum(flux_data.values)
return jnp.sum(flux_data)

freqs = flux_values.coords.get("f")

# handle single frequency case separately for backwards compatibility
# return a float of the only value
if freqs is not None and len(freqs) == 1:
return jnp.sum(flux_values.values)

# for multi-frequency, return a JaxDataArray
return flux_values

@property
def intensity(self) -> ScalarFieldDataArray:
Expand Down
Loading