Skip to content

Commit

Permalink
Added EME solver
Browse files Browse the repository at this point in the history
  • Loading branch information
caseyflex committed Feb 7, 2024
1 parent efaef27 commit 5abddff
Show file tree
Hide file tree
Showing 12 changed files with 1,524 additions and 0 deletions.
12 changes: 12 additions & 0 deletions tests/test_components/test_eme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import pytest
import pydantic.v1 as pd
import numpy as np
from matplotlib import pyplot as plt

import tidy3d as td

from ..utils import STL_GEO, assert_log_level, log_capture


def test_eme_sim_data():
pass
29 changes: 29 additions & 0 deletions tidy3d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,17 @@
from .components.heat.monitor import TemperatureMonitor
from .components.heat.grid import UniformUnstructuredGrid, DistanceUnstructuredGrid

# EME
from .components.eme.simulation import EMESimulation
from .components.eme.data.sim_data import EMESimulationData
from .components.eme.monitor import EMEMonitor, EMEModeSolverMonitor, EMEFieldMonitor
from .components.eme.monitor import EMECoefficientMonitor
from .components.data.data_array import EMESMatrixDataArray, EMEScalarFieldDataArray
from .components.eme.data.dataset import EMEFieldDataset, EMECoefficientDataset, EMESMatrixDataset
from .components.eme.data.monitor_data import EMEGridData
from .components.eme.data.monitor_data import EMEModeSolverData, EMEFieldData, EMECoefficientData
from .components.eme.grid import EMEGrid, EMEUniformGrid, EMECompositeGrid


def set_logging_level(level: str) -> None:
"""Raise a warning here instead of setting the logging level."""
Expand Down Expand Up @@ -327,4 +338,22 @@ def set_logging_level(level: str) -> None:
"TriangularGridDataset",
"TetrahedralGridDataset",
"medium_from_nk",
"EMESimulation",
"EMESimulationData",
"EMEMonitor",
"EMEModeSolverMonitor",
"EMEFieldMonitor",
"EMESMatrixDataArray",
"EMEFieldDataset",
"EMECoefficientDataset",
"EMESMatrixDataset",
"EMEModeSolverData",
"EMEFieldData",
"EMECoefficientData",
"EMECoefficientMonitor",
"EMEGrid",
"EMEUniformGrid",
"EMECompositeGrid",
"EMEGridData",
"EMEScalarFieldDataArray",
]
46 changes: 46 additions & 0 deletions tidy3d/components/data/data_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
"t": {"units": SECOND, "long_name": "time"},
"direction": {"long_name": "propagation direction"},
"mode_index": {"long_name": "mode index"},
"port_index": {"long_name": "port index"},
"mode_index_in": {"long_name": "mode index in"},
"mode_index_out": {"long_name": "mode index out"},
"theta": {"units": RADIAN, "long_name": "elevation angle"},
"phi": {"units": RADIAN, "long_name": "azimuth angle"},
"ux": {"long_name": "normalized kx"},
Expand Down Expand Up @@ -650,6 +653,47 @@ class HeatDataArray(DataArray):
_dims = "T"


class EMEScalarFieldDataArray(AbstractSpatialDataArray):
"""Spatial distribution of a mode in frequency-domain as a function of mode index
and port index.
Example
-------
>>> x = [1,2]
>>> y = [2,3,4]
>>> z = [3,4,5,6]
>>> f = [2e14, 3e14]
>>> mode_index = np.arange(5)
>>> port_index = [0, 1]
>>> coords = dict(x=x, y=y, z=z, f=f, mode_index=mode_index, port_index=port_index)
>>> fd = ScalarModeFieldDataArray((1+1j) * np.random.random((2,3,4,2,5,2)), coords=coords)
"""

__slots__ = ()
_dims = ("x", "y", "z", "f", "mode_index", "port_index")


class EMESMatrixDataArray(DataArray):
"""Scattering matrix elements for a fixed pair of ports.
Example
-------
>>> mode_index_in = [0, 1]
>>> mode_index_out = [0, 1, 2]
>>> f = [2e14]
>>> coords = dict(
... f=f,
... mode_index_out=mode_index_out,
... mode_index_in=mode_index_in,
... )
>>> fd = EMESMatrixDataArray((1 + 1j) * np.random.random((1, 3, 2)), coords=coords)
"""

