Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 100 additions & 12 deletions tests/test_web/test_local_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,23 @@
from pathlib import Path
from types import SimpleNamespace

import autograd as ag
import pytest
import xarray as xr
from autograd.core import defvjp
from rich.console import Console

import tidy3d as td
from tests.test_components.autograd.test_autograd import ALL_KEY, get_functions, params0
from tests.test_web.test_webapi_mode import make_mode_sim
from tidy3d import config
from tidy3d.components.autograd.field_map import FieldMap
from tidy3d.config import get_manager
from tidy3d.web import Job, common, run, run_async
from tidy3d.web.api import webapi as web
from tidy3d.web.api.autograd import autograd, engine, io_utils
from tidy3d.web.api.autograd.autograd import run as run_autograd
from tidy3d.web.api.autograd.constants import SIM_VJP_FILE
from tidy3d.web.api.container import Batch, WebContainer
from tidy3d.web.api.webapi import load_simulation_if_cached
from tidy3d.web.cache import CACHE_ARTIFACT_NAME, clear, get_cache_entry_dir, resolve_local_cache
Expand All @@ -39,6 +46,24 @@ class _FakeStubData:
def __init__(self, simulation: td.Simulation):
self.simulation = simulation

def __getitem__(self, key):
if key == "mode":
params = self.simulation.attrs["params_autograd"]
return SimpleNamespace(
amps=xr.DataArray(params, dims=["x"], coords={"x": list(range(len(params)))})
)

def _strip_traced_fields(self, *args, **kwargs):
"""Fake _strip_traced_fields: return minimal valid autograd-style mapping."""
return {"params": self.simulation.attrs["params"]}

def _insert_traced_fields(self, field_mapping, *args, **kwargs):
self.simulation.attrs["params_autograd"] = field_mapping["params"]
return self

def _make_adjoint_sims(self, **kwargs):
return [self.simulation.updated_copy(run_time=self.simulation.run_time * 2)]


@pytest.fixture
def basic_simulation():
Expand Down Expand Up @@ -128,25 +153,50 @@ def _fake__check_folder(*args, **kwargs):
def _fake_status(self):
return "success"

def _fake_download_file(resource_id, remote_filename, to_file=None, **kwargs):
# Only count this download if it's the adjoint/VJP file
if str(remote_filename) == SIM_VJP_FILE:
counters["download"] += 1

def _fake_from_file(*args, **kwargs):
field_map = FieldMap(tracers=())
return field_map

monkeypatch.setattr(io_utils, "download_file", _fake_download_file)
monkeypatch.setattr(autograd, "postprocess_fwd", _fake_from_file)
monkeypatch.setattr(FieldMap, "from_file", _fake_from_file)
monkeypatch.setattr(WebContainer, "_check_folder", _fake__check_folder)
monkeypatch.setattr(web, "upload", _fake_upload)
monkeypatch.setattr(web, "start", _fake_start)
monkeypatch.setattr(web, "monitor", _fake_monitor)
monkeypatch.setattr(web, "download", _fake_download)
monkeypatch.setattr(web, "estimate_cost", lambda *args, **kwargs: 0.0)
monkeypatch.setattr(Job, "status", property(_fake_status))
monkeypatch.setattr(engine, "upload_sim_fields_keys", lambda *args, **kwargs: None)
monkeypatch.setattr(
web,
"get_info",
lambda task_id, verbose=True: type(
"_Info", (), {"solverVersion": "solver-1", "taskType": "FDTD"}
)(),
)
monkeypatch.setattr(
io_utils,
"get_info",
lambda task_id, verbose=True: type(
"_Info", (), {"solverVersion": "solver-1", "taskType": "FDTD"}
)(),
)
monkeypatch.setattr(
web, "load_simulation", lambda task_id, *args, **kwargs: TASK_TO_SIM[task_id]
)
monkeypatch.setattr(
io_utils, "load_simulation", lambda task_id, *args, **kwargs: TASK_TO_SIM[task_id]
)
monkeypatch.setattr(BatchTask, "is_batch", lambda *args, **kwargs: "success")
monkeypatch.setattr(BatchTask, "detail", lambda *args: SimpleNamespace(status="success"))
monkeypatch.setattr(
BatchTask, "detail", lambda *args, **kwargs: SimpleNamespace(status="success")
)
return counters


Expand Down Expand Up @@ -371,23 +421,59 @@ def _test_job_run_cache(monkeypatch, basic_simulation, tmp_path):
assert os.path.exists(out2_path)


def _test_autograd_cache(monkeypatch):
def _test_autograd_cache(monkeypatch, request):
counters = _patch_run_pipeline(monkeypatch)

# "Original" rule: the one autograd uses by default
def _orig_make_dict_vjp(ans, keys, vals):
return lambda g: [g[key] for key in keys]

