Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion tests/test_web/test_local_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

import tidy3d as td
from tests.test_components.autograd.test_autograd import ALL_KEY, get_functions, params0
from tests.test_web.test_webapi_mode import make_mode_sim
from tidy3d import config
from tidy3d.config import get_manager
from tidy3d.web import Job, common, run_async
from tidy3d.web.api import webapi as web
from tidy3d.web.api.container import WebContainer
from tidy3d.web.api.container import Batch, WebContainer
from tidy3d.web.api.webapi import load_simulation_if_cached
from tidy3d.web.cache import CACHE_ARTIFACT_NAME, clear, resolve_local_cache

Expand Down Expand Up @@ -135,6 +136,9 @@ def _fake_status(self):
"_Info", (), {"solverVersion": "solver-1", "taskType": "FDTD"}
)(),
)
monkeypatch.setattr(
web, "load_simulation", lambda task_id, *args, **kwargs: TASK_TO_SIM[task_id]
)
return counters


Expand Down Expand Up @@ -175,6 +179,56 @@ def _test_load_simulation_if_cached(monkeypatch, tmp_path, basic_simulation):
assert sim_data_from_cache_with_path.simulation == basic_simulation


def _test_mode_solver_caching(monkeypatch, tmp_path):
counters = _patch_run_pipeline(monkeypatch)

# store in cache
mode_sim = make_mode_sim()
mode_sim_data = web.run(mode_sim)

# test basic loading from cache
from_cache_data = load_simulation_if_cached(mode_sim)
assert from_cache_data is not None
assert isinstance(from_cache_data, _FakeStubData)
assert mode_sim_data.simulation == from_cache_data.simulation

# test loading from run
_reset_counters(counters)
mode_sim_data_run = web.run(mode_sim)
assert counters["download"] == 0
assert isinstance(mode_sim_data_run, _FakeStubData)
assert mode_sim_data.simulation == mode_sim_data_run.simulation

# test loading from job
_reset_counters(counters)
job = Job(simulation=mode_sim, task_name="test")
job_data = job.run()
assert counters["download"] == 0
assert isinstance(job_data, _FakeStubData)
assert mode_sim_data.simulation == job_data.simulation

# test loading from batch
_reset_counters(counters)
mode_sim_batch = Batch(simulations={"sim1": mode_sim})
batch_data = mode_sim_batch.run(path_dir=tmp_path)
mode_sim_data_batch = batch_data["sim1"]
assert counters["download"] == 0
assert isinstance(mode_sim_data_batch, _FakeStubData)
assert mode_sim_data.simulation == mode_sim_data_batch.simulation

cache = resolve_local_cache(True)
# test storing via job
cache.clear()
Job(simulation=mode_sim, task_name="test").run()
assert load_simulation_if_cached(mode_sim) is not None

# test storing via batch
cache.clear()
batch_mode_data = Batch(simulations={"sim1": mode_sim}).run(path_dir=tmp_path)
_ = batch_mode_data["sim1"] # access to store
assert load_simulation_if_cached(mode_sim) is not None


def _test_run_cache_hit_async(monkeypatch, basic_simulation, tmp_path):
counters = _patch_run_pipeline(monkeypatch)
monkeypatch.setattr(config.local_cache, "max_entries", 128)
Expand Down Expand Up @@ -381,3 +435,4 @@ def test_cache_sequential(monkeypatch, tmp_path, tmp_path_factory, basic_simulat
_test_job_run_cache(monkeypatch, basic_simulation, tmp_path)
_test_autograd_cache(monkeypatch)
_test_configure_cache_roundtrip(monkeypatch, tmp_path)
_test_mode_solver_caching(monkeypatch, tmp_path)
19 changes: 15 additions & 4 deletions tidy3d/web/api/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from tidy3d.web.api.tidy3d_stub import Tidy3dStub
from tidy3d.web.api.webapi import restore_simulation_if_cached
from tidy3d.web.cache import _store_mode_solver_in_cache
from tidy3d.web.core.constants import TaskId, TaskName
from tidy3d.web.core.task_core import Folder
from tidy3d.web.core.task_info import RunInfo, TaskInfo
Expand Down Expand Up @@ -490,7 +491,15 @@ def load(self, path: PathLike = DEFAULT_DATA_PATH) -> WorkflowDataType:
lazy=self.lazy,
)
if isinstance(self.simulation, ModeSolver):
if not self.load_if_cached:
_store_mode_solver_in_cache(
self.task_id,
self.simulation,
data,
path,
)
self.simulation._patch_data(data=data)