__slots__ = ()
_dims = ("f", "mode_index_out", "mode_index_in")
_data_attrs = {"long_name": "scattering matrix element"}


class ChargeDataArray(DataArray):
"""Charge data array.
Expand Down Expand Up @@ -728,6 +772,8 @@ class IndexedDataArray(DataArray):
FreqModeDataArray,
TriangleMeshDataArray,
HeatDataArray,
EMEScalarFieldDataArray,
EMESMatrixDataArray,
ChargeDataArray,
PointDataArray,
CellDataArray,
Expand Down
27 changes: 27 additions & 0 deletions tidy3d/components/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ...exceptions import DataError, ValidationError, Tidy3dNotImplementedError
from ...constants import PICOSECOND_PER_NANOMETER_PER_KILOMETER
from ...log import log
from ..geometry.base import Box


class Dataset(Tidy3dBaseModel, ABC):
Expand Down Expand Up @@ -198,6 +199,32 @@ def symmetry_eigenvalues(self) -> Dict[str, Callable[[Axis], float]]:
Hz=lambda dim: +1 if (dim == 2) else -1,
)

def _restrict_to_box(self, box: Box):
"""Restrict to a box."""
components = self.field_components.values()
xmin, ymin, zmin = box.bounds[0]
xmax, ymax, zmax = box.bounds[1]
restricted_components = []
for component in components:
restricted_component = component.where(
(xmin < component.x)
& (component.x < xmax)
& (ymin < component.y)
& (component.y < ymax)
& (zmin < component.z)
& (component.z < zmax),
drop=True,
)
restricted_components.append(restricted_component)
return self.updated_copy(
Ex=restricted_components[0],
Ey=restricted_components[1],
Ez=restricted_components[2],
Hx=restricted_components[3],
Hy=restricted_components[4],
Hz=restricted_components[5],
)


class FieldDataset(ElectromagneticFieldDataset):
"""Dataset storing a collection of the scalar components of E and H fields in the freq. domain
Expand Down
Empty file.
Empty file.
95 changes: 95 additions & 0 deletions tidy3d/components/eme/data/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""EME dataset"""
from __future__ import annotations


import pydantic.v1 as pd

from ...data.dataset import Dataset, ElectromagneticFieldDataset
from ...data.data_array import EMEScalarFieldDataArray, EMESMatrixDataArray


class EMESMatrixDataset(Dataset):
"""Dataset storing S matrix."""

S11: EMESMatrixDataArray = pd.Field(
...,
title="S11 matrix",
description="S matrix relating output modes at port 1 to input modes at port 1.",
)
S12: EMESMatrixDataArray = pd.Field(
...,
title="S12 matrix",
description="S matrix relating output modes at port 1 to input modes at port 2.",
)
S21: EMESMatrixDataArray = pd.Field(
...,
title="S21 matrix",
description="S matrix relating output modes at port 2 to input modes at port 1.",
)
S22: EMESMatrixDataArray = pd.Field(
...,
title="S22 matrix",
description="S matrix relating output modes at port 2 to input modes at port 2.",
)


class EMECoefficientDataset(Dataset):
"""Dataset storing expansion coefficients for the modes in a cell.
These are defined at the cell centers.
"""

A1: EMESMatrixDataArray = pd.Field(
...,
title="A1 coefficient",
description="Coefficient for forward mode in this cell " "when excited from port 1.",
)
B1: EMESMatrixDataArray = pd.Field(
...,
title="B1 coefficient",
description="Coefficient for backward mode in this cell " "when excited from port 1.",
)
A2: EMESMatrixDataArray = pd.Field(
...,
title="A2 coefficient",
description="Coefficient for forward mode in this cell " "when excited from port 2.",
)
B2: EMESMatrixDataArray = pd.Field(
...,
title="B2 coefficient",
description="Coefficient for backward mode in this cell " "when excited from port 2.",
)


class EMEFieldDataset(ElectromagneticFieldDataset):
"""Dataset storing scalar components of E and H fields as a function of freq, mode_index, and port_index."""

Ex: EMEScalarFieldDataArray = pd.Field(
...,
title="Ex",
description="Spatial distribution of the x-component of the electric field of the mode.",
)
Ey: EMEScalarFieldDataArray = pd.Field(
...,
title="Ey",
description="Spatial distribution of the y-component of the electric field of the mode.",
)
Ez: EMEScalarFieldDataArray = pd.Field(
...,
title="Ez",
description="Spatial distribution of the z-component of the electric field of the mode.",
)
Hx: EMEScalarFieldDataArray = pd.Field(
...,
title="Hx",
description="Spatial distribution of the x-component of the magnetic field of the mode.",
)
Hy: EMEScalarFieldDataArray = pd.Field(
...,
title="Hy",
description="Spatial distribution of the y-component of the magnetic field of the mode.",
)
Hz: EMEScalarFieldDataArray = pd.Field(
...,
title="Hz",
description="Spatial distribution of the z-component of the magnetic field of the mode.",
)
88 changes: 88 additions & 0 deletions tidy3d/components/eme/data/monitor_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""EME monitor data"""
from __future__ import annotations

