From cd7bb003f757a4e7c9c73c05b597cf01521c3762 Mon Sep 17 00:00:00 2001 From: Tyler Hughes Date: Mon, 1 Apr 2024 16:18:44 +0900 Subject: [PATCH] adjoint run_local disables restrictive validators --- CHANGELOG.md | 3 ++ tests/test_plugins/test_adjoint.py | 39 +++++++++++++------ tidy3d/plugins/adjoint/components/base.py | 18 +++++++++ tidy3d/plugins/adjoint/components/geometry.py | 15 ++++--- tidy3d/plugins/adjoint/components/medium.py | 18 +++++---- .../plugins/adjoint/components/simulation.py | 29 +++++++------- .../plugins/adjoint/components/structure.py | 7 ++++ tidy3d/plugins/adjoint/web.py | 10 +++++ 8 files changed, 101 insertions(+), 38 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 973e13fc2..25625cb86 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] + +### Added - 2D heat simulations are now fully supported. +- `tidy3d.plugins.adjoint.web.run_local` used in place of `run` will skip validators that restrict the size or number of `input_structures`. ### Fixed - Better error message when trying to transform a geometry with infinite bounds. diff --git a/tests/test_plugins/test_adjoint.py b/tests/test_plugins/test_adjoint.py index 5c4a5fd02..4387c446c 100644 --- a/tests/test_plugins/test_adjoint.py +++ b/tests/test_plugins/test_adjoint.py @@ -1346,11 +1346,13 @@ def make_vertices(n: int) -> np.ndarray: return np.stack((np.cos(angles), np.sin(angles)), axis=-1) vertices_pass = make_vertices(MAX_NUM_VERTICES) - _ = JaxPolySlab(vertices=vertices_pass, slab_bounds=(-1, 1)) + ps = JaxPolySlab(vertices=vertices_pass, slab_bounds=(-1, 1)) + ps._validate_web_adjoint() - with pytest.raises(pydantic.ValidationError): - vertices_fail = make_vertices(MAX_NUM_VERTICES + 1) - _ = JaxPolySlab(vertices=vertices_fail, slab_bounds=(-1, 1)) + vertices_fail = make_vertices(MAX_NUM_VERTICES + 1) + ps = JaxPolySlab(vertices=vertices_fail, slab_bounds=(-1, 1)) + with pytest.raises(AdjointError): + ps._validate_web_adjoint() def _test_custom_medium_3D(use_emulated_run): @@ -1410,10 +1412,15 @@ def make_custom_medium(num_cells: int) -> JaxCustomMedium: jax_eps_dataset = JaxPermittivityDataset(**field_components) return JaxCustomMedium(eps_dataset=jax_eps_dataset) - make_custom_medium(num_cells=1) - make_custom_medium(num_cells=MAX_NUM_CELLS_CUSTOM_MEDIUM) - with pytest.raises(pydantic.ValidationError): - make_custom_medium(num_cells=MAX_NUM_CELLS_CUSTOM_MEDIUM + 1) + med = make_custom_medium(num_cells=1) + med._validate_web_adjoint() + + med = make_custom_medium(num_cells=MAX_NUM_CELLS_CUSTOM_MEDIUM) + med._validate_web_adjoint() + + med = make_custom_medium(num_cells=MAX_NUM_CELLS_CUSTOM_MEDIUM + 1) + with pytest.raises(td.exceptions.SetupError): + med._validate_web_adjoint() def test_jax_sim_io(tmp_path): @@ -1462,7 +1469,7 @@ def make_custom_medium(num_cells: int) -> JaxCustomMedium: assert sim == sim2 -def test_num_input_structures(): +def test_num_input_structures(use_emulated_run, tmp_path): """Assert proper error is raised if number of input structures is too large.""" def make_sim_(num_input_structures: int) -> JaxSimulation: @@ -1470,10 +1477,18 @@ def make_sim_(num_input_structures: int) -> JaxSimulation: struct = sim.input_structures[0] return sim.updated_copy(input_structures=num_input_structures * [struct]) - _ = make_sim_(num_input_structures=MAX_NUM_INPUT_STRUCTURES) + sim = make_sim_(num_input_structures=MAX_NUM_INPUT_STRUCTURES) + sim._validate_web_adjoint() - with pytest.raises(pydantic.ValidationError): - _ = make_sim_(num_input_structures=MAX_NUM_INPUT_STRUCTURES + 1) + sim = make_sim_(num_input_structures=MAX_NUM_INPUT_STRUCTURES + 1) + with pytest.raises(AdjointError): + sim._validate_web_adjoint() + + # make sure that the remote web API fails whereas the local one passes + with pytest.raises(AdjointError): + sim_data = run(sim, task_name="test", path=str(tmp_path / RUN_FILE)) + + sim_data = run_local(sim, task_name="test", path=str(tmp_path / RUN_FILE)) @pytest.mark.parametrize("strict_binarize", (True, False)) diff --git a/tidy3d/plugins/adjoint/components/base.py b/tidy3d/plugins/adjoint/components/base.py index aaca05d9f..3b1e515e5 100644 --- a/tidy3d/plugins/adjoint/components/base.py +++ b/tidy3d/plugins/adjoint/components/base.py @@ -15,6 +15,20 @@ from .data.data_array import JaxDataArray, JAX_DATA_ARRAY_TAG +# end of the error message when a ``_validate_web_adjoint`` exception is raised +WEB_ADJOINT_MESSAGE = ( + "You can still run this simulation through " + "'tidy3d.plugins.adjoint.web.run_local' or 'tidy3d.plugins.adjoint.web.run_local' " + ", which are similar to 'run' / 'run_async', but " + "perform the gradient postprocessing calculation locally after the simulation runs. " + "Note that the postprocessing time can become " + "quite long (several minutes or more) if these restrictions are exceeded. " + "Furthermore, the local versions of 'adjoint' require downloading field data " + "inside of the 'input_structures', which can greatly increase the size of data " + "needing to be downloaded." +) + + class JaxObject(Tidy3dBaseModel): """Abstract class that makes a :class:`.Tidy3dBaseModel` jax-compatible through inheritance.""" @@ -57,6 +71,10 @@ def jax_fields(self) -> dict: jax_field_names = self.get_jax_field_names() return {key: getattr(self, key) for key in jax_field_names} + def _validate_web_adjoint(self) -> None: + """Run validators for this component, only if using ``tda.web.run()``.""" + pass + """Methods needed for jax to register arbitrary classes.""" def tree_flatten(self) -> Tuple[list, dict]: diff --git a/tidy3d/plugins/adjoint/components/geometry.py b/tidy3d/plugins/adjoint/components/geometry.py index 4069afd8b..ca353f9c5 100644 --- a/tidy3d/plugins/adjoint/components/geometry.py +++ b/tidy3d/plugins/adjoint/components/geometry.py @@ -21,7 +21,7 @@ from ....constants import fp_eps, MICROMETER from ....exceptions import AdjointError -from .base import JaxObject +from .base import JaxObject, WEB_ADJOINT_MESSAGE from .types import JaxFloat # number of integration points per unit wavelength in material @@ -292,14 +292,17 @@ def no_dilation(cls, val): raise AdjointError("'JaxPolySlab' does not support dilation.") return val - @pd.validator("vertices", always=True) - def limit_number_of_vertices(cls, val): + def _validate_web_adjoint(self) -> None: + """Run validators for this component, only if using ``tda.web.run()``.""" + self._limit_number_of_vertices() + + def _limit_number_of_vertices(self) -> None: """Limit the maximum number of vertices.""" - if len(val) > MAX_NUM_VERTICES: + if len(self.vertices_jax) > MAX_NUM_VERTICES: raise AdjointError( - f"For performance, a maximum of {MAX_NUM_VERTICES} are allowed in 'JaxPolySlab'." + f"For performance, a maximum of {MAX_NUM_VERTICES} are allowed in 'JaxPolySlab'. " + + WEB_ADJOINT_MESSAGE ) - return val def edge_contrib( self, diff --git a/tidy3d/plugins/adjoint/components/medium.py b/tidy3d/plugins/adjoint/components/medium.py index bcab0b3ef..ad1790058 100644 --- a/tidy3d/plugins/adjoint/components/medium.py +++ b/tidy3d/plugins/adjoint/components/medium.py @@ -16,7 +16,7 @@ from ....exceptions import SetupError from ....constants import CONDUCTIVITY -from .base import JaxObject +from .base import JaxObject, WEB_ADJOINT_MESSAGE from .types import JaxFloat from .data.data_array import JaxDataArray from .data.dataset import JaxPermittivityDataset @@ -304,24 +304,28 @@ def _pre_deprecation_dataset(cls, values): ) return values - @pd.validator("eps_dataset", always=True) - def _is_not_too_large(cls, val): + def _validate_web_adjoint(self) -> None: + """Run validators for this component, only if using ``tda.web.run()``.""" + self._is_not_too_large() + + def _is_not_too_large(self): """Ensure number of pixels does not surpass a set amount.""" + field_components = self.eps_dataset.field_components + for field_dim in "xyz": field_name = f"eps_{field_dim}{field_dim}" - data_array = val.field_components[field_name] + data_array = field_components[field_name] coord_lens = [len(data_array.coords[key]) for key in "xyz"] num_cells_dim = np.prod(coord_lens) if num_cells_dim > MAX_NUM_CELLS_CUSTOM_MEDIUM: raise SetupError( "For the adjoint plugin, each component of the 'JaxCustomMedium.eps_dataset' " f"is restricted to have a maximum of {MAX_NUM_CELLS_CUSTOM_MEDIUM} cells. " - f"Detected {num_cells_dim} grid cells in the '{field_name}' component ." + f"Detected {num_cells_dim} grid cells in the '{field_name}' component. " + + WEB_ADJOINT_MESSAGE ) - return val - @pd.validator("eps_dataset", always=True) def _eps_dataset_single_frequency(cls, val): """Override of inherited validator. (still needed)""" diff --git a/tidy3d/plugins/adjoint/components/simulation.py b/tidy3d/plugins/adjoint/components/simulation.py index aded60645..b4f27f5d4 100644 --- a/tidy3d/plugins/adjoint/components/simulation.py +++ b/tidy3d/plugins/adjoint/components/simulation.py @@ -20,7 +20,7 @@ from ....constants import HERTZ, SECOND from ....exceptions import AdjointError -from .base import JaxObject +from .base import JaxObject, WEB_ADJOINT_MESSAGE from .structure import ( JaxStructure, JaxStructureType, @@ -176,18 +176,6 @@ def _subpixel_is_on(cls, val): raise AdjointError("'JaxSimulation.subpixel' must be 'True' to use adjoint plugin.") return val - @pd.validator("input_structures", always=True) - def _restrict_input_structures(cls, val): - """Restrict number of input structures.""" - num_input_structures = len(val) - if num_input_structures > MAX_NUM_INPUT_STRUCTURES: - raise AdjointError( - "For performance, adjoint plugin restricts the number of input structures to " - f"{MAX_NUM_INPUT_STRUCTURES}. Found {num_input_structures}." - ) - - return val - @pd.validator("input_structures", always=True) @skip_if_fields_missing(["structures"]) def _warn_overlap(cls, val, values): @@ -281,6 +269,21 @@ def _warn_nonlinear_input_structure(cls, val): log.warning(f"Nonlinear medium detected in input_structures[{i}]. " + NL_WARNING) return val + def _restrict_input_structures(self) -> None: + """Restrict number of input structures.""" + num_input_structures = len(self.input_structures) + if num_input_structures > MAX_NUM_INPUT_STRUCTURES: + raise AdjointError( + "For performance, adjoint plugin restricts the number of input structures to " + f"{MAX_NUM_INPUT_STRUCTURES}. Found {num_input_structures}. " + WEB_ADJOINT_MESSAGE + ) + + def _validate_web_adjoint(self) -> None: + """Run validators for this component, only if using ``tda.web.run()``.""" + self._restrict_input_structures() + for structure in self.input_structures: + structure._validate_web_adjoint() + @staticmethod def get_freqs_adjoint(output_monitors: List[Monitor]) -> List[float]: """Return sorted list of unique frequencies stripped from a collection of monitors.""" diff --git a/tidy3d/plugins/adjoint/components/structure.py b/tidy3d/plugins/adjoint/components/structure.py index 599bc2f68..6db6ced16 100644 --- a/tidy3d/plugins/adjoint/components/structure.py +++ b/tidy3d/plugins/adjoint/components/structure.py @@ -38,6 +38,13 @@ def _check_2d_geometry(cls, val, values): """Override validator checking 2D geometry, which triggers unnecessarily for gradients.""" return val + def _validate_web_adjoint(self) -> None: + """Run validators for this component, only if using ``tda.web.run()``.""" + if "geometry" in self._differentiable_fields: + self.geometry._validate_web_adjoint() + if "medium" in self._differentiable_fields: + self.medium._validate_web_adjoint() + @property def jax_fields(self): """The fields that are jax-traced for this class.""" diff --git a/tidy3d/plugins/adjoint/web.py b/tidy3d/plugins/adjoint/web.py index 225ce1ac4..43a9c51b6 100644 --- a/tidy3d/plugins/adjoint/web.py +++ b/tidy3d/plugins/adjoint/web.py @@ -111,6 +111,8 @@ def run( Object containing solver results for the supplied :class:`.JaxSimulation`. """ + simulation._validate_web_adjoint() + sim, jax_info = simulation.to_simulation() sim_data = tidy3d_run_fn( @@ -134,6 +136,8 @@ def run_fwd( ) -> Tuple[JaxSimulationData, Tuple[RunResidual]]: """Run forward pass and stash extra objects for the backwards pass.""" + simulation._validate_web_adjoint() + sim_fwd, jax_info_fwd, jax_info_orig = simulation.to_simulation_fwd() sim_data_orig, task_id = webapi_run_adjoint_fwd( @@ -396,6 +400,9 @@ def run_async( Contains the :class:`.JaxSimulationData` of each :class:`.JaxSimulation`. """ + for simulation in simulations: + simulation._validate_web_adjoint() + # get task names, the td.Simulation, and JaxInfo for all supplied simulations task_names = [str(_task_name_orig(i)) for i in range(len(simulations))] task_info = [jax_sim.to_simulation() for jax_sim in simulations] @@ -437,6 +444,9 @@ def run_async_fwd( ) -> Tuple[Tuple[JaxSimulationData, ...], RunResidualBatch]: """Run forward pass and stash extra objects for the backwards pass.""" + for simulation in simulations: + simulation._validate_web_adjoint() + jax_infos_orig = [] sims_fwd = [] jax_infos_fwd = []