return data

def delete(self) -> None:
Expand Down Expand Up @@ -1405,7 +1414,7 @@ def load(
task_paths[task_name] = str(self._job_data_path(task_id=job.task_id, path_dir=path_dir))
task_ids[task_name] = self.jobs[task_name].task_id

loaded = {task_name: job.load_if_cached for task_name, job in self.jobs.items()}
loaded_from_cache = {task_name: job.load_if_cached for task_name, job in self.jobs.items()}

if not skip_download:
self.download(path_dir=path_dir, replace_existing=replace_existing)
Expand All @@ -1414,17 +1423,19 @@ def load(
task_paths=task_paths,
task_ids=task_ids,
verbose=self.verbose,
cached_tasks=loaded,
cached_tasks=loaded_from_cache,
lazy=self.lazy,
is_downloaded=True,
)

for task_name, job in self.jobs.items():
if isinstance(job.simulation, ModeSolver):
job_data = data[task_name]
if not loaded_from_cache[task_name]:
_store_mode_solver_in_cache(
task_ids[task_name], job.simulation, job_data, task_paths[task_name]
)
job.simulation._patch_data(data=job_data)
if not skip_download:
self.download(path_dir=path_dir, replace_existing=replace_existing)

return data

Expand Down
41 changes: 31 additions & 10 deletions tidy3d/web/api/webapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
POST_VALIDATE_STATES,
STATE_PROGRESS_PERCENTAGE,
)
from tidy3d.web.cache import CacheEntry, resolve_local_cache
from tidy3d.web.cache import CacheEntry, _store_mode_solver_in_cache, resolve_local_cache
from tidy3d.web.core.account import Account
from tidy3d.web.core.constants import (
CM_DATA_HDF5_GZ,
Expand All @@ -44,7 +44,7 @@
from tidy3d.web.core.http_util import http
from tidy3d.web.core.task_core import BatchDetail, BatchTask, Folder, SimulationTask
from tidy3d.web.core.task_info import AsyncJobDetail, ChargeType, TaskInfo
from tidy3d.web.core.types import PayType
from tidy3d.web.core.types import PayType, TaskType

from .connect_util import REFRESH_TIME, get_grid_points_str, get_time_steps_str, wait_for_connection
from .tidy3d_stub import Tidy3dStub, Tidy3dStubData
Expand Down Expand Up @@ -575,7 +575,10 @@ def run(
)

if isinstance(simulation, ModeSolver):
if task_id is not None:
_store_mode_solver_in_cache(task_id, simulation, data, path)
simulation._patch_data(data=data)

return data


Expand Down Expand Up @@ -1298,7 +1301,7 @@ def load_simulation(
task_id : str
Unique identifier of task on server. Returned by :meth:`upload`.
path : PathLike = "simulation.json"
Download path to .json file of simulation (including filename).
Download path to .json or .hdf5 file of simulation (including filename).
verbose : bool = True
If ``True``, will print progressbars and status, otherwise, will run silently.

Expand All @@ -1308,7 +1311,13 @@ def load_simulation(
Simulation loaded from downloaded json file.
"""
task = SimulationTask.get(task_id)
task.get_simulation_json(path, verbose=verbose)
path = Path(path)
if path.suffix == ".json":
task.get_simulation_json(path, verbose=verbose)
elif path.suffix == ".hdf5":
task.get_simulation_hdf5(path, verbose=verbose)
else:
raise ValueError("Path suffix must be '.json' or '.hdf5'")
return Tidy3dStub.from_file(path)


Expand Down Expand Up @@ -1414,12 +1423,24 @@ def load(
if simulation_cache is not None and task_id is not None:
info = get_info(task_id, verbose=False)
workflow_type = getattr(info, "taskType", None)
simulation_cache.store_result(
stub_data=stub_data,
task_id=task_id,
path=path,
workflow_type=workflow_type,
)
if (
workflow_type != TaskType.MODE_SOLVER.name
): # we cannot get the simulation from data or web for mode solver
simulation = None
if lazy: # get simulation via web to avoid unpacking of lazy object in store_result
try:
with tempfile.NamedTemporaryFile(suffix=".hdf5") as tmp_file:
simulation = load_simulation(task_id, path=tmp_file.name, verbose=False)
except Exception as e:
log.info(f"Failed to load simulation for storing results: {e}.")
return stub_data
simulation_cache.store_result(
stub_data=stub_data,
task_id=task_id,
path=path,
workflow_type=workflow_type,
simulation=simulation,
)

return stub_data

Expand Down
89 changes: 83 additions & 6 deletions tidy3d/web/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
from typing import Any, Optional

from tidy3d import config
from tidy3d.components.mode.mode_solver import ModeSolver
from tidy3d.components.types.workflow import WorkflowDataType, WorkflowType
from tidy3d.log import log
from tidy3d.web.api.tidy3d_stub import Tidy3dStub
from tidy3d.web.core.constants import TaskId
from tidy3d.web.core.http_util import get_version as _get_protocol_version
from tidy3d.web.core.types import TaskType

CACHE_ARTIFACT_NAME = "simulation_data.hdf5"
CACHE_METADATA_NAME = "metadata.json"
Expand Down Expand Up @@ -316,17 +318,50 @@ def store_result(
task_id: TaskId,
path: str,
workflow_type: str,
) -> None:
simulation: Optional[WorkflowType] = None,
) -> bool:
"""
After we have the data (postprocess done), store it in the cache using the
canonical key (simulation hash + workflow type + environment + version).
Also records the task_id mapping for legacy lookups.
Stores completed workflow results in the local cache using a canonical cache key.

Parameters
----------
stub_data : :class:`.WorkflowDataType`
Object containing the workflow results, including references to the originating simulation.
task_id : str
Unique identifier of the finished workflow task.
path : str
Path to the results file on disk.
workflow_type : str
Type of workflow associated with the results (e.g., ``"SIMULATION"`` or ``"MODE_SOLVER"``).
simulation : Optional[:class:`.WorkflowDataType`]
Simulation object to use when computing the cache key. If not provided,
it will be inferred from ``stub_data.simulation`` when possible.

Returns
-------
bool
``True`` if the result was successfully stored in the local cache, ``False`` otherwise.

Notes
-----
The cache entry is keyed by the simulation hash, workflow type, environment, and protocol version.
This enables automatic reuse of identical simulation results across future runs.
Legacy task ID mappings are recorded to support backward lookup compatibility.
"""
try:
simulation_obj = getattr(stub_data, "simulation", None)
if simulation is not None:
simulation_obj = simulation
else:
simulation_obj = getattr(stub_data, "simulation", None)
if simulation_obj is None:
log.debug(
"Failed storing local cache entry: Could not find simulation data in stub_data."
)
return False
simulation_hash = simulation_obj._hash_self() if simulation_obj is not None else None
if not simulation_hash:
return
log.debug("Failed storing local cache entry: Could not hash simulation.")
return False

version = _get_protocol_version()

Expand All @@ -350,6 +385,8 @@ def store_result(
)
except Exception as e:
log.error(f"Could not store cache entry: {e}")
return False
return True


def _copy_and_hash(
Expand Down Expand Up @@ -510,4 +547,44 @@ def resolve_local_cache(use_cache: Optional[bool] = None) -> Optional[LocalCache
return None


def _store_mode_solver_in_cache(
task_id: TaskId, simulation: ModeSolver, data: WorkflowDataType, path: os.PathLike
) -> bool:
"""
Stores the results of a :class:`.ModeSolver` run in the local cache, if available.

Parameters
----------
task_id : str
Unique identifier of the mode solver task.
simulation : :class:`.ModeSolver`
Mode solver simulation object whose results should be cached.
data : :class:`.WorkflowDataType`
Data object containing the computed results to store.
path : PathLike
Path to the result file on disk.

Returns
-------
bool
``True`` if the result was successfully stored in the local cache, ``False`` otherwise.

Notes
-----
This helper is used internally to persist completed mode solver results
for reuse across repeated runs with identical configurations.
"""
simulation_cache = resolve_local_cache()
if simulation_cache is not None:
stored = simulation_cache.store_result(
stub_data=data,
task_id=task_id,
path=path,
workflow_type=TaskType.MODE_SOLVER.name,
simulation=simulation,
)
return stored
return False


resolve_local_cache()