def _zero_make_dict_vjp(ans, keys, vals):
def vjp(g):
# One gradient per entry in `vals`, all zeros, matching shape/dtype
return [ag.numpy.zeros_like(v) for v in vals]

return vjp

# Install our zero-VJP (this is the thing that affects global state)
defvjp(
ag.builtins._make_dict,
_zero_make_dict_vjp,
argnums=(1,), # gradient w.r.t. `vals`
)

# Make sure we restore it after the test
def _restore_make_dict_vjp():
defvjp(
ag.builtins._make_dict,
_orig_make_dict_vjp,
argnums=(1,),
)

request.addfinalizer(_restore_make_dict_vjp)

cache = resolve_local_cache(use_cache=True)
cache.clear()

functions = get_functions(ALL_KEY, "mode")
make_sim = functions["sim"]
sim = make_sim(params0)
web.run(sim)
assert counters["download"] == 1
assert len(cache) == 1
postprocess = functions["postprocess"]

def objective(params):
sim = make_sim(params)
sim.attrs["params"] = params
sim_data = run_autograd(sim)
value = postprocess(sim_data)
return value

ag.value_and_grad(objective)(params0)
assert counters["download"] == 2
assert len(cache) == 2

_reset_counters(counters)
sim = make_sim(params0)
web.run(sim)
assert counters["download"] == 0
assert len(cache) == 1
ag.value_and_grad(objective)(params0)
assert counters["download"] == 1 # download field data
assert len(cache) == 2


def _test_load_cache_hit(monkeypatch, tmp_path, basic_simulation, fake_data):
Expand Down Expand Up @@ -499,7 +585,9 @@ def _test_env_var_overrides(monkeypatch, tmp_path):
manager._reload()


def test_cache_sequential(monkeypatch, tmp_path, tmp_path_factory, basic_simulation, fake_data):
def test_cache_sequential(
monkeypatch, tmp_path, tmp_path_factory, basic_simulation, fake_data, request
):
"""Run all critical cache tests in sequence to ensure stability."""
monkeypatch.setattr(config.local_cache, "enabled", True)

