Skip to content

Commit

Permalink
Various improvements to EME solver.
Browse files Browse the repository at this point in the history
  • Loading branch information
caseyflex committed May 10, 2024
1 parent 2b18197 commit f30427a
Show file tree
Hide file tree
Showing 8 changed files with 354 additions and 40 deletions.
2 changes: 1 addition & 1 deletion tests/test_components/test_eme.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def test_eme_simulation(log_capture): # noqa: F811
)

# test port offsets
with pytest.raises(pd.ValidationError):
with pytest.raises(ValidationError):
_ = sim.updated_copy(port_offsets=[sim.size[sim.axis] * 2 / 3, sim.size[sim.axis] * 2 / 3])

# test duplicate freqs
Expand Down
3 changes: 2 additions & 1 deletion tidy3d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@
from .components.eme.data.monitor_data import EMEModeSolverData, EMEFieldData, EMECoefficientData
from .components.eme.grid import EMEUniformGrid, EMECompositeGrid, EMEExplicitGrid
from .components.eme.grid import EMEGrid, EMEModeSpec
from .components.eme.sweep import EMELengthSweep, EMEModeSweep
from .components.eme.sweep import EMELengthSweep, EMEModeSweep, EMEFreqSweep


def set_logging_level(level: str) -> None:
Expand Down Expand Up @@ -380,4 +380,5 @@ def set_logging_level(level: str) -> None:
"EMESweepSpec",
"EMELengthSweep",
"EMEModeSweep",
"EMEFreqSweep",
]
27 changes: 22 additions & 5 deletions tidy3d/components/eme/data/sim_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ..simulation import EMESimulation
from .monitor_data import EMEMonitorDataType, EMEModeSolverData, EMEFieldData
from .dataset import EMESMatrixDataset
from ...base import cached_property
from ...data.data_array import EMESMatrixDataArray, EMEScalarFieldDataArray
from ...data.sim_data import AbstractYeeGridSimulationData

