From 807044d510dc334570f8ccc5b4975ff311a80c45 Mon Sep 17 00:00:00 2001 From: Tyler Hughes Date: Tue, 9 Jan 2024 13:20:37 -0500 Subject: [PATCH] add grid_expanded to run_emulated mode data --- docs/notebooks | 2 +- tests/test_plugins/test_adjoint.py | 15 +++++++-------- tests/utils.py | 7 ++++++- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/docs/notebooks b/docs/notebooks index dfc784c8a..d7f3b021d 160000 --- a/docs/notebooks +++ b/docs/notebooks @@ -1 +1 @@ -Subproject commit dfc784c8a2775c6dda3b9dd4c0169a85a523c2ad +Subproject commit d7f3b021d1a9832602735d3440a5b93c8f2851ea diff --git a/tests/test_plugins/test_adjoint.py b/tests/test_plugins/test_adjoint.py index 82e41d31b..628f17f10 100644 --- a/tests/test_plugins/test_adjoint.py +++ b/tests/test_plugins/test_adjoint.py @@ -1649,16 +1649,15 @@ def make_sim(eps): if not has_adj_src: monkeypatch.setattr(JaxModeData, "to_adjoint_sources", lambda *args, **kwargs: []) - def J(x): - sim = make_sim(eps=x) - data = run(sim, task_name="test", path=str(tmp_path / RUN_FILE)) + x = 2.0 + sim = make_sim(eps=x) + data = run(sim, task_name="test", path=str(tmp_path / RUN_FILE)) + + # check whether we got a warning for no sources? + with AssertLogLevel(log_capture, log_level_expected, contains_str="No adjoint sources"): data.make_adjoint_simulation(fwidth=src.source_time.fwidth, run_time=sim.run_time) - power = jnp.sum(jnp.abs(jnp.array(data["mnt"].amps.values)) ** 2) - return power - grad_J = grad(J) - grad_J(2.0) - assert_log_level(log_capture, log_level_expected) + power = jnp.sum(jnp.abs(jnp.array(data["mnt"].amps.values)) ** 2) def test_nonlinear_warn(log_capture): diff --git a/tests/utils.py b/tests/utils.py index 7259cb2c1..7aa265858 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -583,7 +583,12 @@ def make_mode_data(monitor: td.ModeMonitor) -> td.ModeData: coords_amps = dict(direction=["+", "-"]) coords_amps.update(coords_ind) amps = make_data(coords=coords_amps, data_array_type=td.ModeAmpsDataArray, is_complex=True) - return td.ModeData(monitor=monitor, n_complex=n_complex, amps=amps) + return td.ModeData( + monitor=monitor, + n_complex=n_complex, + amps=amps, + grid_expanded=simulation.discretize_monitor(monitor), + ) MONITOR_MAKER_MAP = { td.FieldMonitor: make_field_data,