Skip to content

Commit

Permalink
Various improvements to EME solver.
Browse files Browse the repository at this point in the history
  • Loading branch information
caseyflex committed Jun 4, 2024
1 parent 2cf9de0 commit bc4a166
Show file tree
Hide file tree
Showing 15 changed files with 1,300 additions and 309 deletions.
365 changes: 317 additions & 48 deletions tests/test_components/test_eme.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions tests/test_plugins/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1990,8 +1990,8 @@ def test_no_poynting(use_emulated_run):

sim_data._get_scalar_field(mnt_name_static, "S", "abs")

with pytest.raises(NotImplementedError):
sim_data._get_scalar_field(mnt_name_differentiable, "S", "abs")
# with pytest.raises(NotImplementedError):
# sim_data._get_scalar_field(mnt_name_differentiable, "S", "abs")


def test_to_gds(tmp_path):
Expand Down
38 changes: 36 additions & 2 deletions tests/test_plugins/test_mode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ def test_mode_solver_nan_pol_fraction():
md = ms.solve()
check_ms_reduction(ms)

assert list(np.where(np.isnan(md.pol_fraction.te))[1]) == [8, 9]
assert list(np.where(np.isnan(md.pol_fraction.te))[1]) == [9]


def test_mode_solver_method_defaults():
Expand Down Expand Up @@ -906,5 +906,39 @@ def test_mode_solver_web_run_batch(mock_remote_api):
# Run mode solver one at a time
results = msweb.run_batch(mode_solver_list, verbose=False, folder_name="Mode Solver")
[print(type(x)) for x in results]
assert all([isinstance(x, ModeSolverData) for x in results])
assert all(isinstance(x, ModeSolverData) for x in results)
assert (results[i].n_eff.shape == (num_freqs, i + 1) for i in range(num_of_sims))


def test_mode_solver_relative():
"""Relative mode solver"""

simulation = td.Simulation(
size=SIM_SIZE,
grid_spec=td.GridSpec(wavelength=1.0),
structures=[WAVEGUIDE],
run_time=1e-12,
symmetry=(0, 0, 1),
boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()),
sources=[SRC],
)
mode_spec = td.ModeSpec(
num_modes=3,
target_neff=2.0,
filter_pol="tm",
precision="double",
track_freq="lowest",
)
freqs = [td.C_0 / 0.9, td.C_0 / 1.0, td.C_0 / 1.1]
ms = ModeSolver(
simulation=simulation,
plane=PLANE,
mode_spec=mode_spec,
freqs=freqs,
direction="-",
colocate=False,
)
basis = ms.data_raw
new_freqs = np.array(freqs) * 1.01
ms = ms.updated_copy(freqs=new_freqs)
_ = ms._data_on_yee_grid_relative(basis=basis)
3 changes: 2 additions & 1 deletion tidy3d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@
from .components.eme.data.monitor_data import EMEModeSolverData, EMEFieldData, EMECoefficientData
from .components.eme.grid import EMEUniformGrid, EMECompositeGrid, EMEExplicitGrid
from .components.eme.grid import EMEGrid, EMEModeSpec
from .components.eme.sweep import EMELengthSweep, EMEModeSweep
from .components.eme.sweep import EMELengthSweep, EMEModeSweep, EMEFreqSweep


def set_logging_level(level: str) -> None:
Expand Down Expand Up @@ -387,4 +387,5 @@ def set_logging_level(level: str) -> None:
"EMESweepSpec",
"EMELengthSweep",
"EMEModeSweep",
"EMEFreqSweep",
]
21 changes: 14 additions & 7 deletions tidy3d/components/data/data_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ class EMEScalarModeFieldDataArray(AbstractSpatialDataArray):
"""

__slots__ = ()
_dims = ("x", "y", "z", "f", "mode_index", "eme_cell_index")
_dims = ("x", "y", "z", "f", "sweep_index", "eme_cell_index", "mode_index")


class EMEFreqModeDataArray(DataArray):
Expand All @@ -795,7 +795,7 @@ class EMEFreqModeDataArray(DataArray):
"""

__slots__ = ()
_dims = ("f", "mode_index", "eme_cell_index")
_dims = ("f", "sweep_index", "eme_cell_index", "mode_index")


class EMEScalarFieldDataArray(AbstractSpatialDataArray):
Expand All @@ -815,7 +815,7 @@ class EMEScalarFieldDataArray(AbstractSpatialDataArray):
"""

__slots__ = ()
_dims = ("x", "y", "z", "f", "mode_index", "eme_port_index")
_dims = ("x", "y", "z", "f", "sweep_index", "eme_port_index", "mode_index")


class EMECoefficientDataArray(DataArray):
Expand All @@ -836,11 +836,18 @@ class EMECoefficientDataArray(DataArray):
... eme_cell_index=eme_cell_index,
... eme_port_index=eme_port_index
... )
>>> fd = EMESMatrixDataArray((1 + 1j) * np.random.random((1, 2, 2, 5, 2)), coords=coords)
>>> fd = EMECoefficientDataArray((1 + 1j) * np.random.random((1, 2, 2, 5, 2)), coords=coords)
"""

__slots__ = ()
_dims = ("f", "mode_index_out", "mode_index_in", "eme_cell_index", "eme_port_index")
_dims = (
"f",
"sweep_index",
"eme_port_index",
"eme_cell_index",
"mode_index_out",
"mode_index_in",
)
_data_attrs = {"long_name": "mode expansion coefficient"}


Expand All @@ -864,7 +871,7 @@ class EMESMatrixDataArray(DataArray):
"""

__slots__ = ()
_dims = ("f", "mode_index_out", "mode_index_in", "sweep_index")
_dims = ("f", "sweep_index", "mode_index_out", "mode_index_in")
_data_attrs = {"long_name": "scattering matrix element"}


Expand All @@ -882,7 +889,7 @@ class EMEModeIndexDataArray(DataArray):
"""

__slots__ = ()
_dims = ("f", "mode_index", "eme_cell_index")
_dims = ("f", "sweep_index", "eme_cell_index", "mode_index")
_data_attrs = {"long_name": "Propagation index"}


Expand Down
Loading

0 comments on commit bc4a166

Please sign in to comment.