Expand Down Expand Up @@ -52,7 +53,8 @@ def _extract_mode_solver_data(
update_dict = dict(data._grid_correction_dict, **data.field_components)
update_dict.update({"n_complex": data.n_complex})
update_dict = {
key: field.sel(eme_cell_index=eme_cell_index) for key, field in update_dict.items()
key: field.sel(eme_cell_index=eme_cell_index, drop=True)
for key, field in update_dict.items()
}
monitor = self.simulation.mode_solver_monitors[eme_cell_index]
monitor = monitor.updated_copy(
Expand All @@ -61,6 +63,24 @@ def _extract_mode_solver_data(
grid_expanded = self.simulation.discretize_monitor(monitor=monitor)
return ModeSolverData(**update_dict, monitor=monitor, grid_expanded=grid_expanded)

@cached_property
def port_modes_tuple(self) -> Tuple[ModeSolverData, ModeSolverData]:
"""Port modes as a tuple ``(port_modes_1, port_modes_2)``."""
if self.port_modes is None:
raise SetupError(
"The field 'port_modes' is 'None'. Please set 'store_port_modes' "
"to 'True' and re-run the simulation."
)

num_cells = self.simulation.eme_grid.num_cells

port_modes_1 = self._extract_mode_solver_data(data=self.port_modes, eme_cell_index=0)
port_modes_2 = self._extract_mode_solver_data(
data=self.port_modes, eme_cell_index=num_cells - 1
)

return (port_modes_1, port_modes_2)

def smatrix_in_basis(
self, modes1: Union[FieldData, ModeData] = None, modes2: Union[FieldData, ModeData] = None
) -> EMESMatrixDataset:
Expand Down Expand Up @@ -89,10 +109,7 @@ def smatrix_in_basis(
"to 'True' and re-run the simulation."
)

port_modes1 = self._extract_mode_solver_data(data=self.port_modes, eme_cell_index=0)
port_modes2 = self._extract_mode_solver_data(
data=self.port_modes, eme_cell_index=self.simulation.eme_grid.num_cells - 1
)
port_modes1, port_modes2 = self.port_modes_tuple

if modes1 is None:
modes1 = port_modes1
Expand Down
32 changes: 19 additions & 13 deletions tidy3d/components/eme/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,13 @@ class EMESimulation(AbstractYeeGridSimulation):
"Required to find scattering matrix in basis besides the computational basis.",
)

normalize: bool = pd.Field(
True,
title="Normalize Scattering Matrix",
description="Whether to normalize the port modes to unity flux, "
"thereby normalizing the scattering matrix and expansion coefficients.",
)

port_offsets: Tuple[pd.NonNegativeFloat, pd.NonNegativeFloat] = pd.Field(
(0, 0),
title="Port Offsets",
Expand Down Expand Up @@ -221,19 +228,6 @@ def _validate_auto_grid_wavelength(cls, val, values):
# this is handled instead post-init to ensure freqs is defined
return val

@pd.validator("port_offsets", always=True)
def _validate_port_offsets(cls, val, values):
"""Port offsets cannot jointly exceed simulation length."""
total_offset = val[0] + val[1]
size = values["size"]
axis = values["axis"]
if size[axis] < total_offset:
raise SetupError(
"The sum of the two 'port_offset' fields "
"cannot exceed the simulation 'size' in the 'axis' direction."
)
return val

@pd.validator("freqs", always=True)
def _validate_freqs(cls, val):
"""Freqs cannot contain duplicates."""
Expand Down Expand Up @@ -549,9 +543,21 @@ def _post_init_validators(self) -> None:
self._validate_modes_size()
self._validate_sweep_spec()
self._validate_symmetry()
self._validate_port_offsets()
# self._warn_monitor_interval()
log.end_capture(self)

def _validate_port_offsets(self):
"""Port offsets cannot jointly exceed simulation length."""
total_offset = self.port_offsets[0] + self.port_offsets[1]
size = self.size
axis = self.axis
if size[axis] < total_offset:
raise SetupError(
"The sum of the two 'port_offset' fields "
"cannot exceed the simulation 'size' in the 'axis' direction."
)

def _validate_symmetry(self):
"""Symmetry in propagation direction is not supported."""
if self.symmetry[self.axis] != 0:
Expand Down
12 changes: 11 additions & 1 deletion tidy3d/components/eme/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,14 @@ class EMEModeSweep(EMESweepSpec):
)


EMESweepSpecType = Union[EMELengthSweep, EMEModeSweep]
class EMEFreqSweep(EMESweepSpec):
"""Spec for sweeping frequency in EME propagation step.
Unlike ``sim.freqs``, the frequency sweep is approximate, using a
perturbative mode solver relative to the simulation EME modes."""

freq_scale_factors: List[pd.PositiveFloat] = pd.Field(
..., title="Frequency Scale Factors", description="Frequency scale factors"
)


EMESweepSpecType = Union[EMELengthSweep, EMEModeSweep, EMEFreqSweep]
173 changes: 173 additions & 0 deletions tidy3d/plugins/mode/mode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,26 @@ def _get_data_with_group_index(self) -> ModeSolverData:

return mode_solver.data_raw._group_index_post_process(self.mode_spec.group_index_step)

def _get_data_with_group_index_relative(self, basis: ModeSolverData) -> ModeSolverData:
""":class:`.ModeSolverData` with fields, effective and group indices on unexpanded grid.
Returns
-------
ModeSolverData
:class:`.ModeSolverData` object containing the effective and group indices, and mode
fields.
"""

# create a copy with the required frequencies for numerical differentiation
mode_spec = self.mode_spec.copy(update={"group_index_step": False})
mode_solver = self.copy(
update={"freqs": self._freqs_for_group_index(), "mode_spec": mode_spec}
)

return mode_solver.data_raw_relative(basis)._group_index_post_process(
self.mode_spec.group_index_step
)

@cached_property
def grid_snapped(self) -> Grid:
"""The solver grid snapped to the plane normal and to simulation 0-sized dims if any."""
Expand Down Expand Up @@ -353,6 +373,76 @@ def _data_on_yee_grid(self) -> ModeSolverData:

return mode_solver_data

def _data_on_yee_grid_relative(self, basis: ModeSolverData) -> ModeSolverData:
"""Solve for all modes, and construct data with fields on the Yee grid."""
_, _solver_coords = self.plane.pop_axis(
self._solver_grid.boundaries.to_list, axis=self.normal_axis
)

basis_fields = []
for freq_ind in range(len(basis.n_complex.f)):
basis_fields_freq = {}
for field_name in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"):
basis_fields_freq[field_name] = (
basis.field_components[field_name].isel(f=freq_ind).to_numpy()
)
basis_fields.append(basis_fields_freq)

# Compute and store the modes at all frequencies
n_complex, fields, eps_spec = self._solve_all_freqs_relative(
coords=_solver_coords, symmetry=self.solver_symmetry, basis_fields=basis_fields
)

# start a dictionary storing the data arrays for the ModeSolverData
index_data = ModeIndexDataArray(
np.stack(n_complex, axis=0),
coords=dict(
f=list(self.freqs),
mode_index=np.arange(self.mode_spec.num_modes),
),
)
data_dict = {"n_complex": index_data}

# Construct the field data on Yee grid
for field_name in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"):
xyz_coords = self.grid_snapped[field_name].to_list
scalar_field_data = ScalarModeFieldDataArray(
np.stack([field_freq[field_name] for field_freq in fields], axis=-2),
coords=dict(
x=xyz_coords[0],
y=xyz_coords[1],
z=xyz_coords[2],
f=list(self.freqs),
mode_index=np.arange(self.mode_spec.num_modes),
),
)
data_dict[field_name] = scalar_field_data

# finite grid corrections
grid_factors = self._grid_correction(
simulation=self.simulation,
plane=self.plane,
mode_spec=self.mode_spec,
n_complex=index_data,
direction=self.direction,
)

# make mode solver data on the Yee grid
mode_solver_monitor = self.to_mode_solver_monitor(name=MODE_MONITOR_NAME, colocate=False)
grid_expanded = self.simulation.discretize_monitor(mode_solver_monitor)
mode_solver_data = ModeSolverData(
monitor=mode_solver_monitor,
symmetry=self.simulation.symmetry,
symmetry_center=self.simulation.center,
grid_expanded=grid_expanded,
grid_primal_correction=grid_factors[0],
grid_dual_correction=grid_factors[1],
eps_spec=eps_spec,
**data_dict,
)

return mode_solver_data

def _colocate_data(self, mode_solver_data: ModeSolverData) -> ModeSolverData:
"""Colocate data to Yee grid boundaries."""

Expand Down Expand Up @@ -426,6 +516,17 @@ def data(self) -> ModeSolverData:
mode_solver_data = self.data_raw
return mode_solver_data.symmetry_expanded_copy

def data_relative(self, basis: ModeSolverData) -> ModeSolverData:
""":class:`.ModeSolverData` containing the field and effective index data.
Returns
-------
ModeSolverData
:class:`.ModeSolverData` object containing the effective index and mode fields.
"""
mode_solver_data = self.data_raw_relative(basis)
return mode_solver_data.symmetry_expanded_copy

@cached_property
def sim_data(self) -> SimulationData:
""":class:`.SimulationData` object containing the :class:`.ModeSolverData` for this object.
Expand Down Expand Up @@ -525,6 +626,26 @@ def _solve_all_freqs(
fields.append(fields_freq)
n_complex.append(n_freq)
eps_spec.append(eps_spec_freq)
return n_complex, fields, eps_spec

def _solve_all_freqs_relative(
self,
coords: Tuple[ArrayFloat1D, ArrayFloat1D],
symmetry: Tuple[Symmetry, Symmetry],
basis_fields: List[Dict[str, ArrayComplex4D]],
) -> Tuple[List[float], List[Dict[str, ArrayComplex4D]], List[EpsSpecType]]:
"""Call the mode solver at all requested frequencies."""

fields = []
n_complex = []
eps_spec = []
for freq, basis_fields_freq in zip(self.freqs, basis_fields):
n_freq, fields_freq, eps_spec_freq = self._solve_single_freq_relative(
freq=freq, coords=coords, symmetry=symmetry, basis_fields=basis_fields_freq
)
fields.append(fields_freq)
n_complex.append(n_freq)
eps_spec.append(eps_spec_freq)

return n_complex, fields, eps_spec

Expand Down Expand Up @@ -570,6 +691,58 @@ def _solve_single_freq(
fields = self._postprocess_solver_fields(solver_fields)
return n_complex, fields, eps_spec

def _rotate_field_coords_inverse(self, field: FIELD) -> FIELD:
"""Move the propagation axis to the z axis in the array."""
f_x, f_y, f_z = np.moveaxis(field, source=1 + self.normal_axis, destination=3)
f_n, f_ts = self.plane.pop_axis((f_x, f_y, f_z), axis=self.normal_axis)
return np.stack(self.plane.unpop_axis(f_n, f_ts, axis=2), axis=0)

def _postprocess_solver_fields_inverse(self, fields):
"""Convert ``fields`` to ``solver_fields``. Doesn't change gauge."""
E = [fields[key] for key in ("Ex", "Ey", "Ez")]
H = [fields[key] for key in ("Hx", "Hy", "Hz")]

(Ex, Ey, Ez) = self._rotate_field_coords_inverse(E)
(Hx, Hy, Hz) = self._rotate_field_coords_inverse(H)

# apply -1 to H fields if a reflection was involved in the rotation
if self.normal_axis == 1:
Hx *= -1
Hy *= -1
Hz *= -1

solver_fields = np.stack((Ex, Ey, Ez, Hx, Hy, Hz), axis=0)
return solver_fields

def _solve_single_freq_relative(
self,
freq: float,
coords: Tuple[ArrayFloat1D, ArrayFloat1D],
symmetry: Tuple[Symmetry, Symmetry],
basis_fields: Dict[str, ArrayComplex4D],
) -> Tuple[float, Dict[str, ArrayComplex4D], EpsSpecType]:
"""Call the mode solver at a single frequency.
Modes are computed as linear combinations of ``basis_fields``.
"""

if not LOCAL_SOLVER_IMPORTED:
raise ImportError(IMPORT_ERROR_MSG)

solver_basis_fields = self._postprocess_solver_fields_inverse(basis_fields)

solver_fields, n_complex, eps_spec = compute_modes(
eps_cross=self._solver_eps(freq),
coords=coords,
freq=freq,
mode_spec=self.mode_spec,
symmetry=symmetry,
direction=self.direction,
solver_basis_fields=solver_basis_fields,
)

fields = self._postprocess_solver_fields(solver_fields)
return n_complex, fields, eps_spec

def _rotate_field_coords(self, field: FIELD) -> FIELD:
"""Move the propagation axis=z to the proper order in the array."""
f_x, f_y, f_z = np.moveaxis(field, source=3, destination=1 + self.normal_axis)
Expand Down
Loading

0 comments on commit f30427a

Please sign in to comment.