Skip to content

Commit

Permalink
Various improvements to EME solver.
Browse files Browse the repository at this point in the history
  • Loading branch information
caseyflex committed May 28, 2024
1 parent b54df57 commit 909a1d6
Show file tree
Hide file tree
Showing 13 changed files with 879 additions and 196 deletions.
75 changes: 54 additions & 21 deletions tests/test_components/test_eme.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,11 @@ def test_eme_simulation(log_capture): # noqa: F811
with AssertLogLevel(log_capture, "INFO"):
sim = sim.updated_copy(grid_spec=grid_spec)
# multiple freqs are ok, but not for autogrid
_ = sim.updated_copy(grid_spec=td.GridSpec.uniform(dl=1), freqs=[1e10] + sim.freqs)
_ = sim.updated_copy(grid_spec=td.GridSpec.uniform(dl=1), freqs=[1e10] + list(sim.freqs))
with pytest.raises(SetupError):
_ = td.EMESimulation(
size=sim.size,
freqs=sim.freqs + [1e10],
freqs=list(sim.freqs) + [1e10],
monitors=sim.monitors,
structures=sim.structures,
grid_spec=grid_spec,
Expand All @@ -315,12 +315,12 @@ def test_eme_simulation(log_capture): # noqa: F811
)

# test port offsets
with pytest.raises(pd.ValidationError):
with pytest.raises(ValidationError):
_ = sim.updated_copy(port_offsets=[sim.size[sim.axis] * 2 / 3, sim.size[sim.axis] * 2 / 3])

# test duplicate freqs
with pytest.raises(pd.ValidationError):
_ = sim.updated_copy(freqs=sim.freqs + sim.freqs)
_ = sim.updated_copy(freqs=list(sim.freqs) + list(sim.freqs))