from abc import ABC

from typing import Union, List

import pydantic.v1 as pd

from ...base_sim.data.monitor_data import AbstractMonitorData
from ..monitor import EMEModeSolverMonitor, EMEFieldMonitor, EMECoefficientMonitor
from ...data.monitor_data import ModeSolverData, ElectromagneticFieldData
from ...base import Tidy3dBaseModel
from ....exceptions import ValidationError

from .dataset import EMEFieldDataset, EMECoefficientDataset


class EMEGridData(Tidy3dBaseModel, ABC):
"""Abstract class defining data indexed by a subset of the cells in the EME grid."""

cell_indices: List[pd.NonNegativeInt] = pd.Field(
..., title="Cell indices", description="Cell indices"
)


class EMEModeSolverData(AbstractMonitorData, EMEGridData):
"""Data associated with an EME mode solver monitor."""

monitor: EMEModeSolverMonitor = pd.Field(
...,
title="EME Mode Solver Monitor",
description="EME mode solver monitor associated with this data.",
)

modes: List[ModeSolverData] = pd.Field(
...,
title="Modes",
description="Modes recorded by the EME mode solver monitor. "
"The corresponding cell indices are stored in 'cell_indices'. "
"A mode is recorded if its mode plane is contained in the monitor geometry.",
)

@pd.validator("modes", always=True)
def _validate_num_modes(cls, val, values):
"""Check that the number of modes equals the number of cells inside the monitor."""
num_cells = len(values["cell_indices"])
num_modes = len(val)
if num_cells != num_modes:
raise ValidationError("The number of 'modes' must equal the number of 'cell_indices'.")
return val


class EMEFieldData(EMEFieldDataset, ElectromagneticFieldData):
"""Data associated with an EME field monitor."""

monitor: EMEFieldMonitor = pd.Field(
..., title="EME Field Monitor", description="EME field monitor associated with this data."
)


class EMECoefficientData(AbstractMonitorData, EMEGridData):
"""Data associated with an EME coefficient monitor."""

monitor: EMECoefficientMonitor = pd.Field(
...,
title="EME Coefficient Monitor",
description="EME coefficient monitor associated with this data.",
)

coeffs: List[EMECoefficientDataset] = pd.Field(
...,
title="Coefficients",
description="Coefficients of the forward and backward traveling modes in each cell "
"contained in the monitor geometry.",
)

@pd.validator("coeffs", always=True)
def _validate_num_coeffs(cls, val, values):
"""Check that the number of coeffs equals the number of cells inside the monitor."""
num_cells = len(values["cell_indices"])
num_coeffs = len(val)
if num_cells != num_coeffs:
raise ValidationError("The number of 'coeffs' must equal the number of 'cell_indices'.")
return val


EMEMonitorDataType = Union[EMEModeSolverData, EMEFieldData, EMECoefficientData]
Loading

0 comments on commit 5abddff

Please sign in to comment.