Expand All @@ -514,7 +602,7 @@ def test_cache_sequential(monkeypatch, tmp_path, tmp_path_factory, basic_simulat
_test_cache_eviction_by_size(monkeypatch, tmp_path_factory, basic_simulation)
_test_run_cache_hit_async(monkeypatch, basic_simulation, tmp_path)
_test_job_run_cache(monkeypatch, basic_simulation, tmp_path)
_test_autograd_cache(monkeypatch)
_test_autograd_cache(monkeypatch, request)
_test_configure_cache_roundtrip(monkeypatch, tmp_path)
_test_mode_solver_caching(monkeypatch, tmp_path)
_test_verbosity(monkeypatch, basic_simulation)
65 changes: 27 additions & 38 deletions tests/test_web/test_webapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,29 @@
Env.dev.active()


class FakeJob:
def __init__(self, task_id: str, statuses: list[str], events: list[str]):
self.task_id = task_id
self._statuses = statuses
self._idx = 0
self.events = events

@property
def status(self):
status = self._statuses[self._idx]
if self._idx < len(self._statuses) - 1:
self._idx += 1
self.events.append((self.task_id, "status", status))
return status

def download(self, path: PathLike):
self.events.append((self.task_id, "download", str(path)))

@property
def load_if_cached(self):
return False


class ImmediateExecutor:
def __init__(self, *args, **kwargs):
pass
Expand Down Expand Up @@ -728,32 +751,15 @@ def mock_start_interrupt(self, *args, **kwargs):
def test_batch_monitor_downloads_on_success(monkeypatch, tmp_path):
events = []

class FakeJob:
def __init__(self, task_id: str, statuses: list[str]):
self.task_id = task_id
self._statuses = statuses
self._idx = 0

@property
def status(self):
status = self._statuses[self._idx]
if self._idx < len(self._statuses) - 1:
self._idx += 1
events.append((self.task_id, "status", status))
return status

def download(self, path: PathLike):
events.append((self.task_id, "download", str(path)))

monkeypatch.setattr("tidy3d.web.api.container.ThreadPoolExecutor", ImmediateExecutor)
monkeypatch.setattr("tidy3d.web.api.container.time.sleep", lambda *_args, **_kwargs: None)

sims = {"task_a": make_sim(), "task_b": make_sim()}
batch = Batch(simulations=sims, folder_name=PROJECT_NAME, verbose=False)
batch._cached_properties = {}
fake_jobs = {
"task_a": FakeJob("task_a_id", ["running", "success", "success"]),
"task_b": FakeJob("task_b_id", ["running", "running", "success"]),
"task_a": FakeJob("task_a_id", ["running", "success", "success"], events),
"task_b": FakeJob("task_b_id", ["running", "running", "success"], events),
}
batch._cached_properties["jobs"] = fake_jobs

Expand Down Expand Up @@ -786,32 +792,15 @@ def download(self, path: PathLike):
def test_batch_monitor_skips_existing_download(monkeypatch, tmp_path):
events = []

class FakeJob:
def __init__(self, task_id: str, statuses: list[str]):
self.task_id = task_id
self._statuses = statuses
self._idx = 0

@property
def status(self):
status = self._statuses[self._idx]
if self._idx < len(self._statuses) - 1:
self._idx += 1
events.append((self.task_id, "status", status))
return status

def download(self, path: PathLike):
events.append((self.task_id, "download", str(path)))

monkeypatch.setattr("tidy3d.web.api.container.ThreadPoolExecutor", ImmediateExecutor)
monkeypatch.setattr("tidy3d.web.api.container.time.sleep", lambda *_args, **_kwargs: None)

sims = {"task_a": make_sim(), "task_b": make_sim()}
batch = Batch(simulations=sims, folder_name=PROJECT_NAME, verbose=False)
batch._cached_properties = {}
fake_jobs = {
"task_a": FakeJob("task_a_id", ["success", "success"]),
"task_b": FakeJob("task_b_id", ["running", "success"]),
"task_a": FakeJob("task_a_id", ["success", "success"], events),
"task_b": FakeJob("task_b_id", ["running", "success"], events),
}
batch._cached_properties["jobs"] = fake_jobs

Expand Down
32 changes: 24 additions & 8 deletions tidy3d/web/api/autograd/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tidy3d.web.api.asynchronous import run_async as run_async_webapi
from tidy3d.web.api.container import BatchData
from tidy3d.web.api.tidy3d_stub import Tidy3dStub
from tidy3d.web.api.webapi import load, restore_simulation_if_cached
from tidy3d.web.api.webapi import run as run_webapi
from tidy3d.web.core.types import PayType

Expand Down Expand Up @@ -561,16 +562,31 @@ def _run_primitive(
aux_data=aux_data,
)
else:
sim_combined.validate_pre_upload()
sim_original = sim_original.updated_copy(simulation_type="autograd_fwd", deep=False)
run_kwargs["simulation_type"] = "autograd_fwd"
run_kwargs["sim_fields_keys"] = list(sim_fields.keys())

sim_data_orig, task_id_fwd = _run_tidy3d(
sim_original,
task_name=task_name,
**run_kwargs,
restored_path, task_id_fwd = restore_simulation_if_cached(
simulation=sim_original,
path=run_kwargs.get("path", None),
reduce_simulation=run_kwargs.get("reduce_simulation", "auto"),
verbose=run_kwargs.get("verbose", True),
)
if restored_path is None or task_id_fwd is None:
sim_combined.validate_pre_upload()
run_kwargs["simulation_type"] = "autograd_fwd"
run_kwargs["sim_fields_keys"] = list(sim_fields.keys())

sim_data_orig, task_id_fwd = _run_tidy3d(
sim_original,
task_name=task_name,
**run_kwargs,
)
else:
sim_data_orig = load(
task_id=None,
path=run_kwargs.get("path", None),
verbose=run_kwargs.get("verbose", None),
progress_callback=run_kwargs.get("progress_callback", None),
lazy=run_kwargs.get("lazy", None),
)

# TODO: put this in postprocess?
aux_data[AUX_KEY_FWD_TASK_ID] = task_id_fwd
Expand Down
17 changes: 17 additions & 0 deletions tidy3d/web/api/autograd/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import tidy3d as td
from tidy3d.components.autograd import AutogradFieldMap
from tidy3d.components.autograd.field_map import FieldMap, TracerKeys
from tidy3d.web.api.webapi import get_info, load_simulation
from tidy3d.web.cache import resolve_local_cache
from tidy3d.web.core.s3utils import download_file, upload_file # type: ignore

from .constants import SIM_FIELDS_KEYS_FILE, SIM_VJP_FILE
Expand Down Expand Up @@ -39,6 +41,21 @@ def get_vjp_traced_fields(task_id_adj: str, verbose: bool) -> AutogradFieldMap:
try:
download_file(task_id_adj, SIM_VJP_FILE, to_file=fname, verbose=verbose)
field_map = FieldMap.from_file(fname)

simulation_cache = resolve_local_cache()
if simulation_cache is not None:
info = get_info(task_id_adj, verbose=False)
workflow_type = getattr(info, "taskType", None)
simulation = None
with tempfile.NamedTemporaryFile(suffix=".hdf5") as tmp_file:
simulation = load_simulation(task_id_adj, path=tmp_file.name, verbose=False)
simulation_cache.store_result(
stub_data=field_map,
task_id=task_id_adj,
path=fname,
workflow_type=workflow_type,
simulation=simulation,
)
except Exception as e:
td.log.error(f"Error occurred while getting VJP traced fields: {e}")
raise e
Expand Down
Loading