Skip to content

Commit

Permalink
Various improvements to EME solver.
Browse files Browse the repository at this point in the history
  • Loading branch information
caseyflex committed May 30, 2024
1 parent b54df57 commit 5a4d532
Show file tree
Hide file tree
Showing 14 changed files with 1,106 additions and 212 deletions.
197 changes: 176 additions & 21 deletions tests/test_components/test_eme.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,11 @@ def test_eme_simulation(log_capture): # noqa: F811
with AssertLogLevel(log_capture, "INFO"):
sim = sim.updated_copy(grid_spec=grid_spec)
# multiple freqs are ok, but not for autogrid
_ = sim.updated_copy(grid_spec=td.GridSpec.uniform(dl=1), freqs=[1e10] + sim.freqs)
_ = sim.updated_copy(grid_spec=td.GridSpec.uniform(dl=1), freqs=[1e10] + list(sim.freqs))
with pytest.raises(SetupError):
_ = td.EMESimulation(
size=sim.size,
freqs=sim.freqs + [1e10],
freqs=list(sim.freqs) + [1e10],
monitors=sim.monitors,
structures=sim.structures,
grid_spec=grid_spec,
Expand All @@ -315,12 +315,15 @@ def test_eme_simulation(log_capture): # noqa: F811
)

# test port offsets
with pytest.raises(pd.ValidationError):
with pytest.raises(SetupError):
_ = sim.updated_copy(port_offsets=[sim.size[sim.axis] * 2 / 3, sim.size[sim.axis] * 2 / 3])

# test duplicate freqs
with pytest.raises(pd.ValidationError):
_ = sim.updated_copy(freqs=sim.freqs + sim.freqs)
_ = sim.updated_copy(freqs=list(sim.freqs) + list(sim.freqs))
# test empty freqs
with pytest.raises(pd.ValidationError):
_ = sim.updated_copy(freqs=[])

# test unsupported media
# fully anisotropic
Expand Down Expand Up @@ -403,7 +406,7 @@ def test_eme_simulation(log_capture): # noqa: F811
_ = sim.updated_copy(size=(1000, 1000, 1000))
with pytest.raises(SetupError):
_ = sim.updated_copy(
freqs=sim.freqs + list(1e14 * np.linspace(1, 2, 1000)),
freqs=list(sim.freqs) + list(1e14 * np.linspace(1, 2, 1000)),
grid_spec=sim.grid_spec.updated_copy(wavelength=1),
)
large_monitor = sim.monitors[2].updated_copy(size=(td.inf, td.inf, td.inf))
Expand All @@ -421,8 +424,18 @@ def test_eme_simulation(log_capture): # noqa: F811

