Skip to content

Commit

Permalink
adjoint run_local disables restrictive validators
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex authored and momchil-flex committed Apr 3, 2024
1 parent 2797fe6 commit cd7bb00
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 38 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Added
- 2D heat simulations are now fully supported.
- `tidy3d.plugins.adjoint.web.run_local` used in place of `run` will skip validators that restrict the size or number of `input_structures`.

### Fixed
- Better error message when trying to transform a geometry with infinite bounds.
Expand Down
39 changes: 27 additions & 12 deletions tests/test_plugins/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,11 +1346,13 @@ def make_vertices(n: int) -> np.ndarray:
return np.stack((np.cos(angles), np.sin(angles)), axis=-1)

vertices_pass = make_vertices(MAX_NUM_VERTICES)
_ = JaxPolySlab(vertices=vertices_pass, slab_bounds=(-1, 1))
ps = JaxPolySlab(vertices=vertices_pass, slab_bounds=(-1, 1))
ps._validate_web_adjoint()

with pytest.raises(pydantic.ValidationError):
vertices_fail = make_vertices(MAX_NUM_VERTICES + 1)
_ = JaxPolySlab(vertices=vertices_fail, slab_bounds=(-1, 1))
vertices_fail = make_vertices(MAX_NUM_VERTICES + 1)
ps = JaxPolySlab(vertices=vertices_fail, slab_bounds=(-1, 1))
with pytest.raises(AdjointError):
ps._validate_web_adjoint()


def _test_custom_medium_3D(use_emulated_run):
Expand Down Expand Up @@ -1410,10 +1412,15 @@ def make_custom_medium(num_cells: int) -> JaxCustomMedium:
jax_eps_dataset = JaxPermittivityDataset(**field_components)
return JaxCustomMedium(eps_dataset=jax_eps_dataset)

make_custom_medium(num_cells=1)
make_custom_medium(num_cells=MAX_NUM_CELLS_CUSTOM_MEDIUM)
with pytest.raises(pydantic.ValidationError):
make_custom_medium(num_cells=MAX_NUM_CELLS_CUSTOM_MEDIUM + 1)
med = make_custom_medium(num_cells=1)
med._validate_web_adjoint()

med = make_custom_medium(num_cells=MAX_NUM_CELLS_CUSTOM_MEDIUM)
med._validate_web_adjoint()

med = make_custom_medium(num_cells=MAX_NUM_CELLS_CUSTOM_MEDIUM + 1)
with pytest.raises(td.exceptions.SetupError):
med._validate_web_adjoint()


def test_jax_sim_io(tmp_path):
Expand Down Expand Up @@ -1462,18 +1469,26 @@ def make_custom_medium(num_cells: int) -> JaxCustomMedium:
assert sim == sim2


def test_num_input_structures():
def test_num_input_structures(use_emulated_run, tmp_path):
"""Assert proper error is raised if number of input structures is too large."""

def make_sim_(num_input_structures: int) -> JaxSimulation:
sim = make_sim(permittivity=EPS, size=SIZE, vertices=VERTICES, base_eps_val=BASE_EPS_VAL)
struct = sim.input_structures[0]
return sim.updated_copy(input_structures=num_input_structures * [struct])

_ = make_sim_(num_input_structures=MAX_NUM_INPUT_STRUCTURES)
sim = make_sim_(num_input_structures=MAX_NUM_INPUT_STRUCTURES)
sim._validate_web_adjoint()

with pytest.raises(pydantic.ValidationError):
_ = make_sim_(num_input_structures=MAX_NUM_INPUT_STRUCTURES + 1)
sim = make_sim_(num_input_structures=MAX_NUM_INPUT_STRUCTURES + 1)
with pytest.raises(AdjointError):
sim._validate_web_adjoint()

# make sure that the remote web API fails whereas the local one passes
with pytest.raises(AdjointError):
sim_data = run(sim, task_name="test", path=str(tmp_path / RUN_FILE))

sim_data = run_local(sim, task_name="test", path=str(tmp_path / RUN_FILE))


@pytest.mark.parametrize("strict_binarize", (True, False))
Expand Down
18 changes: 18 additions & 0 deletions tidy3d/plugins/adjoint/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,20 @@
from .data.data_array import JaxDataArray, JAX_DATA_ARRAY_TAG


