Skip to content

Commit

Permalink
Various improvements to EME solver.
Browse files Browse the repository at this point in the history
  • Loading branch information
caseyflex committed May 16, 2024
1 parent e9ed1df commit 325f045
Show file tree
Hide file tree
Showing 12 changed files with 648 additions and 86 deletions.
2 changes: 1 addition & 1 deletion tests/test_components/test_eme.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def test_eme_simulation(log_capture): # noqa: F811
)

# test port offsets
with pytest.raises(pd.ValidationError):
with pytest.raises(ValidationError):
_ = sim.updated_copy(port_offsets=[sim.size[sim.axis] * 2 / 3, sim.size[sim.axis] * 2 / 3])

# test duplicate freqs
Expand Down
3 changes: 2 additions & 1 deletion tidy3d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@
from .components.eme.data.monitor_data import EMEModeSolverData, EMEFieldData, EMECoefficientData
from .components.eme.grid import EMEUniformGrid, EMECompositeGrid, EMEExplicitGrid
from .components.eme.grid import EMEGrid, EMEModeSpec
from .components.eme.sweep import EMELengthSweep, EMEModeSweep
from .components.eme.sweep import EMELengthSweep, EMEModeSweep, EMEFreqSweep


def set_logging_level(level: str) -> None:
Expand Down Expand Up @@ -380,4 +380,5 @@ def set_logging_level(level: str) -> None:
"EMESweepSpec",
"EMELengthSweep",
"EMEModeSweep",
"EMEFreqSweep",
]
141 changes: 121 additions & 20 deletions tidy3d/components/data/sim_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def at_boundaries(self, field_monitor_name: str) -> xr.Dataset:
# colocate to monitor grid boundaries
return self._at_boundaries(self.load_field_monitor(field_monitor_name))

def get_poynting_vector(self, field_monitor_name: str) -> xr.Dataset:
def _get_poynting_vector(self, field_monitor_data: AbstractFieldData) -> xr.Dataset:
"""return ``xarray.Dataset`` of the Poynting vector at Yee cell centers.
Calculated values represent the instantaneous Poynting vector for time-domain fields and the
Expand All @@ -124,19 +124,18 @@ def get_poynting_vector(self, field_monitor_name: str) -> xr.Dataset:
Parameters
----------
field_monitor_name : str
Name of field monitor used in the original :class:`Simulation`.
field_monitor_data: AbstractFieldData
Field monitor data from which to extract Poynting vector.
Returns
-------
xarray.DataArray
DataArray containing the Poynting vector calculated based on the field components
colocated at the center locations of the Yee grid.
"""
mon_data = self.load_field_monitor(field_monitor_name)
field_dataset = self._at_boundaries(mon_data)
field_dataset = self._at_boundaries(field_monitor_data)

time_domain = isinstance(self.monitor_data[field_monitor_name], FieldTimeData)
time_domain = isinstance(field_monitor_data, FieldTimeData)

poynting_components = {}

Expand All @@ -162,20 +161,49 @@ def get_poynting_vector(self, field_monitor_name: str) -> xr.Dataset:
# 2D monitors have grid correction factors that can be different from 1. For Poynting,
# it is always the product of a primal-located field and dual-located field, so the
# total grid correction factor is the product of the two
grid_correction = mon_data.grid_dual_correction * mon_data.grid_primal_correction
grid_correction = (
field_monitor_data.grid_dual_correction * field_monitor_data.grid_primal_correction
)
poynting_components["S" + dim] *= grid_correction

return xr.Dataset(poynting_components)

def get_poynting_vector(self, field_monitor_name: str) -> xr.Dataset:
"""return ``xarray.Dataset`` of the Poynting vector at Yee cell centers.
Calculated values represent the instantaneous Poynting vector for time-domain fields and the
complex vector for frequency-domain: ``S = 1/2 E × conj(H)``.
Only the available components are returned, e.g., if the indicated monitor doesn't include
field component `"Ex"`, then `"Sy"` and `"Sz"` will not be calculated.
Parameters
----------
field_monitor_name : str
Name of field monitor used in the original :class:`Simulation`.
Returns
-------
xarray.DataArray
DataArray containing the Poynting vector calculated based on the field components
colocated at the center locations of the Yee grid.
"""
field_monitor_data = self.load_field_monitor(field_monitor_name)
return self._get_poynting_vector(field_monitor_data=field_monitor_data)

def _get_scalar_field(
self, field_monitor_name: str, field_name: str, val: FieldVal, phase: float = 0.0
self,
field_monitor_data: AbstractFieldData,
field_name: str,
val: FieldVal,
phase: float = 0.0,
):
"""return ``xarray.DataArray`` of the scalar field of a given monitor at Yee cell centers.
Parameters
----------
field_monitor_name : str
Name of field monitor used in the original :class:`Simulation`.
field_monitor_data : AbstractFieldData
Field monitor data from which to extract scalar field.
field_name : str
Name of the derived field component: one of `('E', 'H', 'S', 'Sx', 'Sy', 'Sz')`.
val : Literal['real', 'imag', 'abs', 'abs^2', 'phase'] = 'real'
Expand All @@ -191,15 +219,15 @@ def _get_scalar_field(
"""

