Skip to content

Commit

Permalink
Adding num_proc arg to adjoint run_local, and properly avoiding valid…
Browse files Browse the repository at this point in the history
…ators
  • Loading branch information
momchil-flex authored and tylerflex committed May 13, 2024
1 parent 2f6f8c6 commit 6316bad
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Classmethods in `DispersionFitter` to load complex-valued permittivity or loss tangent data.
- Pre-upload validator to check that mode sources overlap with more than 2 grid cells.
- Support `2DMedium` for `Transformed`/`GeometryGroup`/`ClipOperation` geometries.
- `num_proc` argument to `tidy3d.plugins.adjoint.web.run_local` to control the number of processes used on the local machine for gradient processing.

### Changed
- `tidy3d convert` from `.lsf` files to tidy3d scripts is moved to another repository at `https://github.com/hirako22/Lumerical-to-Tidy3D-Converter`.
Expand Down
37 changes: 30 additions & 7 deletions tests/test_plugins/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
JaxStructureStaticGeometry,
)
from tidy3d.plugins.adjoint.components.simulation import JaxSimulation, JaxInfo, RUN_TIME_FACTOR
from tidy3d.plugins.adjoint.components.simulation import MAX_NUM_INPUT_STRUCTURES
from tidy3d.plugins.adjoint.components import simulation
from tidy3d.plugins.adjoint.components.data.sim_data import JaxSimulationData
from tidy3d.plugins.adjoint.components.data.monitor_data import (
JaxModeData,
Expand Down Expand Up @@ -1564,26 +1564,49 @@ def make_custom_medium(num_cells: int) -> JaxCustomMedium:
assert sim == sim2


def test_num_input_structures(use_emulated_run, tmp_path):
def test_num_input_structures(use_emulated_run, tmp_path, monkeypatch):
"""Assert proper error is raised if number of input structures is too large."""

def make_sim_(num_input_structures: int) -> JaxSimulation:
sim = make_sim(permittivity=EPS, size=SIZE, vertices=VERTICES, base_eps_val=BASE_EPS_VAL)
test_max_num_structs = 3 # monkeypatch for easier testing
monkeypatch.setattr(simulation, "MAX_NUM_INPUT_STRUCTURES", test_max_num_structs)

def make_sim_(permittivity=EPS, num_input_structures: int = 1) -> JaxSimulation:
sim = make_sim(
permittivity=permittivity, size=SIZE, vertices=VERTICES, base_eps_val=BASE_EPS_VAL
)
struct = sim.input_structures[0]
return sim.updated_copy(input_structures=num_input_structures * [struct])

sim = make_sim_(num_input_structures=MAX_NUM_INPUT_STRUCTURES)
sim = make_sim_(num_input_structures=test_max_num_structs)
sim._validate_web_adjoint()

sim = make_sim_(num_input_structures=MAX_NUM_INPUT_STRUCTURES + 1)
sim = make_sim_(num_input_structures=test_max_num_structs + 1)
with pytest.raises(AdjointError):
sim._validate_web_adjoint()

# make sure that the remote web API fails whereas the local one passes
with pytest.raises(AdjointError):
run(sim, task_name="test", path=str(tmp_path / RUN_FILE))

run_local(sim, task_name="test", path=str(tmp_path / RUN_FILE))
# make sure it also errors in a gradient computation
def f(permittivity):
sim = make_sim_(permittivity=permittivity, num_input_structures=test_max_num_structs + 1)
sim_data = run(sim, task_name="test", path=str(tmp_path / RUN_FILE))
return objective(extract_amp(sim_data))

with pytest.raises(AdjointError):
_ = grad(f)(EPS)

# no error when calling run_local directly
sim_data = run_local(sim, task_name="test", path=str(tmp_path / RUN_FILE))

# no error when calling it inside a gradient computation
def f(permittivity):
sim = make_sim_(permittivity=permittivity, num_input_structures=test_max_num_structs + 1)
sim_data = run_local(sim, task_name="test", path=str(tmp_path / RUN_FILE), num_proc=2)
return objective(extract_amp(sim_data))

_ = grad(f)(EPS)


@pytest.mark.parametrize("strict_binarize", (True, False))
Expand Down
56 changes: 44 additions & 12 deletions tidy3d/plugins/adjoint/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ...components.types import Literal

from .components.base import JaxObject
from .components.simulation import JaxSimulation, JaxInfo
from .components.simulation import JaxSimulation, JaxInfo, NUM_PROC_LOCAL
from .components.data.sim_data import JaxSimulationData


Expand Down Expand Up @@ -76,6 +76,31 @@ def tidy3d_run_async_fn(simulations: Dict[str, Simulation], **kwargs) -> BatchDa
""" Running a single simulation using web.run. """


def _run(
simulation: JaxSimulation,
task_name: str,
folder_name: str = "default",
path: str = "simulation_data.hdf5",
callback_url: str = None,
verbose: bool = True,
) -> JaxSimulationData:
"""Split the provided ``JaxSimulation`` into a regular ``Simulation`` and a ``JaxInfo`` part,
run using ``tidy3d_run_fn``, which runs on the server by default but can be monkeypatched,
and recombine into a ``JaxSimulationData``.
"""
sim, jax_info = simulation.to_simulation()

sim_data = tidy3d_run_fn(
simulation=sim,
task_name=str(task_name),
folder_name=folder_name,
path=path,
callback_url=callback_url,
verbose=verbose,
)
return JaxSimulationData.from_sim_data(sim_data, jax_info)


@partial(custom_vjp, nondiff_argnums=tuple(range(1, 6)))
def run(
simulation: JaxSimulation,
Expand Down Expand Up @@ -113,18 +138,15 @@ def run(

simulation._validate_web_adjoint()

sim, jax_info = simulation.to_simulation()

sim_data = tidy3d_run_fn(
simulation=sim,
task_name=str(task_name),
# TODO: add task_id
return _run(
simulation=simulation,
task_name=task_name,
folder_name=folder_name,
path=path,
callback_url=callback_url,
verbose=verbose,
)
# TODO: add task_id
return JaxSimulationData.from_sim_data(sim_data, jax_info)


def run_fwd(
Expand Down Expand Up @@ -583,14 +605,15 @@ def webapi_run_async_adjoint_bwd(
""" Options to do the previous but all client side (mainly for testing / debugging)."""


@partial(custom_vjp, nondiff_argnums=tuple(range(1, 6)))
@partial(custom_vjp, nondiff_argnums=tuple(range(1, 7)))
def run_local(
simulation: JaxSimulation,
task_name: str,
folder_name: str = "default",
path: str = "simulation_data.hdf5",
callback_url: str = None,
verbose: bool = True,
num_proc: int = NUM_PROC_LOCAL,
) -> JaxSimulationData:
"""Submits a :class:`.JaxSimulation` to server, starts running, monitors progress, downloads,
and loads results as a :class:`.JaxSimulationData` object.
Expand All @@ -611,6 +634,8 @@ def run_local(
fields ``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.
verbose : bool = True
If `True`, will print progressbars and status, otherwise, will run silently.
num_proc: int = 1
Number of processes to use for the gradient computations.
Returns
-------
Expand Down Expand Up @@ -643,6 +668,7 @@ def run_local_fwd(
path: str,
callback_url: str,
verbose: bool,
num_proc: int,
) -> Tuple[JaxSimulationData, tuple]:
"""Run forward pass and stash extra objects for the backwards pass."""

Expand All @@ -651,7 +677,7 @@ def run_local_fwd(
input_structures=simulation.input_structures, freqs_adjoint=simulation.freqs_adjoint
)
sim_fwd = simulation.updated_copy(**grad_mnts)
sim_data_fwd = run(
sim_data_fwd = _run(
simulation=sim_fwd,
task_name=_task_name_fwd(task_name),
folder_name=folder_name,
Expand All @@ -671,6 +697,7 @@ def run_local_bwd(
path: str,
callback_url: str,
verbose: bool,
num_proc: int,
res: tuple,
sim_data_vjp: JaxSimulationData,
) -> Tuple[JaxSimulation]:
Expand All @@ -685,7 +712,7 @@ def run_local_bwd(
fwidth_adj = sim_data_fwd.simulation._fwidth_adjoint
run_time_adj = sim_data_fwd.simulation._run_time_adjoint
sim_adj = sim_data_vjp.make_adjoint_simulation(fwidth=fwidth_adj, run_time=run_time_adj)
sim_data_adj = run(
sim_data_adj = _run(
simulation=sim_adj,
task_name=_task_name_adj(task_name),
folder_name=folder_name,
Expand All @@ -699,7 +726,12 @@ def run_local_bwd(
grad_data_adj = sim_data_adj.grad_data_symmetry

# get gradient and insert into the resulting simulation structure medium
sim_vjp = sim_data_vjp.simulation.store_vjp(grad_data_fwd, grad_data_adj, grad_eps_data_fwd)
sim_vjp = sim_data_vjp.simulation.store_vjp(
grad_data_fwd=grad_data_fwd,
grad_data_adj=grad_data_adj,
grad_eps_data=grad_eps_data_fwd,
num_proc=num_proc,
)

return (sim_vjp,)

Expand Down

0 comments on commit 6316bad

Please sign in to comment.