diff --git a/tests/test_web/test_local_cache.py b/tests/test_web/test_local_cache.py index 1d44349eb0..d96cd1912a 100644 --- a/tests/test_web/test_local_cache.py +++ b/tests/test_web/test_local_cache.py @@ -7,16 +7,23 @@ from pathlib import Path from types import SimpleNamespace +import autograd as ag import pytest +import xarray as xr +from autograd.core import defvjp from rich.console import Console import tidy3d as td from tests.test_components.autograd.test_autograd import ALL_KEY, get_functions, params0 from tests.test_web.test_webapi_mode import make_mode_sim from tidy3d import config +from tidy3d.components.autograd.field_map import FieldMap from tidy3d.config import get_manager from tidy3d.web import Job, common, run, run_async from tidy3d.web.api import webapi as web +from tidy3d.web.api.autograd import autograd, engine, io_utils +from tidy3d.web.api.autograd.autograd import run as run_autograd +from tidy3d.web.api.autograd.constants import SIM_VJP_FILE from tidy3d.web.api.container import Batch, WebContainer from tidy3d.web.api.webapi import load_simulation_if_cached from tidy3d.web.cache import CACHE_ARTIFACT_NAME, clear, get_cache_entry_dir, resolve_local_cache @@ -39,6 +46,24 @@ class _FakeStubData: def __init__(self, simulation: td.Simulation): self.simulation = simulation + def __getitem__(self, key): + if key == "mode": + params = self.simulation.attrs["params_autograd"] + return SimpleNamespace( + amps=xr.DataArray(params, dims=["x"], coords={"x": list(range(len(params)))}) + ) + + def _strip_traced_fields(self, *args, **kwargs): + """Fake _strip_traced_fields: return minimal valid autograd-style mapping.""" + return {"params": self.simulation.attrs["params"]} + + def _insert_traced_fields(self, field_mapping, *args, **kwargs): + self.simulation.attrs["params_autograd"] = field_mapping["params"] + return self + + def _make_adjoint_sims(self, **kwargs): + return [self.simulation.updated_copy(run_time=self.simulation.run_time * 2)] + @pytest.fixture def basic_simulation(): @@ -128,6 +153,18 @@ def _fake__check_folder(*args, **kwargs): def _fake_status(self): return "success" + def _fake_download_file(resource_id, remote_filename, to_file=None, **kwargs): + # Only count this download if it's the adjoint/VJP file + if str(remote_filename) == SIM_VJP_FILE: + counters["download"] += 1 + + def _fake_from_file(*args, **kwargs): + field_map = FieldMap(tracers=()) + return field_map + + monkeypatch.setattr(io_utils, "download_file", _fake_download_file) + monkeypatch.setattr(autograd, "postprocess_fwd", _fake_from_file) + monkeypatch.setattr(FieldMap, "from_file", _fake_from_file) monkeypatch.setattr(WebContainer, "_check_folder", _fake__check_folder) monkeypatch.setattr(web, "upload", _fake_upload) monkeypatch.setattr(web, "start", _fake_start) @@ -135,6 +172,7 @@ def _fake_status(self): monkeypatch.setattr(web, "download", _fake_download) monkeypatch.setattr(web, "estimate_cost", lambda *args, **kwargs: 0.0) monkeypatch.setattr(Job, "status", property(_fake_status)) + monkeypatch.setattr(engine, "upload_sim_fields_keys", lambda *args, **kwargs: None) monkeypatch.setattr( web, "get_info", @@ -142,11 +180,23 @@ def _fake_status(self): "_Info", (), {"solverVersion": "solver-1", "taskType": "FDTD"} )(), ) + monkeypatch.setattr( + io_utils, + "get_info", + lambda task_id, verbose=True: type( + "_Info", (), {"solverVersion": "solver-1", "taskType": "FDTD"} + )(), + ) monkeypatch.setattr( web, "load_simulation", lambda task_id, *args, **kwargs: TASK_TO_SIM[task_id] ) + monkeypatch.setattr( + io_utils, "load_simulation", lambda task_id, *args, **kwargs: TASK_TO_SIM[task_id] + ) monkeypatch.setattr(BatchTask, "is_batch", lambda *args, **kwargs: "success") - monkeypatch.setattr(BatchTask, "detail", lambda *args: SimpleNamespace(status="success")) + monkeypatch.setattr( + BatchTask, "detail", lambda *args, **kwargs: SimpleNamespace(status="success") + ) return counters @@ -371,23 +421,59 @@ def _test_job_run_cache(monkeypatch, basic_simulation, tmp_path): assert os.path.exists(out2_path) -def _test_autograd_cache(monkeypatch): +def _test_autograd_cache(monkeypatch, request): counters = _patch_run_pipeline(monkeypatch) + + # "Original" rule: the one autograd uses by default + def _orig_make_dict_vjp(ans, keys, vals): + return lambda g: [g[key] for key in keys] + + def _zero_make_dict_vjp(ans, keys, vals): + def vjp(g): + # One gradient per entry in `vals`, all zeros, matching shape/dtype + return [ag.numpy.zeros_like(v) for v in vals] + + return vjp + + # Install our zero-VJP (this is the thing that affects global state) + defvjp( + ag.builtins._make_dict, + _zero_make_dict_vjp, + argnums=(1,), # gradient w.r.t. `vals` + ) + + # Make sure we restore it after the test + def _restore_make_dict_vjp(): + defvjp( + ag.builtins._make_dict, + _orig_make_dict_vjp, + argnums=(1,), + ) + + request.addfinalizer(_restore_make_dict_vjp) + cache = resolve_local_cache(use_cache=True) cache.clear() functions = get_functions(ALL_KEY, "mode") make_sim = functions["sim"] - sim = make_sim(params0) - web.run(sim) - assert counters["download"] == 1 - assert len(cache) == 1 + postprocess = functions["postprocess"] + + def objective(params): + sim = make_sim(params) + sim.attrs["params"] = params + sim_data = run_autograd(sim) + value = postprocess(sim_data) + return value + + ag.value_and_grad(objective)(params0) + assert counters["download"] == 2 + assert len(cache) == 2 _reset_counters(counters) - sim = make_sim(params0) - web.run(sim) - assert counters["download"] == 0 - assert len(cache) == 1 + ag.value_and_grad(objective)(params0) + assert counters["download"] == 1 # download field data + assert len(cache) == 2 def _test_load_cache_hit(monkeypatch, tmp_path, basic_simulation, fake_data): @@ -499,7 +585,9 @@ def _test_env_var_overrides(monkeypatch, tmp_path): manager._reload() -def test_cache_sequential(monkeypatch, tmp_path, tmp_path_factory, basic_simulation, fake_data): +def test_cache_sequential( + monkeypatch, tmp_path, tmp_path_factory, basic_simulation, fake_data, request +): """Run all critical cache tests in sequence to ensure stability.""" monkeypatch.setattr(config.local_cache, "enabled", True) @@ -514,7 +602,7 @@ def test_cache_sequential(monkeypatch, tmp_path, tmp_path_factory, basic_simulat _test_cache_eviction_by_size(monkeypatch, tmp_path_factory, basic_simulation) _test_run_cache_hit_async(monkeypatch, basic_simulation, tmp_path) _test_job_run_cache(monkeypatch, basic_simulation, tmp_path) - _test_autograd_cache(monkeypatch) + _test_autograd_cache(monkeypatch, request) _test_configure_cache_roundtrip(monkeypatch, tmp_path) _test_mode_solver_caching(monkeypatch, tmp_path) _test_verbosity(monkeypatch, basic_simulation) diff --git a/tests/test_web/test_webapi.py b/tests/test_web/test_webapi.py index 2b7c6e8a1f..48ee90f4bf 100644 --- a/tests/test_web/test_webapi.py +++ b/tests/test_web/test_webapi.py @@ -69,6 +69,29 @@ Env.dev.active() +class FakeJob: + def __init__(self, task_id: str, statuses: list[str], events: list[str]): + self.task_id = task_id + self._statuses = statuses + self._idx = 0 + self.events = events + + @property + def status(self): + status = self._statuses[self._idx] + if self._idx < len(self._statuses) - 1: + self._idx += 1 + self.events.append((self.task_id, "status", status)) + return status + + def download(self, path: PathLike): + self.events.append((self.task_id, "download", str(path))) + + @property + def load_if_cached(self): + return False + + class ImmediateExecutor: def __init__(self, *args, **kwargs): pass @@ -728,23 +751,6 @@ def mock_start_interrupt(self, *args, **kwargs): def test_batch_monitor_downloads_on_success(monkeypatch, tmp_path): events = [] - class FakeJob: - def __init__(self, task_id: str, statuses: list[str]): - self.task_id = task_id - self._statuses = statuses - self._idx = 0 - - @property - def status(self): - status = self._statuses[self._idx] - if self._idx < len(self._statuses) - 1: - self._idx += 1 - events.append((self.task_id, "status", status)) - return status - - def download(self, path: PathLike): - events.append((self.task_id, "download", str(path))) - monkeypatch.setattr("tidy3d.web.api.container.ThreadPoolExecutor", ImmediateExecutor) monkeypatch.setattr("tidy3d.web.api.container.time.sleep", lambda *_args, **_kwargs: None) @@ -752,8 +758,8 @@ def download(self, path: PathLike): batch = Batch(simulations=sims, folder_name=PROJECT_NAME, verbose=False) batch._cached_properties = {} fake_jobs = { - "task_a": FakeJob("task_a_id", ["running", "success", "success"]), - "task_b": FakeJob("task_b_id", ["running", "running", "success"]), + "task_a": FakeJob("task_a_id", ["running", "success", "success"], events), + "task_b": FakeJob("task_b_id", ["running", "running", "success"], events), } batch._cached_properties["jobs"] = fake_jobs @@ -786,23 +792,6 @@ def download(self, path: PathLike): def test_batch_monitor_skips_existing_download(monkeypatch, tmp_path): events = [] - class FakeJob: - def __init__(self, task_id: str, statuses: list[str]): - self.task_id = task_id - self._statuses = statuses - self._idx = 0 - - @property - def status(self): - status = self._statuses[self._idx] - if self._idx < len(self._statuses) - 1: - self._idx += 1 - events.append((self.task_id, "status", status)) - return status - - def download(self, path: PathLike): - events.append((self.task_id, "download", str(path))) - monkeypatch.setattr("tidy3d.web.api.container.ThreadPoolExecutor", ImmediateExecutor) monkeypatch.setattr("tidy3d.web.api.container.time.sleep", lambda *_args, **_kwargs: None) @@ -810,8 +799,8 @@ def download(self, path: PathLike): batch = Batch(simulations=sims, folder_name=PROJECT_NAME, verbose=False) batch._cached_properties = {} fake_jobs = { - "task_a": FakeJob("task_a_id", ["success", "success"]), - "task_b": FakeJob("task_b_id", ["running", "success"]), + "task_a": FakeJob("task_a_id", ["success", "success"], events), + "task_b": FakeJob("task_b_id", ["running", "success"], events), } batch._cached_properties["jobs"] = fake_jobs diff --git a/tidy3d/web/api/autograd/autograd.py b/tidy3d/web/api/autograd/autograd.py index 47bb1c7d8d..c2e4eb965c 100644 --- a/tidy3d/web/api/autograd/autograd.py +++ b/tidy3d/web/api/autograd/autograd.py @@ -19,6 +19,7 @@ from tidy3d.web.api.asynchronous import run_async as run_async_webapi from tidy3d.web.api.container import BatchData from tidy3d.web.api.tidy3d_stub import Tidy3dStub +from tidy3d.web.api.webapi import load, restore_simulation_if_cached from tidy3d.web.api.webapi import run as run_webapi from tidy3d.web.core.types import PayType @@ -561,16 +562,31 @@ def _run_primitive( aux_data=aux_data, ) else: - sim_combined.validate_pre_upload() sim_original = sim_original.updated_copy(simulation_type="autograd_fwd", deep=False) - run_kwargs["simulation_type"] = "autograd_fwd" - run_kwargs["sim_fields_keys"] = list(sim_fields.keys()) - - sim_data_orig, task_id_fwd = _run_tidy3d( - sim_original, - task_name=task_name, - **run_kwargs, + restored_path, task_id_fwd = restore_simulation_if_cached( + simulation=sim_original, + path=run_kwargs.get("path", None), + reduce_simulation=run_kwargs.get("reduce_simulation", "auto"), + verbose=run_kwargs.get("verbose", True), ) + if restored_path is None or task_id_fwd is None: + sim_combined.validate_pre_upload() + run_kwargs["simulation_type"] = "autograd_fwd" + run_kwargs["sim_fields_keys"] = list(sim_fields.keys()) + + sim_data_orig, task_id_fwd = _run_tidy3d( + sim_original, + task_name=task_name, + **run_kwargs, + ) + else: + sim_data_orig = load( + task_id=None, + path=run_kwargs.get("path", None), + verbose=run_kwargs.get("verbose", None), + progress_callback=run_kwargs.get("progress_callback", None), + lazy=run_kwargs.get("lazy", None), + ) # TODO: put this in postprocess? aux_data[AUX_KEY_FWD_TASK_ID] = task_id_fwd diff --git a/tidy3d/web/api/autograd/io_utils.py b/tidy3d/web/api/autograd/io_utils.py index e5b1f71c67..08d03b0f47 100644 --- a/tidy3d/web/api/autograd/io_utils.py +++ b/tidy3d/web/api/autograd/io_utils.py @@ -6,6 +6,8 @@ import tidy3d as td from tidy3d.components.autograd import AutogradFieldMap from tidy3d.components.autograd.field_map import FieldMap, TracerKeys +from tidy3d.web.api.webapi import get_info, load_simulation +from tidy3d.web.cache import resolve_local_cache from tidy3d.web.core.s3utils import download_file, upload_file # type: ignore from .constants import SIM_FIELDS_KEYS_FILE, SIM_VJP_FILE @@ -39,6 +41,21 @@ def get_vjp_traced_fields(task_id_adj: str, verbose: bool) -> AutogradFieldMap: try: download_file(task_id_adj, SIM_VJP_FILE, to_file=fname, verbose=verbose) field_map = FieldMap.from_file(fname) + + simulation_cache = resolve_local_cache() + if simulation_cache is not None: + info = get_info(task_id_adj, verbose=False) + workflow_type = getattr(info, "taskType", None) + simulation = None + with tempfile.NamedTemporaryFile(suffix=".hdf5") as tmp_file: + simulation = load_simulation(task_id_adj, path=tmp_file.name, verbose=False) + simulation_cache.store_result( + stub_data=field_map, + task_id=task_id_adj, + path=fname, + workflow_type=workflow_type, + simulation=simulation, + ) except Exception as e: td.log.error(f"Error occurred while getting VJP traced fields: {e}") raise e diff --git a/tidy3d/web/api/container.py b/tidy3d/web/api/container.py index 1f77181b49..236a339122 100644 --- a/tidy3d/web/api/container.py +++ b/tidy3d/web/api/container.py @@ -257,12 +257,14 @@ class Job(WebContainer): ) _stash_path: Optional[str] = PrivateAttr(default=None) + _cached_task_id: Optional[TaskId] = PrivateAttr(default=None) + @cached_property def _stash_path_for_job(self) -> str: """Stash file which is a temporary location for the cached-restored file.""" stash_dir = Path(tempfile.gettempdir()) / "tidy3d_stash" stash_dir.mkdir(parents=True, exist_ok=True) - return str(Path(stash_dir / f"{self._cached_task_id}.hdf5")) + return str(Path(stash_dir / f"{uuid.uuid4()}.hdf5")) def _materialize_from_stash(self, dst_path: os.PathLike) -> None: """Atomic copy from stash to requested path.""" @@ -332,14 +334,15 @@ def run( def load_if_cached(self) -> bool: """Checks if results are cached and (if yes) restores them into our shared stash file.""" # use temporary path as final destination is unknown - stash_path = self._stash_path_for_job() + stash_path = self._stash_path_for_job - restored = restore_simulation_if_cached( + restored, cached_task_id = restore_simulation_if_cached( simulation=self.simulation, path=stash_path, reduce_simulation=self.reduce_simulation, verbose=self.verbose, ) + self._cached_task_id = cached_task_id if restored is None: return False @@ -348,11 +351,6 @@ def load_if_cached(self) -> bool: atexit.register(self.clear_stash) return True - @cached_property - def _cached_task_id(self) -> TaskId: - """The task ID for jobs which are loaded from cache.""" - return "cached_" + self.task_name + "_" + str(uuid.uuid4()) - @cached_property def task_id(self) -> TaskId: """The task ID for this ``Job``. Uploads the ``Job`` if it hasn't already been uploaded.""" @@ -1127,6 +1125,8 @@ def schedule_download(job: Job) -> None: # ----- continue condition & status formatting ------------------------------- def check_continue_condition(job: Job) -> bool: + if job.load_if_cached: + return False status = job.status if not web._is_modeler_batch(job.task_id): return status not in END_STATES @@ -1514,6 +1514,10 @@ def estimate_cost(self, verbose: bool = True) -> float: console = get_logging_console() if batch_cost is not None and batch_cost > 0: console.log(f"Maximum FlexCredit cost: {batch_cost:1.3f} for the whole batch.") + elif batch_cost == 0 and all(job.load_if_cached for job in self.jobs.values()): + console.log( + "No Flexcredit cost for batch as all simulations were restored from local cache." + ) else: console.log("Could not get estimated batch cost!") diff --git a/tidy3d/web/api/webapi.py b/tidy3d/web/api/webapi.py index 1225d2ddf5..a57f0ebbfe 100644 --- a/tidy3d/web/api/webapi.py +++ b/tidy3d/web/api/webapi.py @@ -355,7 +355,7 @@ def restore_simulation_if_cached( path: Optional[PathLike] = None, reduce_simulation: Literal["auto", True, False] = "auto", verbose: bool = True, -) -> Optional[PathLike]: +) -> tuple[Optional[PathLike], Optional[TaskId]]: """ Attempt to restore simulation data from a local cache entry, if available. @@ -376,9 +376,12 @@ def restore_simulation_if_cached( ------- Optional[PathLike] The path to the restored simulation data if found in cache, otherwise None. If no path is specified, the cache entry path is returned, otherwise the given path is returned. + Optional[TaskId] + The original task id of the restored simulation data. """ simulation_cache = resolve_local_cache() retrieved_simulation_path = None + cached_task_id = None if simulation_cache is not None: sim_for_cache = simulation if isinstance(simulation, (ModeSolver, ModeSimulation)): @@ -399,7 +402,7 @@ def restore_simulation_if_cached( console.log( f"Loading simulation from local cache. View cached task using web UI at [link={url}]'{url}'[/link]." ) - return retrieved_simulation_path + return retrieved_simulation_path, cached_task_id def load_simulation_if_cached( @@ -428,7 +431,7 @@ def load_simulation_if_cached( Optional[WorkflowDataType] The loaded simulation data if found in cache, otherwise None. """ - restored_path = restore_simulation_if_cached( + restored_path, _ = restore_simulation_if_cached( simulation, path, reduce_simulation, verbose=verbose ) if restored_path is not None: @@ -547,7 +550,7 @@ def run( :meth:`tidy3d.web.api.container.Batch.monitor` Monitor progress of each of the running tasks. """ - restored_path = restore_simulation_if_cached( + restored_path, _ = restore_simulation_if_cached( simulation=simulation, path=path, reduce_simulation=reduce_simulation, diff --git a/tidy3d/web/cache.py b/tidy3d/web/cache.py index ae9776888d..1ca2be0b07 100644 --- a/tidy3d/web/cache.py +++ b/tidy3d/web/cache.py @@ -184,7 +184,6 @@ def _store(self, key: str, source_path: Path, metadata: dict[str, Any]) -> Optio shutil.rmtree(final_dir) os.replace(tmp_dir, final_dir) entry = CacheEntry(key=key, root=self._root, metadata=metadata) - log.debug("Stored simulation cache entry '%s' (%d bytes).", key, file_size) return entry finally: try: @@ -390,6 +389,7 @@ def store_result( source_path=Path(path), metadata=metadata, ) + log.debug("Stored local cache entry for workflow type '%s'.", workflow_type) except Exception as e: log.error(f"Could not store cache entry: {e}") return False