if field_name[0] == "S":
dataset = self.get_poynting_vector(field_monitor_name)
dataset = self._get_poynting_vector(field_monitor_data)
if len(field_name) > 1:
if field_name in dataset:
derived_data = dataset[field_name]
derived_data.name = field_name
return self._field_component_value(derived_data, val)
raise Tidy3dKeyError(f"Poynting component {field_name} not available")
else:
dataset = self.at_boundaries(field_monitor_name)
dataset = self._at_boundaries(field_monitor_data)

dataset = self.apply_phase(data=dataset, phase=phase)

Expand Down Expand Up @@ -336,9 +364,9 @@ def apply_phase(data: Union[xr.DataArray, xr.Dataset], phase: float = 0.0) -> xr
)
return data

def plot_field(
def _plot_field(
self,
field_monitor_name: str,
field_monitor_data: AbstractFieldData,
field_name: str,
val: FieldVal = "real",
scale: PlotScale = "lin",
Expand All @@ -354,9 +382,8 @@ def plot_field(
Parameters
----------
field_monitor_name : str
Name of :class:`.FieldMonitor`, :class:`.FieldTimeData`, or :class:`.ModeSolverData`
to plot.
field_monitor_data : AbstractFieldData
Field monitor data to plot.
field_name : str
Name of ``field`` component to plot (eg. `'Ex'`).
Also accepts ``'E'`` and ``'H'`` to plot the vector magnitudes of the electric and
Expand Down Expand Up @@ -408,10 +435,9 @@ def plot_field(

if field_name in ("E", "H") or field_name[0] == "S":
# Derived fields
field_data = self._get_scalar_field(field_monitor_name, field_name, val, phase=phase)
field_data = self._get_scalar_field(field_monitor_data, field_name, val, phase=phase)
else:
# Direct field component (e.g. Ex)
field_monitor_data = self.load_field_monitor(field_monitor_name)
if field_name not in field_monitor_data.field_components:
raise DataError(f"field_name '{field_name}' not found in data.")
field_component = field_monitor_data.field_components[field_name]
Expand Down Expand Up @@ -446,7 +472,7 @@ def plot_field(
)

# interp out any monitor.size==0 dimensions
monitor = self.simulation.get_monitor_by_name(field_monitor_name)
monitor = field_monitor_data.monitor
thin_dims = {
"xyz"[dim]: monitor.center[dim]
for dim in range(3)
Expand Down Expand Up @@ -541,6 +567,81 @@ def plot_field(
ax=ax,
)

def plot_field(
self,
field_monitor_name: str,
field_name: str,
val: FieldVal = "real",
scale: PlotScale = "lin",
eps_alpha: float = 0.2,
phase: float = 0.0,
robust: bool = True,
vmin: float = None,
vmax: float = None,
ax: Ax = None,
**sel_kwargs,
) -> Ax:
"""Plot the field data for a monitor with simulation plot overlaid.
Parameters
----------
field_monitor_name : str
Name of :class:`.FieldMonitor`, :class:`.FieldTimeData`, or :class:`.ModeSolverData`
to plot.
field_name : str
Name of ``field`` component to plot (eg. `'Ex'`).
Also accepts ``'E'`` and ``'H'`` to plot the vector magnitudes of the electric and
magnetic fields, and ``'S'`` for the Poynting vector.
val : Literal['real', 'imag', 'abs', 'abs^2', 'phase'] = 'real'
Which part of the field to plot.
scale : Literal['lin', 'dB']
Plot in linear or logarithmic (dB) scale.
eps_alpha : float = 0.2
Opacity of the structure permittivity.
Must be between 0 and 1 (inclusive).
phase : float = 0.0
Optional phase (radians) to apply to the fields.
Only has an effect on frequency-domain fields.
robust : bool = True
If True and vmin or vmax are absent, uses the 2nd and 98th percentiles of the data
to compute the color limits. This helps in visualizing the field patterns especially
in the presence of a source.
vmin : float = None
The lower bound of data range that the colormap covers. If ``None``, they are
inferred from the data and other keyword arguments.
vmax : float = None
The upper bound of data range that the colormap covers. If ``None``, they are
inferred from the data and other keyword arguments.
ax : matplotlib.axes._subplots.Axes = None
matplotlib axes to plot on, if not specified, one is created.
sel_kwargs : keyword arguments used to perform ``.sel()`` selection in the monitor data.
These kwargs can select over the spatial dimensions (``x``, ``y``, ``z``),
frequency or time dimensions (``f``, ``t``) or ``mode_index``, if applicable.
For the plotting to work appropriately, the resulting data after selection must contain
only two coordinates with len > 1.
Furthermore, these should be spatial coordinates (``x``, ``y``, or ``z``).
Returns
-------
matplotlib.axes._subplots.Axes
The supplied or created matplotlib axes.
"""

field_monitor_data = self.load_field_monitor(field_monitor_name)
return self._plot_field(
field_monitor_data=field_monitor_data,
field_name=field_name,
val=val,
scale=scale,
eps_alpha=eps_alpha,
phase=phase,
robust=robust,
vmin=vmin,
vmax=vmax,
ax=ax,
**sel_kwargs,
)

@equal_aspect
@add_ax_if_none
def plot_scalar_array(
Expand Down
Loading

0 comments on commit 325f045

Please sign in to comment.