# end of the error message when a ``_validate_web_adjoint`` exception is raised
WEB_ADJOINT_MESSAGE = (
"You can still run this simulation through "
"'tidy3d.plugins.adjoint.web.run_local' or 'tidy3d.plugins.adjoint.web.run_local' "
", which are similar to 'run' / 'run_async', but "
"perform the gradient postprocessing calculation locally after the simulation runs. "
"Note that the postprocessing time can become "
"quite long (several minutes or more) if these restrictions are exceeded. "
"Furthermore, the local versions of 'adjoint' require downloading field data "
"inside of the 'input_structures', which can greatly increase the size of data "
"needing to be downloaded."
)


class JaxObject(Tidy3dBaseModel):
"""Abstract class that makes a :class:`.Tidy3dBaseModel` jax-compatible through inheritance."""

Expand Down Expand Up @@ -57,6 +71,10 @@ def jax_fields(self) -> dict:
jax_field_names = self.get_jax_field_names()
return {key: getattr(self, key) for key in jax_field_names}

def _validate_web_adjoint(self) -> None:
"""Run validators for this component, only if using ``tda.web.run()``."""
pass

"""Methods needed for jax to register arbitrary classes."""

def tree_flatten(self) -> Tuple[list, dict]:
Expand Down
15 changes: 9 additions & 6 deletions tidy3d/plugins/adjoint/components/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ....constants import fp_eps, MICROMETER
from ....exceptions import AdjointError

from .base import JaxObject
from .base import JaxObject, WEB_ADJOINT_MESSAGE
from .types import JaxFloat

# number of integration points per unit wavelength in material
Expand Down Expand Up @@ -292,14 +292,17 @@ def no_dilation(cls, val):
raise AdjointError("'JaxPolySlab' does not support dilation.")
return val

@pd.validator("vertices", always=True)
def limit_number_of_vertices(cls, val):
def _validate_web_adjoint(self) -> None:
"""Run validators for this component, only if using ``tda.web.run()``."""
self._limit_number_of_vertices()

def _limit_number_of_vertices(self) -> None:
"""Limit the maximum number of vertices."""
if len(val) > MAX_NUM_VERTICES:
if len(self.vertices_jax) > MAX_NUM_VERTICES:
raise AdjointError(
f"For performance, a maximum of {MAX_NUM_VERTICES} are allowed in 'JaxPolySlab'."
f"For performance, a maximum of {MAX_NUM_VERTICES} are allowed in 'JaxPolySlab'. "
+ WEB_ADJOINT_MESSAGE
)
return val

def edge_contrib(
self,
Expand Down
18 changes: 11 additions & 7 deletions tidy3d/plugins/adjoint/components/medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ....exceptions import SetupError
from ....constants import CONDUCTIVITY

from .base import JaxObject
from .base import JaxObject, WEB_ADJOINT_MESSAGE
from .types import JaxFloat
from .data.data_array import JaxDataArray
from .data.dataset import JaxPermittivityDataset
Expand Down Expand Up @@ -304,24 +304,28 @@ def _pre_deprecation_dataset(cls, values):
)
return values

@pd.validator("eps_dataset", always=True)
def _is_not_too_large(cls, val):
def _validate_web_adjoint(self) -> None:
"""Run validators for this component, only if using ``tda.web.run()``."""
self._is_not_too_large()

def _is_not_too_large(self):
"""Ensure number of pixels does not surpass a set amount."""

field_components = self.eps_dataset.field_components

for field_dim in "xyz":
field_name = f"eps_{field_dim}{field_dim}"
data_array = val.field_components[field_name]
data_array = field_components[field_name]
coord_lens = [len(data_array.coords[key]) for key in "xyz"]
num_cells_dim = np.prod(coord_lens)
if num_cells_dim > MAX_NUM_CELLS_CUSTOM_MEDIUM:
raise SetupError(
"For the adjoint plugin, each component of the 'JaxCustomMedium.eps_dataset' "
f"is restricted to have a maximum of {MAX_NUM_CELLS_CUSTOM_MEDIUM} cells. "
f"Detected {num_cells_dim} grid cells in the '{field_name}' component ."
f"Detected {num_cells_dim} grid cells in the '{field_name}' component. "
+ WEB_ADJOINT_MESSAGE
)

return val

@pd.validator("eps_dataset", always=True)
def _eps_dataset_single_frequency(cls, val):
"""Override of inherited validator. (still needed)"""
Expand Down
29 changes: 16 additions & 13 deletions tidy3d/plugins/adjoint/components/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ....constants import HERTZ, SECOND
from ....exceptions import AdjointError