# test sweep
_ = sim.updated_copy(sweep_spec=td.EMELengthSweep(scale_factors=list(np.linspace(1, 2, 10))))
_ = sim.updated_copy(
sweep_spec=td.EMELengthSweep(
scale_factors=np.stack((np.linspace(1, 2, 7), np.linspace(1, 2, 7)))
)
)
# second shape of length sweep must equal number of cells
with pytest.raises(SetupError):
_ = sim.updated_copy(sweep_spec=td.EMELengthSweep(scale_factors=np.array([[1, 2], [3, 4]])))
_ = sim.updated_copy(sweep_spec=td.EMEModeSweep(num_modes=list(np.arange(1, 5))))
# test sweep size limit
with pytest.raises(SetupError):
_ = sim.updated_copy(sweep_spec=td.EMELengthSweep(scale_factors=[]))
with pytest.raises(SetupError):
_ = sim.updated_copy(
sweep_spec=td.EMELengthSweep(scale_factors=list(np.linspace(1, 2, 200)))
Expand All @@ -431,6 +444,36 @@ def test_eme_simulation(log_capture): # noqa: F811
with pytest.raises(SetupError):
_ = sim.updated_copy(sweep_spec=td.EMEModeSweep(num_modes=list(np.arange(150, 200))))

# warn about num modes with constraint
with AssertLogLevel(log_capture, "INFO"):
_ = sim.updated_copy(
constraint="passive",
eme_grid_spec=td.EMEUniformGrid(num_cells=1, mode_spec=td.EMEModeSpec(num_modes=40)),
)
_ = sim.updated_copy(
constraint=None,
eme_grid_spec=td.EMEUniformGrid(num_cells=1, mode_spec=td.EMEModeSpec(num_modes=60)),
)
with AssertLogLevel(log_capture, "WARNING"):
_ = sim.updated_copy(
constraint="passive",
eme_grid_spec=td.EMEUniformGrid(num_cells=1, mode_spec=td.EMEModeSpec(num_modes=60)),
)

_ = sim.port_modes_monitor

# test freq sweep
sim = sim.updated_copy(sweep_spec=td.EMELengthSweep(scale_factors=[1, 2]))
assert sim.port_modes_monitor.num_sweep == 1
sim = sim.updated_copy(sweep_spec=td.EMEFreqSweep(freq_scale_factors=[1, 2]))
assert sim.port_modes_monitor.num_sweep == 2
assert sim._num_sweep == 2
assert sim._monitor_num_sweep(sim.monitors[0]) == 1
with pytest.raises(SetupError):
_ = sim.updated_copy(monitors=[monitor.updated_copy(num_sweep=4)])
with pytest.raises(ValidationError):
_ = sim.updated_copy(sweep_spec=td.EMEFreqSweep(freq_scale_factors=[1e-10, 2]))


def _get_eme_scalar_mode_field_data_array():
x = np.linspace(-1, 1, 68)
Expand All @@ -439,11 +482,16 @@ def _get_eme_scalar_mode_field_data_array():
f = [td.C_0, 3e14]
mode_index = np.arange(10)
eme_cell_index = np.arange(7)
coords = dict(x=x, y=y, z=z, f=f, mode_index=mode_index, eme_cell_index=eme_cell_index)
return td.EMEScalarFieldDataArray(
(1 + 1j) * np.random.random((len(x), len(y), 1, 2, len(mode_index), len(eme_cell_index))),
coords = dict(
x=x, y=y, z=z, f=f, sweep_index=[0], eme_cell_index=eme_cell_index, mode_index=mode_index
)
data = td.EMEScalarModeFieldDataArray(
(1 + 1j)
* np.random.random((len(x), len(y), 1, 2, 1, len(eme_cell_index), len(mode_index))),
coords=coords,
)
data = data.drop_vars("sweep_index")
return data


def test_eme_scalar_mode_field_data_array():
Expand All @@ -457,10 +505,14 @@ def _get_eme_scalar_field_data_array():
f = [td.C_0, 3e14]
mode_index = np.arange(5)
eme_port_index = [0, 1]
coords = dict(x=x, y=y, z=z, f=f, mode_index=mode_index, eme_port_index=eme_port_index)
return td.EMEScalarFieldDataArray(
(1 + 1j) * np.random.random((len(x), len(y), len(z), 2, 5, 2)), coords=coords
coords = dict(
x=x, y=y, z=z, f=f, sweep_index=[0], eme_port_index=eme_port_index, mode_index=mode_index
)
data = td.EMEScalarFieldDataArray(
(1 + 1j) * np.random.random((len(x), len(y), len(z), 2, 1, 2, 5)), coords=coords
)
data = data.drop_vars("sweep_index")
return data


def test_eme_scalar_field_data_array():
Expand Down Expand Up @@ -525,15 +577,28 @@ def _get_eme_coeff_data_array():
eme_port_index = [0, 1]
coords = dict(
f=f,
sweep_index=[0],
eme_port_index=eme_port_index,
eme_cell_index=eme_cell_index,
mode_index_out=mode_index_out,
mode_index_in=mode_index_in,
eme_cell_index=eme_cell_index,
eme_port_index=eme_port_index,
)
data = (1 + 1j) * np.random.random(
(len(f), len(mode_index_out), len(mode_index_in), len(eme_cell_index), len(eme_port_index))
data = td.EMECoefficientDataArray(
(1 + 1j)
* np.random.random(
(
len(f),
1,
len(eme_port_index),
len(eme_cell_index),
len(mode_index_out),
len(mode_index_in),
),
),
coords=coords,
)
return td.EMECoefficientDataArray(data, coords=coords)
data = data.drop_vars("sweep_index")
return data


def _get_eme_coeff_dataset():
Expand All @@ -550,9 +615,13 @@ def _get_eme_mode_index_data_array():
f = [td.C_0, 3e14]
mode_index = np.arange(10)
eme_cell_index = np.arange(7)
coords = dict(f=f, mode_index=mode_index, eme_cell_index=eme_cell_index)
data = (1 + 1j) * np.random.random((len(f), len(mode_index), len(eme_cell_index)))
return td.EMEModeIndexDataArray(data, coords=coords)
coords = dict(f=f, sweep_index=[0], eme_cell_index=eme_cell_index, mode_index=mode_index)
data = td.EMEModeIndexDataArray(
(1 + 1j) * np.random.random((len(f), 1, len(eme_cell_index), len(mode_index))),
coords=coords,
)
data = data.drop_vars("sweep_index")
return data


def test_eme_mode_index_data_array():
Expand Down Expand Up @@ -605,18 +674,23 @@ def _get_eme_mode_solver_data():
n_complex = _get_eme_mode_index_data_array()
kwargs.update({"n_complex": n_complex})
grid_primal_correction_data = np.ones(
(len(n_complex.f), len(n_complex.mode_index), len(n_complex.eme_cell_index))
(len(n_complex.f), 1, len(n_complex.eme_cell_index), len(n_complex.mode_index))
)
grid_dual_correction_data = grid_primal_correction_data
grid_correction_coords = dict(
f=n_complex.f, mode_index=n_complex.mode_index, eme_cell_index=n_complex.eme_cell_index
f=n_complex.f,
sweep_index=[0],
eme_cell_index=n_complex.eme_cell_index,
mode_index=n_complex.mode_index,
)
grid_primal_correction = td.components.data.data_array.EMEFreqModeDataArray(
grid_primal_correction_data, coords=grid_correction_coords
)
grid_dual_correction = td.components.data.data_array.EMEFreqModeDataArray(
grid_dual_correction_data, coords=grid_correction_coords
)
grid_primal_correction = grid_primal_correction.drop_vars("sweep_index")
grid_dual_correction = grid_dual_correction.drop_vars("sweep_index")
return td.EMEModeSolverData(
monitor=monitor,
grid_primal_correction=grid_primal_correction,
Expand Down Expand Up @@ -656,8 +730,10 @@ def _get_mode_solver_data(modes_out=False, num_modes=3):
mode_index = np.arange(num_modes)
kwargs = {key: field.isel(eme_cell_index=0, drop=True) for key, field in kwargs.items()}
kwargs = {key: field.isel(mode_index=mode_index) for key, field in kwargs.items()}
kwargs = {key: field.isel(sweep_index=0) for key, field in kwargs.items()}
n_complex = eme_mode_data.n_complex.isel(eme_cell_index=0, drop=True)
n_complex = n_complex.isel(mode_index=mode_index)
n_complex = n_complex.isel(sweep_index=0)
kwargs.update({"n_complex": n_complex})
sim = make_eme_sim()
grid_expanded = sim.discretize_monitor(monitor)
Expand Down Expand Up @@ -701,9 +777,20 @@ def test_eme_sim_data():
]
port_modes = _get_eme_port_modes()
smatrix = _get_eme_smatrix_dataset(num_modes_1=5, num_modes_2=5)

sim_data = td.EMESimulationData(simulation=sim, data=data, smatrix=smatrix, port_modes=None)
with pytest.raises(SetupError):
_ = sim_data.port_modes_tuple
with pytest.raises(SetupError):
_ = sim_data.port_modes_list

sim_data = td.EMESimulationData(
simulation=sim, data=data, smatrix=smatrix, port_modes=port_modes
)
print(port_modes.Ex.coords)
print(port_modes.Ex.dims)
_ = sim_data.port_modes_tuple
_ = sim_data.port_modes_list

# test smatrix_in_basis
smatrix_in_basis = sim_data.smatrix_in_basis(modes1=modes_in_data, modes2=modes_out_data)
Expand Down Expand Up @@ -757,6 +844,10 @@ def test_eme_sim_data():
_ = sim_data.updated_copy(port_modes=None).smatrix_in_basis(
modes1=modes_in_data, modes2=modes_out_data
)
with pytest.raises(SetupError):
_ = sim_data.updated_copy(port_modes=None).field_in_basis(
field=sim_data["field"], modes=modes_in_data, port_index=0
)

# test field in basis
field_in_basis = sim_data.field_in_basis(field=sim_data["field"], port_index=0)
Expand Down Expand Up @@ -812,6 +903,7 @@ def test_eme_sim_data():

# test smatrix in basis with sweep
smatrix = _get_eme_smatrix_dataset(num_modes_1=5, num_modes_2=5, num_sweep=10)
sim = sim.updated_copy(sweep_spec=td.EMELengthSweep(scale_factors=np.linspace(1, 2, 10)))
sim_data = td.EMESimulationData(
simulation=sim, data=data, smatrix=smatrix, port_modes=port_modes
)
Expand Down Expand Up @@ -863,3 +955,66 @@ def test_eme_sim_data():
assert len(smatrix_in_basis.S12.coords) == 2
assert len(smatrix_in_basis.S21.coords) == 2
assert len(smatrix_in_basis.S22.coords) == 2
smatrix_in_basis = sim_data.smatrix_in_basis(modes1=modes_in0)
assert len(smatrix_in_basis.S11.coords) == 2
assert len(smatrix_in_basis.S12.coords) == 3
assert len(smatrix_in_basis.S21.coords) == 3
assert len(smatrix_in_basis.S22.coords) == 4
smatrix_in_basis = sim_data.smatrix_in_basis(modes2=modes_out0)
assert len(smatrix_in_basis.S11.coords) == 4
assert len(smatrix_in_basis.S12.coords) == 3
assert len(smatrix_in_basis.S21.coords) == 3
assert len(smatrix_in_basis.S22.coords) == 2
smatrix_in_basis = sim_data.smatrix_in_basis()
assert len(smatrix_in_basis.S11.coords) == 4
assert len(smatrix_in_basis.S12.coords) == 4
assert len(smatrix_in_basis.S21.coords) == 4
assert len(smatrix_in_basis.S22.coords) == 4
_ = sim_data.port_modes_tuple
assert len(sim_data.port_modes_list) == 1

# test freq sweep smatrix_in_basis
sim = sim.updated_copy(sweep_spec=td.EMEFreqSweep(freq_scale_factors=np.linspace(1, 2, 10)))
sim_data = td.EMESimulationData(
simulation=sim, data=data, smatrix=smatrix, port_modes=port_modes
)
_ = sim_data.port_modes_tuple
assert len(sim_data.port_modes_list) == 10
smatrix_in_basis = sim_data.smatrix_in_basis(modes1=modes_in0, modes2=modes_out_data)
assert len(smatrix_in_basis.S11.coords) == 2
assert len(smatrix_in_basis.S12.coords) == 3
assert len(smatrix_in_basis.S21.coords) == 3
assert len(smatrix_in_basis.S22.coords) == 4
smatrix_in_basis = sim_data.smatrix_in_basis(modes1=modes_in_data, modes2=modes_out0)
assert len(smatrix_in_basis.S11.coords) == 4
assert len(smatrix_in_basis.S12.coords) == 3
assert len(smatrix_in_basis.S21.coords) == 3
assert len(smatrix_in_basis.S22.coords) == 2
smatrix_in_basis = sim_data.smatrix_in_basis(modes1=modes_in0, modes2=modes_out0)
assert len(smatrix_in_basis.S11.coords) == 2
assert len(smatrix_in_basis.S12.coords) == 2
assert len(smatrix_in_basis.S21.coords) == 2
assert len(smatrix_in_basis.S22.coords) == 2
smatrix_in_basis = sim_data.smatrix_in_basis(modes1=modes_in0)
assert len(smatrix_in_basis.S11.coords) == 2
assert len(smatrix_in_basis.S12.coords) == 3
assert len(smatrix_in_basis.S21.coords) == 3
assert len(smatrix_in_basis.S22.coords) == 4
smatrix_in_basis = sim_data.smatrix_in_basis(modes2=modes_out0)
assert len(smatrix_in_basis.S11.coords) == 4
assert len(smatrix_in_basis.S12.coords) == 3
assert len(smatrix_in_basis.S21.coords) == 3
assert len(smatrix_in_basis.S22.coords) == 2
smatrix_in_basis = sim_data.smatrix_in_basis()
assert len(smatrix_in_basis.S11.coords) == 4
assert len(smatrix_in_basis.S12.coords) == 4
assert len(smatrix_in_basis.S21.coords) == 4
assert len(smatrix_in_basis.S22.coords) == 4

# test field in basis with freq sweep
field_in_basis = sim_data.field_in_basis(field=sim_data["field"], port_index=0)
assert "mode_index" in field_in_basis.Ex.coords
field_in_basis = sim_data.field_in_basis(field=sim_data["field"], modes=modes_in0, port_index=0)
assert "mode_index" not in field_in_basis.Ex.coords
field_in_basis = sim_data.field_in_basis(field=sim_data["field"], modes=modes_in0, port_index=1)
assert "mode_index" not in field_in_basis.Ex.coords
4 changes: 2 additions & 2 deletions tests/test_plugins/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1990,8 +1990,8 @@ def test_no_poynting(use_emulated_run):

sim_data._get_scalar_field(mnt_name_static, "S", "abs")

with pytest.raises(NotImplementedError):
sim_data._get_scalar_field(mnt_name_differentiable, "S", "abs")
# with pytest.raises(NotImplementedError):
# sim_data._get_scalar_field(mnt_name_differentiable, "S", "abs")


def test_to_gds(tmp_path):
Expand Down
38 changes: 36 additions & 2 deletions tests/test_plugins/test_mode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ def test_mode_solver_nan_pol_fraction():
md = ms.solve()
check_ms_reduction(ms)

assert list(np.where(np.isnan(md.pol_fraction.te))[1]) == [8, 9]
assert list(np.where(np.isnan(md.pol_fraction.te))[1]) == [9]


def test_mode_solver_method_defaults():
Expand Down Expand Up @@ -906,5 +906,39 @@ def test_mode_solver_web_run_batch(mock_remote_api):
# Run mode solver one at a time
results = msweb.run_batch(mode_solver_list, verbose=False, folder_name="Mode Solver")
[print(type(x)) for x in results]
assert all([isinstance(x, ModeSolverData) for x in results])
assert all(isinstance(x, ModeSolverData) for x in results)
assert (results[i].n_eff.shape == (num_freqs, i + 1) for i in range(num_of_sims))


def test_mode_solver_relative():
"""Relative mode solver"""

simulation = td.Simulation(
size=SIM_SIZE,
grid_spec=td.GridSpec(wavelength=1.0),
structures=[WAVEGUIDE],
run_time=1e-12,
symmetry=(0, 0, 1),
boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()),
sources=[SRC],
)
mode_spec = td.ModeSpec(
num_modes=3,
target_neff=2.0,
filter_pol="tm",
precision="double",
track_freq="lowest",
)
freqs = [td.C_0 / 0.9, td.C_0 / 1.0, td.C_0 / 1.1]
ms = ModeSolver(
simulation=simulation,
plane=PLANE,
mode_spec=mode_spec,
freqs=freqs,
direction="-",
colocate=False,
)
basis = ms.data_raw
new_freqs = np.array(freqs) * 1.01
ms = ms.updated_copy(freqs=new_freqs)
_ = ms._data_on_yee_grid_relative(basis=basis)
Loading

0 comments on commit 5a4d532

Please sign in to comment.