# test unsupported media
# fully anisotropic
Expand Down Expand Up @@ -403,7 +403,7 @@ def test_eme_simulation(log_capture): # noqa: F811
_ = sim.updated_copy(size=(1000, 1000, 1000))
with pytest.raises(SetupError):
_ = sim.updated_copy(
freqs=sim.freqs + list(1e14 * np.linspace(1, 2, 1000)),
freqs=list(sim.freqs) + list(1e14 * np.linspace(1, 2, 1000)),
grid_spec=sim.grid_spec.updated_copy(wavelength=1),
)
large_monitor = sim.monitors[2].updated_copy(size=(td.inf, td.inf, td.inf))
Expand Down Expand Up @@ -439,11 +439,16 @@ def _get_eme_scalar_mode_field_data_array():
f = [td.C_0, 3e14]
mode_index = np.arange(10)
eme_cell_index = np.arange(7)
coords = dict(x=x, y=y, z=z, f=f, mode_index=mode_index, eme_cell_index=eme_cell_index)
return td.EMEScalarFieldDataArray(
(1 + 1j) * np.random.random((len(x), len(y), 1, 2, len(mode_index), len(eme_cell_index))),
coords = dict(
x=x, y=y, z=z, f=f, sweep_index=[0], eme_cell_index=eme_cell_index, mode_index=mode_index
)
data = td.EMEScalarModeFieldDataArray(
(1 + 1j)
* np.random.random((len(x), len(y), 1, 2, 1, len(eme_cell_index), len(mode_index))),
coords=coords,
)
data = data.drop_vars("sweep_index")
return data


def test_eme_scalar_mode_field_data_array():
Expand All @@ -457,10 +462,14 @@ def _get_eme_scalar_field_data_array():
f = [td.C_0, 3e14]
mode_index = np.arange(5)
eme_port_index = [0, 1]
coords = dict(x=x, y=y, z=z, f=f, mode_index=mode_index, eme_port_index=eme_port_index)
return td.EMEScalarFieldDataArray(
(1 + 1j) * np.random.random((len(x), len(y), len(z), 2, 5, 2)), coords=coords
coords = dict(
x=x, y=y, z=z, f=f, sweep_index=[0], eme_port_index=eme_port_index, mode_index=mode_index
)
data = td.EMEScalarFieldDataArray(
(1 + 1j) * np.random.random((len(x), len(y), len(z), 2, 1, 2, 5)), coords=coords
)
data = data.drop_vars("sweep_index")
return data


def test_eme_scalar_field_data_array():
Expand Down Expand Up @@ -525,15 +534,28 @@ def _get_eme_coeff_data_array():
eme_port_index = [0, 1]
coords = dict(
f=f,
sweep_index=[0],
eme_port_index=eme_port_index,
eme_cell_index=eme_cell_index,
mode_index_out=mode_index_out,
mode_index_in=mode_index_in,
eme_cell_index=eme_cell_index,
eme_port_index=eme_port_index,
)
data = (1 + 1j) * np.random.random(
(len(f), len(mode_index_out), len(mode_index_in), len(eme_cell_index), len(eme_port_index))
data = td.EMECoefficientDataArray(
(1 + 1j)
* np.random.random(
(
len(f),
1,
len(eme_port_index),
len(eme_cell_index),
len(mode_index_out),
len(mode_index_in),
),
),
coords=coords,
)
return td.EMECoefficientDataArray(data, coords=coords)
data = data.drop_vars("sweep_index")
return data


def _get_eme_coeff_dataset():
Expand All @@ -550,9 +572,13 @@ def _get_eme_mode_index_data_array():
f = [td.C_0, 3e14]
mode_index = np.arange(10)
eme_cell_index = np.arange(7)
coords = dict(f=f, mode_index=mode_index, eme_cell_index=eme_cell_index)
data = (1 + 1j) * np.random.random((len(f), len(mode_index), len(eme_cell_index)))
return td.EMEModeIndexDataArray(data, coords=coords)
coords = dict(f=f, sweep_index=[0], eme_cell_index=eme_cell_index, mode_index=mode_index)
data = td.EMEModeIndexDataArray(
(1 + 1j) * np.random.random((len(f), 1, len(eme_cell_index), len(mode_index))),
coords=coords,
)
data = data.drop_vars("sweep_index")
return data


def test_eme_mode_index_data_array():
Expand Down Expand Up @@ -605,18 +631,23 @@ def _get_eme_mode_solver_data():
n_complex = _get_eme_mode_index_data_array()
kwargs.update({"n_complex": n_complex})
grid_primal_correction_data = np.ones(
(len(n_complex.f), len(n_complex.mode_index), len(n_complex.eme_cell_index))
(len(n_complex.f), 1, len(n_complex.eme_cell_index), len(n_complex.mode_index))
)
grid_dual_correction_data = grid_primal_correction_data
grid_correction_coords = dict(
f=n_complex.f, mode_index=n_complex.mode_index, eme_cell_index=n_complex.eme_cell_index
f=n_complex.f,
sweep_index=[0],
eme_cell_index=n_complex.eme_cell_index,
mode_index=n_complex.mode_index,
)
grid_primal_correction = td.components.data.data_array.EMEFreqModeDataArray(
grid_primal_correction_data, coords=grid_correction_coords
)
grid_dual_correction = td.components.data.data_array.EMEFreqModeDataArray(
grid_dual_correction_data, coords=grid_correction_coords
)
grid_primal_correction = grid_primal_correction.drop_vars("sweep_index")
grid_dual_correction = grid_dual_correction.drop_vars("sweep_index")
return td.EMEModeSolverData(
monitor=monitor,
grid_primal_correction=grid_primal_correction,
Expand Down Expand Up @@ -656,8 +687,10 @@ def _get_mode_solver_data(modes_out=False, num_modes=3):
mode_index = np.arange(num_modes)
kwargs = {key: field.isel(eme_cell_index=0, drop=True) for key, field in kwargs.items()}
kwargs = {key: field.isel(mode_index=mode_index) for key, field in kwargs.items()}
kwargs = {key: field.isel(sweep_index=0) for key, field in kwargs.items()}
n_complex = eme_mode_data.n_complex.isel(eme_cell_index=0, drop=True)
n_complex = n_complex.isel(mode_index=mode_index)
n_complex = n_complex.isel(sweep_index=0)
kwargs.update({"n_complex": n_complex})
sim = make_eme_sim()
grid_expanded = sim.discretize_monitor(monitor)
Expand Down
3 changes: 2 additions & 1 deletion tidy3d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@
from .components.eme.data.monitor_data import EMEModeSolverData, EMEFieldData, EMECoefficientData
from .components.eme.grid import EMEUniformGrid, EMECompositeGrid, EMEExplicitGrid
from .components.eme.grid import EMEGrid, EMEModeSpec
from .components.eme.sweep import EMELengthSweep, EMEModeSweep
from .components.eme.sweep import EMELengthSweep, EMEModeSweep, EMEFreqSweep


def set_logging_level(level: str) -> None:
Expand Down Expand Up @@ -381,4 +381,5 @@ def set_logging_level(level: str) -> None:
"EMESweepSpec",
"EMELengthSweep",
"EMEModeSweep",
"EMEFreqSweep",
]
21 changes: 14 additions & 7 deletions tidy3d/components/data/data_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ class EMEScalarModeFieldDataArray(AbstractSpatialDataArray):
"""

__slots__ = ()
_dims = ("x", "y", "z", "f", "mode_index", "eme_cell_index")
_dims = ("x", "y", "z", "f", "sweep_index", "eme_cell_index", "mode_index")


class EMEFreqModeDataArray(DataArray):
Expand All @@ -734,7 +734,7 @@ class EMEFreqModeDataArray(DataArray):
"""

__slots__ = ()
_dims = ("f", "mode_index", "eme_cell_index")
_dims = ("f", "sweep_index", "eme_cell_index", "mode_index")


class EMEScalarFieldDataArray(AbstractSpatialDataArray):
Expand All @@ -754,7 +754,7 @@ class EMEScalarFieldDataArray(AbstractSpatialDataArray):
"""

__slots__ = ()
_dims = ("x", "y", "z", "f", "mode_index", "eme_port_index")
_dims = ("x", "y", "z", "f", "sweep_index", "eme_port_index", "mode_index")


class EMECoefficientDataArray(DataArray):
Expand All @@ -775,11 +775,18 @@ class EMECoefficientDataArray(DataArray):
... eme_cell_index=eme_cell_index,
... eme_port_index=eme_port_index
... )
>>> fd = EMESMatrixDataArray((1 + 1j) * np.random.random((1, 2, 2, 5, 2)), coords=coords)
>>> fd = EMECoefficientDataArray((1 + 1j) * np.random.random((1, 2, 2, 5, 2)), coords=coords)
"""

__slots__ = ()
_dims = ("f", "mode_index_out", "mode_index_in", "eme_cell_index", "eme_port_index")
_dims = (
"f",
"sweep_index",
"eme_port_index",
"eme_cell_index",
"mode_index_out",
"mode_index_in",
)
_data_attrs = {"long_name": "mode expansion coefficient"}


Expand All @@ -803,7 +810,7 @@ class EMESMatrixDataArray(DataArray):
"""

__slots__ = ()
_dims = ("f", "mode_index_out", "mode_index_in", "sweep_index")
_dims = ("f", "sweep_index", "mode_index_out", "mode_index_in")
_data_attrs = {"long_name": "scattering matrix element"}


Expand All @@ -821,7 +828,7 @@ class EMEModeIndexDataArray(DataArray):
"""

__slots__ = ()
_dims = ("f", "mode_index", "eme_cell_index")
_dims = ("f", "sweep_index", "eme_cell_index", "mode_index")
_data_attrs = {"long_name": "Propagation index"}


Expand Down
Loading

0 comments on commit 909a1d6

Please sign in to comment.