from .base import JaxObject
from .base import JaxObject, WEB_ADJOINT_MESSAGE
from .structure import (
JaxStructure,
JaxStructureType,
Expand Down Expand Up @@ -176,18 +176,6 @@ def _subpixel_is_on(cls, val):
raise AdjointError("'JaxSimulation.subpixel' must be 'True' to use adjoint plugin.")
return val

@pd.validator("input_structures", always=True)
def _restrict_input_structures(cls, val):
"""Restrict number of input structures."""
num_input_structures = len(val)
if num_input_structures > MAX_NUM_INPUT_STRUCTURES:
raise AdjointError(
"For performance, adjoint plugin restricts the number of input structures to "
f"{MAX_NUM_INPUT_STRUCTURES}. Found {num_input_structures}."
)

return val

@pd.validator("input_structures", always=True)
@skip_if_fields_missing(["structures"])
def _warn_overlap(cls, val, values):
Expand Down Expand Up @@ -281,6 +269,21 @@ def _warn_nonlinear_input_structure(cls, val):
log.warning(f"Nonlinear medium detected in input_structures[{i}]. " + NL_WARNING)
return val

def _restrict_input_structures(self) -> None:
"""Restrict number of input structures."""
num_input_structures = len(self.input_structures)
if num_input_structures > MAX_NUM_INPUT_STRUCTURES:
raise AdjointError(
"For performance, adjoint plugin restricts the number of input structures to "
f"{MAX_NUM_INPUT_STRUCTURES}. Found {num_input_structures}. " + WEB_ADJOINT_MESSAGE
)

def _validate_web_adjoint(self) -> None:
"""Run validators for this component, only if using ``tda.web.run()``."""
self._restrict_input_structures()
for structure in self.input_structures:
structure._validate_web_adjoint()

@staticmethod
def get_freqs_adjoint(output_monitors: List[Monitor]) -> List[float]:
"""Return sorted list of unique frequencies stripped from a collection of monitors."""
Expand Down
7 changes: 7 additions & 0 deletions tidy3d/plugins/adjoint/components/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ def _check_2d_geometry(cls, val, values):
"""Override validator checking 2D geometry, which triggers unnecessarily for gradients."""
return val

def _validate_web_adjoint(self) -> None:
"""Run validators for this component, only if using ``tda.web.run()``."""
if "geometry" in self._differentiable_fields:
self.geometry._validate_web_adjoint()
if "medium" in self._differentiable_fields:
self.medium._validate_web_adjoint()

@property
def jax_fields(self):
"""The fields that are jax-traced for this class."""
Expand Down
10 changes: 10 additions & 0 deletions tidy3d/plugins/adjoint/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def run(
Object containing solver results for the supplied :class:`.JaxSimulation`.
"""

simulation._validate_web_adjoint()

sim, jax_info = simulation.to_simulation()

sim_data = tidy3d_run_fn(
Expand All @@ -134,6 +136,8 @@ def run_fwd(
) -> Tuple[JaxSimulationData, Tuple[RunResidual]]:
"""Run forward pass and stash extra objects for the backwards pass."""

simulation._validate_web_adjoint()

sim_fwd, jax_info_fwd, jax_info_orig = simulation.to_simulation_fwd()

sim_data_orig, task_id = webapi_run_adjoint_fwd(
Expand Down Expand Up @@ -396,6 +400,9 @@ def run_async(
Contains the :class:`.JaxSimulationData` of each :class:`.JaxSimulation`.
"""

for simulation in simulations:
simulation._validate_web_adjoint()

# get task names, the td.Simulation, and JaxInfo for all supplied simulations
task_names = [str(_task_name_orig(i)) for i in range(len(simulations))]
task_info = [jax_sim.to_simulation() for jax_sim in simulations]
Expand Down Expand Up @@ -437,6 +444,9 @@ def run_async_fwd(
) -> Tuple[Tuple[JaxSimulationData, ...], RunResidualBatch]:
"""Run forward pass and stash extra objects for the backwards pass."""

for simulation in simulations:
simulation._validate_web_adjoint()

jax_infos_orig = []
sims_fwd = []
jax_infos_fwd = []
Expand Down

0 comments on commit cd7bb00

Please sign in to comment.