diff --git a/CHANGELOG.md b/CHANGELOG.md index 89030f540a..12a96d3465 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,6 +50,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added more RF-specific mode characteristics to `MicrowaveModeData`, including propagation constants (alpha, beta, gamma), phase/group velocities, wave impedance, and automatic mode classification with configurable polarization thresholds in `MicrowaveModeSpec`. - Introduce `tidy3d.rf` namespace to consolidate all RF classes. - Added support for custom colormaps in `plot_field`. +- Added `custom_vjp` and new custom run functions that provide hooks into adjoint for custom gradient calculations. ### Breaking Changes - Edge singularity correction at PEC and lossy metal edges defaults to `True`. diff --git a/tests/test_components/autograd/numerical/test_autograd_cm_custom_vjp.py b/tests/test_components/autograd/numerical/test_autograd_cm_custom_vjp.py new file mode 100644 index 0000000000..af11380471 --- /dev/null +++ b/tests/test_components/autograd/numerical/test_autograd_cm_custom_vjp.py @@ -0,0 +1,404 @@ +# tests custom_vjp autograd hooks for ComponentModeler and compares to numerically computed finite difference gradients +from __future__ import annotations + +import operator +import sys + +import autograd as ag +import matplotlib.pylab as plt +import numpy as np +import pytest +import xarray as xr + +import tidy3d as td +from tidy3d.plugins.smatrix import ComponentModeler, Port +from tidy3d.plugins.smatrix.run import _run_local +from tidy3d.web.api.autograd.types import CustomVJPConfig + +PLOT_FD_ADJ_COMPARISON = True +NUM_FINITE_DIFFERENCE = 10 +SAVE_FD_ADJ_DATA = True +SAVE_FD_LOC = 0 +SAVE_ADJ_LOC = 1 +LOCAL_GRADIENT = True +VERBOSE = False +NUMERICAL_RESULTS_SUBDIR = "numerical_cm_custom_vjp_test" +SHOW_PRINT_STATEMENTS = True + +OVERLAP_ERROR_THRESHOLD_DEG = 10.0 + +ADJOINT_PERMITTIVITY = 1.5**2 + +if PLOT_FD_ADJ_COMPARISON: + pytestmark = pytest.mark.usefixtures("mpl_config_interactive") +else: + pytestmark = pytest.mark.usefixtures("mpl_config_noninteractive") + +if SHOW_PRINT_STATEMENTS: + sys.stdout = sys.stderr + + +SIMULATION_SIZE_MESH_WVL_FACTOR = 7 +SIMULATION_HEIGHT_WVL_FACTOR = 3 + +SPHERE_OFFSET_MAX_MESH_WVL_FACTOR = 0.25 +SPHERE_MIN_RADIUS_MESH_WVL_FACTOR = 0.3 +SPHERE_MAX_RADIUS_MESH_WVL_FACTOR = 0.4 + +FD_STEP_MESH_WVL_FACTOR = 1.0 / 75.0 + + +def get_sim_geometry(mesh_wvl_um): + return td.Box( + size=( + SIMULATION_SIZE_MESH_WVL_FACTOR * mesh_wvl_um, + SIMULATION_SIZE_MESH_WVL_FACTOR * mesh_wvl_um, + SIMULATION_HEIGHT_WVL_FACTOR * mesh_wvl_um, + ), + center=(0, 0, 0), + ) + + +def make_base_sim( + mesh_wvl_um, + adj_wvl_um, + monitor_bg_index=1.0, + run_time=2e-11, +): + sim_geometry = get_sim_geometry(mesh_wvl_um) + sim_size_um = sim_geometry.size + sim_center_um = sim_geometry.center + + input_waveguide = td.Structure( + geometry=td.Box( + center=(-0.35 * sim_size_um[0], sim_center_um[1], sim_center_um[2]), + size=(0.5 * sim_size_um[0], 0.35 * adj_wvl_um, 0.2 * adj_wvl_um), + ), + medium=td.Medium(permittivity=3.5**2), + ) + + output_waveguide = td.Structure( + geometry=td.Box( + center=(0.35 * sim_size_um[0], sim_center_um[1], sim_center_um[2]), + size=(0.5 * sim_size_um[0], 0.35 * adj_wvl_um, 0.2 * adj_wvl_um), + ), + medium=td.Medium(permittivity=3.5**2), + ) + + num_modes = 1 + + port_left = Port( + center=input_waveguide.geometry.center, + size=(0.0, adj_wvl_um, adj_wvl_um), + mode_spec=td.ModeSpec(num_modes=num_modes), + direction="+", + name="left", + ) + + port_right = Port( + center=output_waveguide.geometry.center, + size=(0.0, adj_wvl_um, adj_wvl_um), + mode_spec=td.ModeSpec(num_modes=num_modes), + direction="-", + name="right", + ) + + boundary_spec = td.BoundarySpec( + x=td.Boundary.pml(), + y=td.Boundary.pml(), + z=td.Boundary.pml(), + ) + + ports = [port_left, port_right] + + return ports, td.Simulation( + center=sim_center_um, + size=sim_size_um, + grid_spec=td.GridSpec.auto( + min_steps_per_wvl=30, + wavelength=1.5, + ), + boundary_spec=boundary_spec, + sources=[], + monitors=[], + structures=[input_waveguide, output_waveguide], + run_time=1e-11, + ) + + +def vjp_sphere(sphere, derivative_info): + max_frequency = np.max(derivative_info.frequencies) + min_wvl = td.C_0 / max_frequency + + step_size = min_wvl / 20.0 + + ps_paths = set() + ps_paths.update({("permittivity",)}) + + update_kwargs = { + "paths": list(ps_paths), + "deep": False, + } + + def finite_difference_gradient(perturb_up, perturb_down, derivative_info_): + eps_up = derivative_info.updated_epsilon(perturb_up) + eps_down = derivative_info.updated_epsilon(perturb_down) + eps_grad = (eps_up - eps_down) / (2 * step_size) + + derivative_info_custom_medium = derivative_info_.updated_copy(**update_kwargs) + + custom_medium = td.CustomMedium(permittivity=xr.ones_like(eps_grad.isel(f=0, drop=True))) + vjps_custom_medium = custom_medium._compute_derivatives(derivative_info_custom_medium) + + total_grad = np.real(np.sum(eps_grad.sum("f").data * vjps_custom_medium[("permittivity",)])) + + return total_grad + + vjps = {} + for path in derivative_info.paths: + if path[0:2] == ( + "geometry", + "radius", + ): + sphere_up = sphere.updated_copy(radius=sphere.radius + step_size) + sphere_down = sphere.updated_copy(radius=sphere.radius - step_size) + vjps[path] = finite_difference_gradient(sphere_up, sphere_down, derivative_info) + elif path[0:2] == ("geometry", "center"): + if len(path) == 2: + center_indices = (0, 1, 2) + else: + _, center_index = path[1:] + center_indices = [center_index] + + vjp_result = [] + for center_index in center_indices: + center_up = list(sphere.center) + center_down = list(sphere.center) + + center_up[center_index] += step_size + center_down[center_index] -= step_size + + sphere_up = sphere.updated_copy(center=center_up) + sphere_down = sphere.updated_copy(center=center_down) + + vjp_result.append( + finite_difference_gradient(sphere_up, sphere_down, derivative_info) + ) + + vjps[path] = vjp_result if len(path) == 2 else vjp_result[0] + + return vjps + + +def create_objective_function(geometry, create_sim_base, adj_wvl_um, sim_path_dir): + def objective(geom_parameters_lists): + ports, sim_base = create_sim_base() + + simulation_dict = {} + geom_dict = {} + for idx, geom_parameters in enumerate(geom_parameters_lists): + sphere_structure = td.Structure( + geometry=td.Sphere(center=geom_parameters[0:3], radius=geom_parameters[3]), + medium=td.Medium(permittivity=ADJOINT_PERMITTIVITY), + ) + + sim_with_sphere = sim_base.updated_copy( + structures=(*sim_base.structures, sphere_structure) + ) + + simulation_dict[f"numerical_custom_vjp_testing_{idx}"] = sim_with_sphere.copy() + geom_dict[f"numerical_custom_vjp_testing_{idx}"] = geom_parameters + + sim_data = {} + for key, sim_val in simulation_dict.items(): + modeler = ComponentModeler( + simulation=sim_val, + ports=ports, + freqs=[td.C_0 / adj_wvl_um], + ) + + custom_vjp_single = CustomVJPConfig( + structure_index=td.Sphere, + compute_derivatives=vjp_sphere, + ) + + sim_data[key] = _run_local( + modeler, + local_gradient=LOCAL_GRADIENT, + verbose=VERBOSE, + custom_vjp=custom_vjp_single, + ) + + objective_vals = [] + for idx in range(len(geom_parameters_lists)): + smatrix = sim_data[f"numerical_custom_vjp_testing_{idx}"] + objective_vals.append(np.sum(np.abs(smatrix.smatrix().values) ** 2)) + + if len(geom_parameters_lists) == 1: + return objective_vals[0] + + return objective_vals + + return objective + + +background_indices = [1.0] +mesh_wvls_um = [1.5] +adj_wvls_um = [1.5] + +test_parameters = [] + +test_number = 0 +for idx in range(len(mesh_wvls_um)): + mesh_wvl_um = mesh_wvls_um[idx] + adj_wvl_um = adj_wvls_um[idx] + + for monitor_bg_index in background_indices: + test_parameters.append( + { + "mesh_wvl_um": mesh_wvl_um, + "adj_wvl_um": adj_wvl_um, + "monitor_bg_index": monitor_bg_index, + "test_number": test_number, + } + ) + + test_number += 1 + + +@pytest.mark.numerical +@pytest.mark.parametrize("test_parameters", test_parameters) +def test_finite_difference_custom_vjp(test_parameters, rng, numerical_case_dir): + """Test a variety of autograd permittivity gradients for DiffractionData by""" + """comparing them to numerical finite difference.""" + + test_number = test_parameters["test_number"] + + ( + mesh_wvl_um, + adj_wvl_um, + monitor_bg_index, + test_number, + ) = operator.itemgetter( + "mesh_wvl_um", + "adj_wvl_um", + "monitor_bg_index", + "test_number", + )(test_parameters) + + sim_geometry = get_sim_geometry(mesh_wvl_um) + + dim_um = mesh_wvl_um + thickness_um = 0.5 * mesh_wvl_um + block = td.Box( + center=(sim_geometry.center[0], sim_geometry.center[1], 0), + size=(dim_um, dim_um, thickness_um), + ) + + sim_path_dir = numerical_case_dir / "simulations" / f"test{test_number}" + sim_path_dir.mkdir(parents=True, exist_ok=True) + + objective = create_objective_function( + block, + lambda mesh_wvl_um=mesh_wvl_um, + adj_wvl_um=adj_wvl_um, + monitor_bg_index=monitor_bg_index: make_base_sim( + mesh_wvl_um=mesh_wvl_um, + adj_wvl_um=adj_wvl_um, + monitor_bg_index=monitor_bg_index, + ), + adj_wvl_um, + sim_path_dir=str(sim_path_dir), + ) + + obj_val_and_grad = ag.value_and_grad(objective) + + sphere_init = [ + *rng.uniform( + low=-SPHERE_OFFSET_MAX_MESH_WVL_FACTOR * mesh_wvl_um, + high=SPHERE_OFFSET_MAX_MESH_WVL_FACTOR * mesh_wvl_um, + size=2, + ), + 0.0, + *rng.uniform( + low=SPHERE_MIN_RADIUS_MESH_WVL_FACTOR * mesh_wvl_um, + high=SPHERE_MAX_RADIUS_MESH_WVL_FACTOR * mesh_wvl_um, + size=1, + ), + ] + + geom_init = sphere_init + test_results = np.zeros((2, len(geom_init))) + + obj, adj_grad = obj_val_and_grad([geom_init]) + adj_grad = np.squeeze(np.array(adj_grad)) + + # empirical step size for finite difference calculation + fd_step = FD_STEP_MESH_WVL_FACTOR * mesh_wvl_um + + all_params = [] + + for fd_idx in range(len(geom_init)): + geom_up = geom_init.copy() + geom_down = geom_init.copy() + + geom_up[fd_idx] += fd_step + geom_down[fd_idx] -= fd_step + + all_params.append(geom_up) + all_params.append(geom_down) + + all_obj = objective(all_params) + + fd_grad = np.zeros(len(geom_init)) + for fd_idx in range(len(geom_init)): + obj_up_location = 2 * fd_idx + obj_down_location = 2 * fd_idx + 1 + + fd_grad[fd_idx] = (all_obj[obj_up_location] - all_obj[obj_down_location]) / (2 * fd_step) + + rms_error = np.linalg.norm(fd_grad - adj_grad) + fd_mag = np.linalg.norm(fd_grad) + adj_mag = np.linalg.norm(adj_grad) + + dot = np.sum((fd_grad / fd_mag) * (adj_grad / adj_mag)) + overlap_deg = np.arccos(dot) * 180.0 / np.pi + + print("\n" * 3) + print("-" * 20) + print(f"Numerical test #{test_number}") + print(f"Mesh and adjoint wavelengths: {mesh_wvl_um}, {adj_wvl_um}") + print(f"Background index for monitor: {monitor_bg_index}") + print(f"RMS Error: {rms_error}") + print(f"Gradient overlap (deg): {overlap_deg}") + print(f"FD, Adj magnitudes: {fd_mag}, {adj_mag}") + print("-" * 20) + print("\n" * 3) + + test_results[SAVE_FD_LOC, :] = fd_grad + test_results[SAVE_ADJ_LOC, :] = adj_grad + + save_idx = test_number + 1 + save_path = None + if SAVE_FD_ADJ_DATA: + results_dir = numerical_case_dir / NUMERICAL_RESULTS_SUBDIR + results_dir.mkdir(parents=True, exist_ok=True) + save_path = results_dir / f"results_{save_idx}.npy" + + try: + assert overlap_deg < OVERLAP_ERROR_THRESHOLD_DEG, ( + "Adjoint and finite difference gradients misaligned." + ) + finally: + if save_path is not None: + np.save(save_path, test_results) + + test_number += 1 + + if PLOT_FD_ADJ_COMPARISON: + plt.plot(adj_grad, color="g", linewidth=2.0) + plt.plot(fd_grad, color="b", linewidth=1.5, linestyle="--") + plt.legend(["Adjoint", "Finite difference"]) + plt.xlabel("Sample number") + plt.ylabel("Gradient value") + plt.show() diff --git a/tests/test_components/autograd/numerical/test_autograd_custom_vjp.py b/tests/test_components/autograd/numerical/test_autograd_custom_vjp.py new file mode 100644 index 0000000000..a53de5d502 --- /dev/null +++ b/tests/test_components/autograd/numerical/test_autograd_custom_vjp.py @@ -0,0 +1,453 @@ +# tests custom_vjp autograd hook for run_custom and run_async_custom and compares to numerically computed finite difference gradients +from __future__ import annotations + +import operator +import sys + +import autograd as ag +import matplotlib.pylab as plt +import numpy as np +import pytest +import xarray as xr + +import tidy3d as td +from tidy3d.web.api.autograd.autograd import run_async_custom, run_custom +from tidy3d.web.api.autograd.types import CustomVJPConfig + +PLOT_FD_ADJ_COMPARISON = True +NUM_FINITE_DIFFERENCE = 10 +SAVE_FD_ADJ_DATA = True +SAVE_FD_LOC = 0 +SAVE_ADJ_LOC = 1 +LOCAL_GRADIENT = True +VERBOSE = False +NUMERICAL_RESULTS_SUBDIR = "numerical_custom_vjp_test" +SHOW_PRINT_STATEMENTS = True + +OVERLAP_ERROR_THRESHOLD_DEG = 10.0 + +ADJOINT_SPHERE_PERMITTIVITY = 1.5**2 + +if PLOT_FD_ADJ_COMPARISON: + pytestmark = pytest.mark.usefixtures("mpl_config_interactive") +else: + pytestmark = pytest.mark.usefixtures("mpl_config_noninteractive") + +if SHOW_PRINT_STATEMENTS: + sys.stdout = sys.stderr + + +SIMULATION_SIZE_MESH_WVL_FACTOR = 3.5 +SIMULATION_HEIGHT_WVL_FACTOR = 5 + +SPHERE_OFFSET_MAX_MESH_WVL_FACTOR = 0.25 +SPHERE_MIN_RADIUS_MESH_WVL_FACTOR = 0.3 +SPHERE_MAX_RADIUS_MESH_WVL_FACTOR = 0.4 + +FD_STEP_MESH_WVL_FACTOR = 1.0 / 75.0 + + +def get_sim_geometry(mesh_wvl_um): + return td.Box( + size=( + SIMULATION_SIZE_MESH_WVL_FACTOR * mesh_wvl_um, + SIMULATION_SIZE_MESH_WVL_FACTOR * mesh_wvl_um, + SIMULATION_HEIGHT_WVL_FACTOR * mesh_wvl_um, + ), + center=(0, 0, 0), + ) + + +def make_base_sim( + mesh_wvl_um, + adj_wvl_um, + pw_angle_deg, + monitor_bg_index=1.0, + run_time=2e-11, +): + sim_geometry = get_sim_geometry(mesh_wvl_um) + sim_size_um = sim_geometry.size + sim_center_um = sim_geometry.center + + src_size = sim_size_um[0:2] + (0,) + + wl_min_src_um = 0.9 * adj_wvl_um + wl_max_src_um = 1.1 * adj_wvl_um + + fwidth_src = td.C_0 * ((1.0 / wl_min_src_um) - (1.0 / wl_max_src_um)) + freq0 = td.C_0 / adj_wvl_um + + pulse = td.GaussianPulse(freq0=freq0, fwidth=fwidth_src) + + src = td.PlaneWave( + center=(sim_center_um[0], sim_center_um[1], -2.0), + size=[td.inf, td.inf, 0], + source_time=pulse, + direction="+", + angle_theta=(pw_angle_deg * np.pi / 180.0), + ) + + boundary_spec = td.BoundarySpec( + x=td.Boundary.pml(), + y=td.Boundary.pml(), + z=td.Boundary.pml(), + ) + + field_monitor = td.FieldMonitor( + center=( + sim_center_um[0], + sim_center_um[1], + mesh_wvl_um / 1.5, + ), + size=(mesh_wvl_um, mesh_wvl_um, 0), + name="monitor_fields", + freqs=[freq0], + ) + + monitor_index_block = td.Box( + center=(sim_center_um[0], sim_center_um[1], 0.25 * sim_size_um[2] + mesh_wvl_um), + size=(*tuple(2 * size for size in sim_size_um[0:2]), mesh_wvl_um + 0.5 * sim_size_um[2]), + ) + monitor_index_block_structure = td.Structure( + geometry=monitor_index_block, medium=td.Medium(permittivity=monitor_bg_index**2) + ) + + sim_base = td.Simulation( + center=sim_center_um, + size=sim_size_um, + grid_spec=td.GridSpec.auto( + min_steps_per_wvl=30, + wavelength=mesh_wvl_um, + ), + structures=[monitor_index_block_structure], + sources=[src], + monitors=[field_monitor], + run_time=run_time, + boundary_spec=boundary_spec, + subpixel=True, + ) + + return sim_base + + +def vjp_sphere(sphere, derivative_info): + max_frequency = np.max(derivative_info.frequencies) + min_wvl = td.C_0 / max_frequency + + step_size = min_wvl / 20.0 + + ps_paths = set() + ps_paths.update({("permittivity",)}) + + update_kwargs = { + "paths": list(ps_paths), + "deep": False, + } + + def finite_difference_gradient(perturb_up, perturb_down, derivative_info_): + eps_up = derivative_info.updated_epsilon(perturb_up) + eps_down = derivative_info.updated_epsilon(perturb_down) + eps_grad = (eps_up - eps_down) / (2 * step_size) + + derivative_info_custom_medium = derivative_info_.updated_copy(**update_kwargs) + + custom_medium = td.CustomMedium(permittivity=xr.ones_like(eps_grad.isel(f=0, drop=True))) + vjps_custom_medium = custom_medium._compute_derivatives(derivative_info_custom_medium) + + total_grad = np.real(np.sum(eps_grad.sum("f").data * vjps_custom_medium[("permittivity",)])) + + return total_grad + + vjps = {} + for path in derivative_info.paths: + if path[0:2] == ( + "geometry", + "radius", + ): + sphere_up = sphere.updated_copy(radius=sphere.radius + step_size) + sphere_down = sphere.updated_copy(radius=sphere.radius - step_size) + vjps[path] = finite_difference_gradient(sphere_up, sphere_down, derivative_info) + elif path[0:2] == ("geometry", "center"): + if len(path) == 2: + center_indices = (0, 1, 2) + else: + _, center_index = path[1:] + center_indices = [center_index] + + vjp_result = [] + for center_index in center_indices: + center_up = list(sphere.center) + center_down = list(sphere.center) + + center_up[center_index] += step_size + center_down[center_index] -= step_size + + sphere_up = sphere.updated_copy(center=center_up) + sphere_down = sphere.updated_copy(center=center_down) + + vjp_result.append( + finite_difference_gradient(sphere_up, sphere_down, derivative_info) + ) + + vjps[path] = vjp_result if len(path) == 2 else vjp_result[0] + + return vjps + + +def create_objective_function(geometry, create_sim_base, eval_fn, run_fn, sim_path_dir): + def objective(sphere_parameters_lists): + sim_base = create_sim_base() + + simulation_dict = {} + for idx, sphere_parameters in enumerate(sphere_parameters_lists): + sphere_structure = td.Structure( + geometry=td.Sphere(center=sphere_parameters[0:3], radius=sphere_parameters[3]), + medium=td.Medium(permittivity=ADJOINT_SPHERE_PERMITTIVITY), + ) + + sim_with_sphere = sim_base.updated_copy( + structures=(*sim_base.structures, sphere_structure) + ) + + simulation_dict[f"numerical_custom_vjp_testing_{idx}"] = sim_with_sphere.copy() + + custom_vjp_single = CustomVJPConfig( + structure_index=1, + compute_derivatives=vjp_sphere, + ) + + assert (run_fn == "run_custom") or (run_fn == "run_async_custom"), ( + "Unrecognized run function!" + ) + if run_fn == "run_custom": + sim_data = {} + for key, sim_val in simulation_dict.items(): + sim_data[key] = run_custom( + sim_val, + local_gradient=LOCAL_GRADIENT, + verbose=VERBOSE, + custom_vjp=custom_vjp_single, + ) + elif run_fn == "run_async_custom": + sim_data = run_async_custom( + simulation_dict, + path_dir=sim_path_dir, + local_gradient=LOCAL_GRADIENT, + verbose=VERBOSE, + custom_vjp=custom_vjp_single, + ) + + objective_vals = [] + for idx in range(len(sphere_parameters_lists)): + objective_vals.append(eval_fn(sim_data[f"numerical_custom_vjp_testing_{idx}"])) + + if len(sphere_parameters_lists) == 1: + return objective_vals[0] + + return objective_vals + + return objective + + +def make_eval_fns(): + def transmission(sim_data): + total = 0.0 + + return np.sum(np.abs(sim_data["monitor_fields"].flux.data) ** 2) + + eval_fns = [transmission] + eval_fn_names = ["transmission"] + + return eval_fns, eval_fn_names + + +background_indices = [1.0] +mesh_wvls_um = [1.5] +adj_wvls_um = [1.5] + +orders_x = [(1,)] +orders_y = [(0,)] +polarizations = ["p"] + +pw_angles_deg = [0.0] + +run_functions = ["run_custom", "run_async_custom"] + +test_parameters = [] + +test_number = 0 +for idx in range(len(mesh_wvls_um)): + mesh_wvl_um = mesh_wvls_um[idx] + adj_wvl_um = adj_wvls_um[idx] + + eval_fns, eval_fn_names = make_eval_fns() + + for pw_angle_deg in pw_angles_deg: + for monitor_bg_index in background_indices: + for eval_fn_idx, eval_fn in enumerate(eval_fns): + for run_fn in run_functions: + test_parameters.append( + { + "mesh_wvl_um": mesh_wvl_um, + "adj_wvl_um": adj_wvl_um, + "monitor_bg_index": monitor_bg_index, + "pw_angle_deg": pw_angle_deg, + "eval_fn": eval_fn, + "eval_fn_name": eval_fn_names[eval_fn_idx], + "run_fn": run_fn, + "test_number": test_number, + } + ) + + test_number += 1 + + +@pytest.mark.numerical +@pytest.mark.parametrize("test_parameters", test_parameters) +def test_finite_difference_custom_vjp(test_parameters, rng, numerical_case_dir): + """Test a variety of autograd permittivity gradients for DiffractionData by""" + """comparing them to numerical finite difference.""" + + ( + mesh_wvl_um, + adj_wvl_um, + monitor_bg_index, + pw_angle_deg, + eval_fn, + eval_fn_name, + run_fn, + test_number, + ) = operator.itemgetter( + "mesh_wvl_um", + "adj_wvl_um", + "monitor_bg_index", + "pw_angle_deg", + "eval_fn", + "eval_fn_name", + "run_fn", + "test_number", + )(test_parameters) + + sim_geometry = get_sim_geometry(mesh_wvl_um) + + dim_um = mesh_wvl_um + thickness_um = 0.5 * mesh_wvl_um + block = td.Box( + center=(sim_geometry.center[0], sim_geometry.center[1], 0), + size=(dim_um, dim_um, thickness_um), + ) + + sim_path_dir = numerical_case_dir / "simulations" / f"test{test_number}" + sim_path_dir.mkdir(parents=True, exist_ok=True) + + objective = create_objective_function( + block, + lambda mesh_wvl_um=mesh_wvl_um, + adj_wvl_um=adj_wvl_um, + pw_angle_deg=pw_angle_deg, + monitor_bg_index=monitor_bg_index: make_base_sim( + mesh_wvl_um=mesh_wvl_um, + adj_wvl_um=adj_wvl_um, + pw_angle_deg=pw_angle_deg, + monitor_bg_index=monitor_bg_index, + ), + eval_fn, + run_fn, + sim_path_dir=str(sim_path_dir), + ) + + obj_val_and_grad = ag.value_and_grad(objective) + + sphere_init = [ + *rng.uniform( + low=-SPHERE_OFFSET_MAX_MESH_WVL_FACTOR * mesh_wvl_um, + high=SPHERE_OFFSET_MAX_MESH_WVL_FACTOR * mesh_wvl_um, + size=2, + ), + 0.0, + *rng.uniform( + low=SPHERE_MIN_RADIUS_MESH_WVL_FACTOR * mesh_wvl_um, + high=SPHERE_MAX_RADIUS_MESH_WVL_FACTOR * mesh_wvl_um, + size=1, + ), + ] + + test_results = np.zeros((2, len(sphere_init))) + + obj, adj_grad = obj_val_and_grad([sphere_init]) + adj_grad = np.squeeze(np.array(adj_grad)) + + # empirical step size from running other finite difference tests for field + # cases with permittivity + fd_step = FD_STEP_MESH_WVL_FACTOR * mesh_wvl_um + + all_spheres = [] + # pattern_dot_adj_gradient = np.zeros(len(sphere_init)) + + for fd_idx in range(len(sphere_init)): + sphere_up = sphere_init.copy() + sphere_down = sphere_init.copy() + + sphere_up[fd_idx] += fd_step + sphere_down[fd_idx] -= fd_step + + all_spheres.append(sphere_up) + all_spheres.append(sphere_down) + + all_obj = objective(all_spheres) + + fd_grad = np.zeros(len(sphere_init)) + for fd_idx in range(len(sphere_init)): + obj_up_location = 2 * fd_idx + obj_down_location = 2 * fd_idx + 1 + + fd_grad[fd_idx] = (all_obj[obj_up_location] - all_obj[obj_down_location]) / (2 * fd_step) + + rms_error = np.linalg.norm(fd_grad - adj_grad) + fd_mag = np.linalg.norm(fd_grad) + adj_mag = np.linalg.norm(adj_grad) + + dot = np.sum((fd_grad / fd_mag) * (adj_grad / adj_mag)) + overlap_deg = np.arccos(dot) * 180.0 / np.pi + + print("\n" * 3) + print("-" * 20) + print(f"Numerical test #{test_number}") + print(f"Mesh and adjoint wavelengths: {mesh_wvl_um}, {adj_wvl_um}") + print(f"Input plane wave angle (deg): {pw_angle_deg}") + print(f"Background index for monitor: {monitor_bg_index}") + print(f"Eval function: {eval_fn_name}") + print(f"RMS Error: {rms_error}") + print(f"Gradient overlap (deg): {overlap_deg}") + print(f"FD, Adj magnitudes: {fd_mag}, {adj_mag}") + print("-" * 20) + print("\n" * 3) + + test_results[SAVE_FD_LOC, :] = fd_grad + test_results[SAVE_ADJ_LOC, :] = adj_grad + + save_idx = test_number + 1 + save_path = None + if SAVE_FD_ADJ_DATA: + results_dir = numerical_case_dir / NUMERICAL_RESULTS_SUBDIR + results_dir.mkdir(parents=True, exist_ok=True) + save_path = results_dir / f"results_{save_idx}.npy" + + try: + assert overlap_deg < OVERLAP_ERROR_THRESHOLD_DEG, ( + "Adjoint and finite difference gradients misaligned." + ) + finally: + if save_path is not None: + np.save(save_path, test_results) + + test_number += 1 + + if PLOT_FD_ADJ_COMPARISON: + plt.plot(adj_grad, color="g", linewidth=2.0) + plt.plot(fd_grad, color="b", linewidth=1.5, linestyle="--") + plt.title(f"Gradient for objective: {eval_fn_name}") + plt.legend(["Adjoint", "Finite difference"]) + plt.xlabel("Sample number") + plt.ylabel("Gradient value") + plt.show() diff --git a/tests/test_components/autograd/numerical/test_autograd_periodic_numerical.py b/tests/test_components/autograd/numerical/test_autograd_periodic_numerical.py index 1ef3607ad2..2a953e1863 100644 --- a/tests/test_components/autograd/numerical/test_autograd_periodic_numerical.py +++ b/tests/test_components/autograd/numerical/test_autograd_periodic_numerical.py @@ -122,7 +122,6 @@ def make_base_sim( ) else: diffraction_monitor = td.DiffractionMonitor( - # center=(0, 0, -0.35 * sim_size_um[2]), center=(sim_center_um[0], sim_center_um[1], -0.35 * sim_size_um[2]), size=(np.inf, np.inf, 0), name="monitor_diffraction", diff --git a/tests/test_components/autograd/test_autograd.py b/tests/test_components/autograd/test_autograd.py index 8621eaded7..63340110da 100644 --- a/tests/test_components/autograd/test_autograd.py +++ b/tests/test_components/autograd/test_autograd.py @@ -3,6 +3,7 @@ import copy import cProfile +import re import typing import warnings from importlib import reload @@ -28,8 +29,12 @@ from tidy3d.config import config from tidy3d.exceptions import AdjointError from tidy3d.plugins.polyslab import ComplexPolySlab +from tidy3d.plugins.smatrix import ComponentModeler, Port +from tidy3d.plugins.smatrix.run import _run_local from tidy3d.web import run, run_async from tidy3d.web.api.autograd import autograd as autograd_module +from tidy3d.web.api.autograd.autograd import run_async_custom, run_custom +from tidy3d.web.api.autograd.types import CustomVJPConfig from ...utils import SIM_FULL, AssertLogLevel, run_emulated, tracer_arr @@ -101,6 +106,7 @@ def _make_di(paths, freq): ), bounds_intersect=((-1, -1, -1), (1, 1, 1)), simulation_bounds=((-2, -2, -2), (2, 2, 2)), + updated_epsilon=None, ) @@ -116,6 +122,7 @@ def _make_di(paths, freq): IS_3D = False POLYSLAB_AXIS = 2 +POLYSLAB_SELECT_VERTICES = 0 # angle of the measurement waveguide ROT_ANGLE_WG = 0 * np.pi / 4 @@ -239,7 +246,6 @@ def emulated_run_fwd(simulation, task_name, **run_kwargs) -> td.SimulationData: def emulated_run_bwd(simulation, task_name, **run_kwargs) -> td.SimulationData: """What gets called instead of ``web/api/autograd/autograd.py::_run_tidy3d_bwd``.""" - task_name_fwd = "".join(task_name.partition("_adjoint")[:-2]) # run the adjoint sim @@ -259,6 +265,7 @@ def emulated_run_bwd(simulation, task_name, **run_kwargs) -> td.SimulationData: sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, sim_fields_keys=sim_fields_keys, + custom_vjp=None, ) return traced_fields_vjp @@ -306,7 +313,9 @@ def emulated_run_async_bwd(simulations, **run_kwargs) -> td.SimulationData: return emulated_run_fwd, emulated_run_bwd -def make_structures(params: anp.ndarray) -> dict[str, td.Structure]: +def make_structures( + params: anp.ndarray, polyslab_axis: int = POLYSLAB_AXIS +) -> dict[str, td.Structure]: """Make a dictionary of the structures given the parameters.""" np.random.seed(0) @@ -406,8 +415,7 @@ def make_structures(params: anp.ndarray) -> dict[str, td.Structure]: matrix = np.random.random((N_PARAMS,)) - 0.5 params_01 = 0.5 * (anp.tanh(matrix @ params / 3) + 1) - free_param = "vertices" if POLYSLAB_AXIS == 0 else "slab_bounds" - + free_param = "vertices" if polyslab_axis == POLYSLAB_SELECT_VERTICES else "slab_bounds" if free_param == "vertices": radii = 0.5 + 0.5 * params_01 slab_bounds = (-0.5, 0.5) @@ -415,8 +423,6 @@ def make_structures(params: anp.ndarray) -> dict[str, td.Structure]: radii = 1.0 shift = 0.1 * params_01 slab_bounds = (-0.5 + shift, 0.5 + shift) - # slab_bounds = (-0.5 + shift, 0.5) - # slab_bounds = (-0.5, 0.5 + shift) phis = 2 * anp.pi * anp.linspace(0, 1, NUM_VERTICES + 1)[:NUM_VERTICES] xs = radii * anp.cos(phis) @@ -427,7 +433,7 @@ def make_structures(params: anp.ndarray) -> dict[str, td.Structure]: geometry=td.PolySlab( vertices=vertices, slab_bounds=slab_bounds, - axis=POLYSLAB_AXIS, + axis=polyslab_axis, sidewall_angle=0.00, dilation=0.00, ), @@ -438,7 +444,7 @@ def make_structures(params: anp.ndarray) -> dict[str, td.Structure]: geometry=td.PolySlab( vertices=vertices, slab_bounds=slab_bounds, - axis=POLYSLAB_AXIS, + axis=polyslab_axis, sidewall_angle=0.00, dilation=0.00, ), @@ -658,9 +664,6 @@ def plot_sim(sim: td.Simulation, plot_eps: bool = True) -> None: args = [("polyslab", "mode")] -# args = [("polyslab", "mode")] - - def get_functions(structure_key: str, monitor_key: str) -> dict[str, typing.Callable]: if structure_key == ALL_KEY: structure_keys = structure_keys_ @@ -681,10 +684,10 @@ def get_functions(structure_key: str, monitor_key: str) -> dict[str, typing.Call monitors.append(monitor_traced) monitor_pp_fns[monitor_key] = monitor_pp_fn - def make_sim(*args) -> td.Simulation: + def make_sim(*args, polyslab_axis=POLYSLAB_AXIS) -> td.Simulation: """Make the simulation with all of the fields.""" - structures_traced_dict = make_structures(*args) + structures_traced_dict = make_structures(*args, polyslab_axis=polyslab_axis) structures = list(SIM_BASE.structures) for structure_key in structure_keys: @@ -727,6 +730,538 @@ def test_polyslab_axis_ops(axis): basis_vecs = p.edge_basis_vectors(edges=edges) +def make_polyslab_custom_vjp(custom_vjp_val): + def polyslab_custom_vjp(polyslab, derivative_info): + vjps = {} + + for path in derivative_info.paths: + if path[0:2] == ("geometry", "vertices"): + vjps[path] = custom_vjp_val * np.ones(polyslab.vertices.shape) + elif path[0:2] == ("geometry", "slab_bounds"): + if len(path) == 3: + vjps[path] = (custom_vjp_val, custom_vjp_val)[path[2]] + else: + vjps[path] = (custom_vjp_val, custom_vjp_val) + + return vjps + + return polyslab_custom_vjp + + +@pytest.mark.parametrize("structure_key, monitor_key", [("polyslab", "mode")]) +@pytest.mark.parametrize("polyslab_axis", [0, 1, 2]) +@pytest.mark.parametrize("use_run_async", [True, False]) +@pytest.mark.parametrize("use_task_names", [True, False]) +@pytest.mark.parametrize( + "use_single_custom_vjp, specify_custom_vjp_by_type", + [(True, True), (True, False), (False, False)], +) +@pytest.mark.parametrize("local_gradient", [True, False]) +def test_autograd_custom_vjp( + use_emulated_run, + structure_key, + monitor_key, + polyslab_axis, + use_run_async, + use_task_names, + use_single_custom_vjp, + specify_custom_vjp_by_type, + local_gradient, +): + """Test that we can override a vjp with a user defined function.""" + + fn_dict = get_functions(structure_key, monitor_key) + make_sim = fn_dict["sim"] + postprocess = fn_dict["postprocess"] + + task_names = {"test_a", "adjoint", "_test"} + + def make_objective(custom_vjp_val): + polyslab_custom_vjp = make_polyslab_custom_vjp(custom_vjp_val) + + structure_index = td.PolySlab if specify_custom_vjp_by_type else 1 + + path_key = ( + "geometry", + "vertices" if polyslab_axis == POLYSLAB_SELECT_VERTICES else "slab_bounds", + ) + + custom_vjp_tuple = ( + CustomVJPConfig( + structure_index=structure_index, + compute_derivatives=polyslab_custom_vjp, + path_key=path_key, + ), + ) + + custom_vjp_single = CustomVJPConfig( + structure_index=structure_index, + compute_derivatives=polyslab_custom_vjp, + ) + + custom_vjp_element = custom_vjp_single if use_single_custom_vjp else custom_vjp_tuple + + def objective(*args): + if use_task_names: + sims = { + task_name: make_sim(*args, polyslab_axis=polyslab_axis) + for task_name in task_names + } + custom_vjp = dict.fromkeys(sims.keys(), custom_vjp_element) + else: + sims = [make_sim(*args, polyslab_axis=polyslab_axis)] * len(task_names) + custom_vjp = [custom_vjp_element] * len(task_names) + batch_data = {} + if use_run_async: + batch_data = run_async_custom( + sims, custom_vjp=custom_vjp, local_gradient=local_gradient + ) + else: + if use_task_names: + for task_name, sim in sims.items(): + batch_data[task_name] = run_custom( + sim, + task_name, + custom_vjp=custom_vjp[task_name], + local_gradient=local_gradient, + ) + else: + for idx, sim in enumerate(sims): + batch_data[idx] = run_custom( + sim, custom_vjp=custom_vjp[idx], local_gradient=local_gradient + ) + + value = 0.0 + + for _, sim_data in batch_data.items(): + value += postprocess(sim_data) + return value + + return objective + + custom_vjp_val = 1.0 + custom_vjp_val_scale = 10.0 * custom_vjp_val + + if not local_gradient: + with pytest.raises( + td.exceptions.AdjointError, + match="custom_vjp specified for a remote gradient not supported.", + ): + val, grad = ag.value_and_grad(make_objective(custom_vjp_val))(params0) + else: + val, grad = ag.value_and_grad(make_objective(custom_vjp_val))(params0) + val_scale, grad_scale = ag.value_and_grad(make_objective(custom_vjp_val_scale))(params0) + + assert np.isclose( + np.sum(np.abs(grad * (custom_vjp_val_scale / custom_vjp_val) - grad_scale)), 0.0 + ), "Gradients were not set by the user vjp" + + +@pytest.mark.parametrize("structure_key, monitor_key", [("polyslab", "mode")]) +@pytest.mark.parametrize("polyslab_axis", [0, 1, 2]) +@pytest.mark.parametrize("use_run_async", [True, False]) +@pytest.mark.parametrize("use_task_names", [True, False]) +@pytest.mark.parametrize("use_single_custom_vjp", [True, False]) +def test_autograd_custom_vjp_selective( + use_emulated_run, + structure_key, + monitor_key, + polyslab_axis, + use_run_async, + use_task_names, + use_single_custom_vjp, +): + """Test that we can selectively override a vjp with a user defined function that covers some of, but not all, gradient keys.""" + + fn_dict = get_functions(structure_key, monitor_key) + make_sim = fn_dict["sim"] + postprocess = fn_dict["postprocess"] + + task_names = {"test_a", "adjoint", "_test"} + + def make_objective(custom_vjp_val): + polyslab_custom_vjp = make_polyslab_custom_vjp(custom_vjp_val) + + custom_vjp_tuple = ( + CustomVJPConfig( + structure_index=1, + compute_derivatives=polyslab_custom_vjp, + path_key=( + ( + "geometry", + "vertices", + ) + ), + ), + ) + + custom_vjp_single = CustomVJPConfig( + structure_index=1, + compute_derivatives=polyslab_custom_vjp, + path_key=( + ( + "geometry", + "vertices", + ) + ), + ) + + custom_vjp_element = custom_vjp_single if use_single_custom_vjp else custom_vjp_tuple + if not (polyslab_axis == POLYSLAB_SELECT_VERTICES): + custom_vjp_element = None + + def objective(*args): + if custom_vjp_element: + if use_task_names: + custom_vjp = dict.fromkeys(task_names, custom_vjp_element) + else: + sims = [make_sim(*args, polyslab_axis=polyslab_axis)] * len(task_names) + custom_vjp = [custom_vjp_element] * len(task_names) + else: + custom_vjp = None + + if use_task_names: + sims = { + task_name: make_sim(*args, polyslab_axis=polyslab_axis) + for task_name in task_names + } + else: + sims = [make_sim(*args, polyslab_axis=polyslab_axis)] * len(task_names) + + batch_data = {} + if use_run_async: + batch_data = run_async_custom(sims, custom_vjp=custom_vjp, local_gradient=True) + else: + if use_task_names: + for task_name, sim in sims.items(): + batch_data[task_name] = run_custom( + sim, + task_name, + custom_vjp=custom_vjp and custom_vjp[task_name], + local_gradient=True, + ) + else: + for idx, sim in enumerate(sims): + batch_data[idx] = run_custom( + sim, custom_vjp=custom_vjp and custom_vjp[idx], local_gradient=True + ) + + value = 0.0 + + for _, sim_data in batch_data.items(): + value += postprocess(sim_data) + return value + + return objective + + custom_vjp_val = 1.0 + custom_vjp_val_scale = 10.0 * custom_vjp_val + + val, grad = ag.value_and_grad(make_objective(custom_vjp_val))(params0) + val_scale, grad_scale = ag.value_and_grad(make_objective(custom_vjp_val_scale))(params0) + + if polyslab_axis == POLYSLAB_SELECT_VERTICES: + assert np.isclose( + np.sum(np.abs(grad * (custom_vjp_val_scale / custom_vjp_val) - grad_scale)), 0.0 + ), "Gradients were not set by the user vjp when they should have been" + else: + assert not np.isclose( + np.sum(np.abs(grad * (custom_vjp_val_scale / custom_vjp_val) - grad_scale)), 0.0 + ), "Gradients were set by the user vjp when they should not have been" + + +@pytest.mark.parametrize("structure_key, monitor_key", [("polyslab", "mode")]) +@pytest.mark.parametrize("use_run_async", [True, False]) +def test_autograd_error_custom_vjp_indices_and_paths( + use_emulated_run, + structure_key, + monitor_key, + use_run_async, +): + """Test error checking for custom_vjp when structure index or path does not + exist in traced vjp structure paths. + """ + + fn_dict = get_functions(structure_key, monitor_key) + make_sim = fn_dict["sim"] + postprocess = fn_dict["postprocess"] + + task_names = {"test_a", "adjoint", "_test"} + custom_vjp_paths = { + "test_a": ("geometry", "vertices"), + "adjoint": ("geometry", "slab_bounds"), + "_test": ("geometry", "verslab"), + } + custom_vjp_structure_indices = {"test_a": 1, "adjoint": 5, "_test": 1} + + def make_objective(custom_vjp_val): + polyslab_custom_vjp = make_polyslab_custom_vjp(custom_vjp_val) + + custom_vjp_dict = {} + custom_vjp_dict_bad_structure_index = {} + custom_vjp_dict_bad_structure_path = {} + for task_name in task_names: + custom_vjp_dict[task_name] = CustomVJPConfig( + structure_index=custom_vjp_structure_indices[task_name], + compute_derivatives=polyslab_custom_vjp, + path_key=custom_vjp_paths[task_name], + ) + + custom_vjp_dict_bad_structure_index[task_name] = CustomVJPConfig( + structure_index=custom_vjp_structure_indices[task_name], + compute_derivatives=polyslab_custom_vjp, + path_key=("geometry", "vertices"), + ) + + custom_vjp_dict_bad_structure_path[task_name] = CustomVJPConfig( + structure_index=1, + compute_derivatives=polyslab_custom_vjp, + path_key=custom_vjp_paths[task_name], + ) + + def objective(*args): + sims = {task_name: make_sim(*args, polyslab_axis=0) for task_name in task_names} + + batch_data = {} + if use_run_async: + with pytest.raises( + td.exceptions.AdjointError, + match=f"CustomVJPConfig structure index {custom_vjp_structure_indices['adjoint']} not in traced structure indices.", + ): + batch_data = run_async_custom( + sims, custom_vjp=custom_vjp_dict_bad_structure_index, local_gradient=True + ) + with pytest.raises( + td.exceptions.AdjointError, + match=re.escape( + f"CustomVJPConfig path {custom_vjp_paths['_test']} not in traced structure paths." + ), + ): + batch_data = run_async_custom( + sims, custom_vjp=custom_vjp_dict_bad_structure_path, local_gradient=True + ) + else: + for task_name, sim in sims.items(): + if task_name == "adjoint": + with pytest.raises( + td.exceptions.AdjointError, + match=f"CustomVJPConfig structure index {custom_vjp_structure_indices[task_name]} not in traced structure indices.", + ): + batch_data[task_name] = run_custom( + sim, + task_name, + custom_vjp=custom_vjp_dict[task_name], + local_gradient=True, + ) + + elif task_name == "_test": + with pytest.raises( + td.exceptions.AdjointError, + match=re.escape( + f"CustomVJPConfig path {custom_vjp_paths[task_name]} not in traced structure paths." + ), + ): + batch_data[task_name] = run_custom( + sim, + task_name, + custom_vjp=custom_vjp_dict[task_name], + local_gradient=True, + ) + else: + batch_data[task_name] = run_custom( + sim, + task_name, + custom_vjp=custom_vjp_dict[task_name], + local_gradient=True, + ) + + value = 0.0 + + for _, sim_data in batch_data.items(): + value += postprocess(sim_data) + return value + + return objective + + custom_vjp_val = 1.0 + val, grad = ag.value_and_grad(make_objective(custom_vjp_val))(params0) + + +@pytest.mark.parametrize("structure_key, monitor_key", [("polyslab", "mode")]) +@pytest.mark.parametrize("use_run_async", [True, False]) +@pytest.mark.parametrize("error_type", ["num_args", "arg_name"]) +def test_autograd_error_custom_vjp_function( + use_emulated_run, + structure_key, + monitor_key, + use_run_async, + error_type, +): + """Test error checking for custom_vjp when compute_derivatives function signature is wrong.""" + + fn_dict = get_functions(structure_key, monitor_key) + make_sim = fn_dict["sim"] + postprocess = fn_dict["postprocess"] + + task_names = {"test_a", "adjoint", "_test"} + + def polyslab_custom_vjp_bad_num_args(polyslab, derivative_info, extra_arg): + return {} + + def polyslab_custom_vjp_bad_arg_name(polyslab, d_info): + return {} + + choose_compute_derivatives = ( + polyslab_custom_vjp_bad_num_args + if (error_type == "num_args") + else polyslab_custom_vjp_bad_arg_name + ) + error_msg = ( + ( + "CustomVJPConfig compute_derivatives function should accept two arguments and it currently accepts 3 arguments." + ) + if (error_type == "num_args") + else ( + "CustomVJPConfig compute_derivatives function second argument name is d_info but it should be derivative_info." + ) + ) + + def make_objective(custom_vjp_val): + polyslab_custom_vjp = make_polyslab_custom_vjp(custom_vjp_val) + + custom_vjp = CustomVJPConfig( + structure_index=td.PolySlab, compute_derivatives=choose_compute_derivatives + ) + + def objective(*args): + sims = {task_name: make_sim(*args, polyslab_axis=0) for task_name in task_names} + + batch_data = {} + if use_run_async: + with pytest.raises(td.exceptions.AdjointError, match=error_msg): + batch_data = run_async_custom(sims, custom_vjp=custom_vjp, local_gradient=True) + else: + for task_name, sim in sims.items(): + with pytest.raises(td.exceptions.AdjointError, match=error_msg): + batch_data[task_name] = run_custom( + sim, + task_name, + custom_vjp=custom_vjp, + local_gradient=True, + ) + + value = 0.0 + + for _, sim_data in batch_data.items(): + value += postprocess(sim_data) + return value + + return objective + + custom_vjp_val = 1.0 + val, grad = ag.value_and_grad(make_objective(custom_vjp_val))(params0) + + +@pytest.mark.parametrize("structure_key, monitor_key", [("polyslab", "mode")]) +@pytest.mark.parametrize("polyslab_axis", [0, 1, 2]) +@pytest.mark.parametrize("run_function", [_run_local, run_custom]) +@pytest.mark.parametrize( + "use_single_custom_vjp, specify_custom_vjp_by_type", + [(True, True), (True, False), (False, False)], +) +@pytest.mark.parametrize("local_gradient", [True, False]) +def test_autograd_cm_custom_vjp( + use_emulated_run, + structure_key, + monitor_key, + polyslab_axis, + run_function, + use_single_custom_vjp, + specify_custom_vjp_by_type, + local_gradient, +): + """Test that we can override a vjp with a user defined function in component modeler simulations.""" + + fn_dict = get_functions(structure_key, monitor_key) + make_sim = fn_dict["sim"] + postprocess = fn_dict["postprocess"] + + def make_objective(custom_vjp_val): + polyslab_custom_vjp = make_polyslab_custom_vjp(custom_vjp_val) + + structure_index = td.PolySlab if specify_custom_vjp_by_type else 1 + + path_key = ( + "geometry", + "vertices" if polyslab_axis == POLYSLAB_SELECT_VERTICES else "slab_bounds", + ) + + custom_vjp_tuple = ( + CustomVJPConfig( + structure_index=structure_index, + compute_derivatives=polyslab_custom_vjp, + path_key=path_key, + ), + ) + + custom_vjp_single = CustomVJPConfig( + structure_index=structure_index, + compute_derivatives=polyslab_custom_vjp, + ) + + custom_vjp_element = custom_vjp_single if use_single_custom_vjp else custom_vjp_tuple + + def objective(*args): + base_sim = make_sim(*args, polyslab_axis=polyslab_axis) + find_mode_monitors = [ + monitor for monitor in base_sim.monitors if isinstance(monitor, td.ModeMonitor) + ] + + select_mode_monitor = find_mode_monitors[0] + + stripped_sim = base_sim.updated_copy(sources=[], monitors=[]) + + input_port = Port( + center=select_mode_monitor.center, + size=select_mode_monitor.size, + mode_spec=select_mode_monitor.mode_spec, + direction="-", + name="input_port", + ) + + modeler = ComponentModeler( + simulation=stripped_sim, + ports=[input_port], + freqs=select_mode_monitor.freqs, + ) + + smatrix = run_function( + modeler, + custom_vjp=custom_vjp_element, + local_gradient=local_gradient, + ) + return np.sum(np.abs(smatrix.smatrix().values) ** 2) + + return objective + + custom_vjp_val = 1.0 + custom_vjp_val_scale = 10.0 * custom_vjp_val + + if not local_gradient: + with pytest.raises( + td.exceptions.AdjointError, + match="custom_vjp specified for a remote gradient not supported.", + ): + val, grad = ag.value_and_grad(make_objective(custom_vjp_val))(params0) + else: + val, grad = ag.value_and_grad(make_objective(custom_vjp_val))(params0) + val_scale, grad_scale = ag.value_and_grad(make_objective(custom_vjp_val_scale))(params0) + + assert np.isclose( + np.sum(np.abs(grad * (custom_vjp_val_scale / custom_vjp_val) - grad_scale)), 0.0 + ), "Gradients were not set by the user vjp" + + @pytest.mark.skipif(not RUN_NUMERICAL, reason="Numerical gradient tests runs through web API.") @pytest.mark.parametrize("structure_key, monitor_key", (_NUMERICAL_COMBINATION,)) def test_autograd_numerical(structure_key, monitor_key): @@ -1847,6 +2382,7 @@ def J(eps): ), bounds_intersect=((-1, -1, -1), (1, 1, 1)), simulation_bounds=((-2, -2, -2), (2, 2, 2)), + updated_epsilon=None, ) grads_computed = pr._compute_derivatives(derivative_info=info) @@ -1889,6 +2425,7 @@ def test_adaptive_spacing(eps_real): eps_inf_structure={}, bounds_intersect=((-1, -1, -1), (1, 1, 1)), simulation_bounds=((-2, -2, -2), (2, 2, 2)), + updated_epsilon=None, ) with AssertLogLevel("WARNING", contains_str="Based on the material, the adaptive spacing"): @@ -1919,6 +2456,7 @@ def test_cylinder_discretization(eps_real): eps_inf_structure={}, bounds_intersect=((-1, -1, -1), (1, 1, 1)), simulation_bounds=((-2, -2, -2), (2, 2, 2)), + updated_epsilon=None, ) with AssertLogLevel( @@ -2000,6 +2538,7 @@ def J(eps): ), bounds_intersect=((-1, -1, -1), (1, 1, 1)), simulation_bounds=((-2, -2, -2), (2, 2, 2)), + updated_epsilon=None, ) grads_computed = pr._compute_derivatives(derivative_info=info) diff --git a/tests/test_components/autograd/test_autograd_custom_dispersive_vjps.py b/tests/test_components/autograd/test_autograd_custom_dispersive_vjps.py index 3f23a7e98f..6c07125338 100644 --- a/tests/test_components/autograd/test_autograd_custom_dispersive_vjps.py +++ b/tests/test_components/autograd/test_autograd_custom_dispersive_vjps.py @@ -48,6 +48,7 @@ def _deriv_info(freq): "eps_inf_structure": eps_inf, "bounds_intersect": ((-1, -1, -1), (1, 1, 1)), "simulation_bounds": ((-2, -2, -2), (2, 2, 2)), + "updated_epsilon": None, } diff --git a/tidy3d/components/autograd/derivative_utils.py b/tidy3d/components/autograd/derivative_utils.py index 7c36444687..f7b216526a 100644 --- a/tidy3d/components/autograd/derivative_utils.py +++ b/tidy3d/components/autograd/derivative_utils.py @@ -115,6 +115,9 @@ class DerivativeInfo: frequencies: ArrayLike """Frequencies at which the adjoint gradient should be computed.""" + updated_epsilon: Callable + """Function to return the permittivity upon geometry replacement in the simulation.""" + H_der_map: Optional[FieldData] = None """Magnetic field gradient map. Dataset where the field components ("Hx", "Hy", "Hz") store the multiplication diff --git a/tidy3d/components/geometry/primitives.py b/tidy3d/components/geometry/primitives.py index 667ef5cb1a..313c91a57a 100644 --- a/tidy3d/components/geometry/primitives.py +++ b/tidy3d/components/geometry/primitives.py @@ -42,6 +42,13 @@ class Sphere(base.Centered, base.Circular): >>> b = Sphere(center=(1,2,3), radius=2) """ + radius: TracedSize1D = pydantic.Field( + ..., + title="Radius", + description="Radius of geometry at the ``reference_plane``.", + units=MICROMETER, + ) + def inside( self, x: np.ndarray[float], y: np.ndarray[float], z: np.ndarray[float] ) -> np.ndarray[bool]: diff --git a/tidy3d/components/structure.py b/tidy3d/components/structure.py index 6bd11d8e9e..fa87a1341c 100644 --- a/tidy3d/components/structure.py +++ b/tidy3d/components/structure.py @@ -3,6 +3,7 @@ from __future__ import annotations import pathlib +import typing from collections import defaultdict from functools import cmp_to_key from os import PathLike @@ -357,8 +358,14 @@ def _make_adjoint_monitors( return mnt_fld, mnt_eps - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Compute adjoint gradients given the forward and adjoint fields""" + def _compute_derivatives( + self, + derivative_info: DerivativeInfo, + vjp_fns: typing.Optional[dict[tuple[str, ...], typing.Callable[..., typing.Any]]] = None, + ) -> AutogradFieldMap: + """Compute adjoint gradients given the forward and adjoint fields provided in derivative_info. + vjp_fns provide alternate derivative computation paths for the geometry or medium derivatives. + """ # generate a mapping from the 'medium', or 'geometry' tag to the list of fields for VJP structure_fields_map = defaultdict(list) @@ -380,8 +387,31 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField for med_or_geo, field_paths in structure_fields_map.items(): # grab derivative values {field_name -> vjp_value} med_or_geo_field = self.medium if med_or_geo == "medium" else self.geometry - info = derivative_info.updated_copy(paths=field_paths, deep=False) - derivative_values_map = med_or_geo_field._compute_derivatives(derivative_info=info) + + collect_paths_by_keys = {} + for path in field_paths: + if path[0] in collect_paths_by_keys: + collect_paths_by_keys[path[0]].append(path) + else: + collect_paths_by_keys[path[0]] = [path] + + derivative_values_map = {} + for path_key, paths in collect_paths_by_keys.items(): + info = derivative_info.updated_copy(paths=paths, deep=False) + + full_path = (med_or_geo, path_key) + if (vjp_fns is not None) and (full_path in vjp_fns): + full_paths = ((med_or_geo, *path) for path in paths) + info = derivative_info.updated_copy(paths=full_paths, deep=False) + + vjp = vjp_fns[full_path](med_or_geo_field, info) + vjp_strip_med_or_geo = {key[1:]: val for key, val in vjp.items()} + + derivative_values_map.update(vjp_strip_med_or_geo) + else: + derivative_values_map.update( + med_or_geo_field._compute_derivatives(derivative_info=info) + ) # construct map of {field path -> derivative value} for field_path, derivative_value in derivative_values_map.items(): diff --git a/tidy3d/plugins/smatrix/run.py b/tidy3d/plugins/smatrix/run.py index ff4b9d39f1..7481e57c85 100644 --- a/tidy3d/plugins/smatrix/run.py +++ b/tidy3d/plugins/smatrix/run.py @@ -1,8 +1,10 @@ from __future__ import annotations -from typing import Any +import json +import typing from tidy3d.components.data.index import SimulationDataMap +from tidy3d.exceptions import AdjointError from tidy3d.log import log from tidy3d.plugins.smatrix.component_modelers.modal import ModalComponentModeler from tidy3d.plugins.smatrix.component_modelers.terminal import TerminalComponentModeler @@ -11,6 +13,8 @@ from tidy3d.plugins.smatrix.data.terminal import TerminalComponentModelerData from tidy3d.plugins.smatrix.data.types import ComponentModelerDataType from tidy3d.web import Batch, BatchData +from tidy3d.web.api.autograd.autograd import expand_custom_vjp +from tidy3d.web.api.autograd.types import CustomVJPConfig DEFAULT_DATA_DIR = "." @@ -87,7 +91,7 @@ def compose_modeler_data_from_batch_data( def create_batch( modeler: ComponentModelerType, - **kwargs: Any, + **kwargs: typing.Any, ) -> Batch: """Create a simulation Batch from a component modeler. @@ -114,7 +118,8 @@ def create_batch( def _run_local( modeler: ComponentModelerType, path_dir: str = DEFAULT_DATA_DIR, - **kwargs: Any, + custom_vjp: typing.Optional[typing.Union[CustomVJPConfig, tuple[CustomVJPConfig]]] = None, + **kwargs: typing.Any, ) -> ComponentModelerDataType: """Execute the full simulation workflow for a given component modeler. @@ -129,6 +134,9 @@ def _run_local( The component modeler defining the simulations to be run. path_dir : str, optional The directory where the batch file will be saved. Defaults to ".". + custom_vjp : typing.Union[CustomVJPConfig, tuple[CustomVJPConfig]] = None + Specification of alternate gradient function for certain structures in the simulation. + This can be a single vjp configuration or multiple specified in a tuple. **kwargs Extra keyword arguments propagated to the Batch creation. @@ -143,7 +151,10 @@ def _run_local( from tidy3d.web.api.autograd import autograd as web_ag sims = modeler.sim_dict - if any(web_ag.is_valid_for_autograd(sim) for sim in sims.values()): + + should_use_autograd = any(web_ag.is_valid_for_autograd(sim) for sim in sims.values()) + + if should_use_autograd: if len(modeler.element_mappings) > 0: log.warning( "Element mappings are used to populate S-matrix values, but autograd gradients " @@ -159,7 +170,27 @@ def _run_local( kwargs.setdefault("simulation_type", "tidy3d_autograd_async") kwargs.setdefault("path_dir", path_dir) - sim_data_map = _run_async(simulations=sims, **kwargs) + local_gradient = kwargs.get("local_gradient", True) + + if not local_gradient: + if custom_vjp is not None: + raise AdjointError("custom_vjp specified for a remote gradient not supported.") + + if isinstance(custom_vjp, CustomVJPConfig): + custom_vjp = (custom_vjp,) + + if custom_vjp: + custom_vjp = dict.fromkeys(sims, custom_vjp) + + expanded_custom_vjp_dict = {} + for sim_key, custom_vjp_entry in custom_vjp.items(): + expanded_custom_vjp_dict[sim_key] = expand_custom_vjp(custom_vjp_entry, sims[sim_key]) + + sim_data_map = _run_async( + simulations=sims, + custom_vjp=expanded_custom_vjp_dict, + **kwargs, + ) return compose_modeler_data_from_batch_data(modeler=modeler, batch_data=sim_data_map) diff --git a/tidy3d/web/api/autograd/__init__.py b/tidy3d/web/api/autograd/__init__.py index e69de29bb2..77387b1fbc 100644 --- a/tidy3d/web/api/autograd/__init__.py +++ b/tidy3d/web/api/autograd/__init__.py @@ -0,0 +1,5 @@ +# from __future__ import annotations + +# from tidy3d.web.api.autograd.autograd import expand_custom_vjp + +# __all__ = ["expand_custom_vjp"] diff --git a/tidy3d/web/api/autograd/autograd.py b/tidy3d/web/api/autograd/autograd.py index c2e4eb965c..ff0cdd201f 100644 --- a/tidy3d/web/api/autograd/autograd.py +++ b/tidy3d/web/api/autograd/autograd.py @@ -1,7 +1,9 @@ # autograd wrapper for web functions from __future__ import annotations +import inspect import typing +from dataclasses import replace from os import PathLike from pathlib import Path from typing import Any @@ -12,6 +14,8 @@ import tidy3d as td from tidy3d.components.autograd import AutogradFieldMap from tidy3d.components.base import TRACED_FIELD_KEYS_ATTR +from tidy3d.components.geometry.utils import GeometryType +from tidy3d.components.medium import MediumType from tidy3d.components.types.workflow import WorkflowDataType, WorkflowType from tidy3d.config import config from tidy3d.exceptions import AdjointError @@ -50,6 +54,7 @@ from .io_utils import ( upload_sim_fields_keys as _upload_sim_fields_keys_impl, ) +from .types import CustomVJPConfig, SetupRunResult def _resolve_local_gradient(value: typing.Optional[bool]) -> bool: @@ -100,7 +105,104 @@ def is_valid_for_autograd_async(simulations: dict[str, td.Simulation]) -> bool: return True -def run( +def expand_custom_vjp( + custom_vjp: tuple[CustomVJPConfig, ...], simulation: td.Simulation +) -> tuple[CustomVJPConfig, ...]: + """Expand custom_vjp for entries where structure_index is a GeometryType or MediumType + into multiple entries tagged by an integer structure_index. + """ + expanded_custom_vjp = [] + geometry_types_seen = [] + medium_types_seen = [] + + custom_vjp_indices = [ + vjp_config.structure_index + for vjp_config in custom_vjp + if isinstance(vjp_config.structure_index, int) + ] + + allowed_classes_geometry = typing.get_args(GeometryType) + allowed_classes_medium = typing.get_args(MediumType) + + for vjp_config in custom_vjp: + if isinstance(vjp_config.structure_index, type) and issubclass( + vjp_config.structure_index, allowed_classes_geometry + ): + if vjp_config.structure_index in geometry_types_seen: + raise AdjointError( + f"custom_vjp assigned multiple times for geometry type {vjp_config.structure_index}" + ) + + geometry_types_seen.append(vjp_config.structure_index) + + for structure_idx, structure in enumerate(simulation.structures): + if isinstance(structure.geometry, vjp_config.structure_index) and ( + structure_idx not in custom_vjp_indices + ): + updated_vjp_config = replace(vjp_config, structure_index=structure_idx) + + expanded_custom_vjp.append(updated_vjp_config) + + elif isinstance(vjp_config.structure_index, type) and issubclass( + custom_vjp.structure_index, allowed_classes_medium + ): + if vjp_config.structure_index in medium_types_seen: + raise AdjointError( + f"custom_vjp multiple times for medium type {vjp_config.structure_index}" + ) + + medium_types_seen.append(vjp_config.structure_index) + + for structure_idx, structure in enumerate(simulation.structures): + if isinstance(structure.medium, vjp_config.structure_index) and ( + structure_idx not in custom_vjp_indices + ): + updated_vjp_config = replace(vjp_config, structure_index=structure_idx) + expanded_custom_vjp.append(updated_vjp_config) + + else: + expanded_custom_vjp.append(vjp_config) + + return tuple(expanded_custom_vjp) + + +def verify_custom_vjp( + custom_vjp: tuple[CustomVJPConfig, ...], traced_fields: AutogradFieldMap +) -> None: + """Check that the provided custom_vjp is targeting structure indices and vjp paths that exist + in the traced structures specified by traced_fields. Also check the function signature for the + compute_derivatives function to make sure it has the right number of arguments and the second + argument is named derivative_info. + """ + + custom_vjp_index_options = [full_path[1] for full_path in traced_fields] + custom_vjp_path_options = [full_path[2:4] for full_path in traced_fields] + + for vjp_config in custom_vjp: + sig = inspect.signature(vjp_config.compute_derivatives) + argument_names = list(sig.parameters.keys()) + + if not (len(argument_names) == 2): + raise AdjointError( + f"CustomVJPConfig compute_derivatives function should accept two arguments and it currently accepts {len(argument_names)} arguments." + ) + if not (argument_names[1] == "derivative_info"): + raise AdjointError( + f"CustomVJPConfig compute_derivatives function second argument name is {argument_names[1]} but it should be derivative_info." + ) + + if vjp_config.structure_index not in custom_vjp_index_options: + raise AdjointError( + f"CustomVJPConfig structure index {vjp_config.structure_index} not in traced structure indices." + ) + + if vjp_config.path_key and (vjp_config.path_key not in custom_vjp_path_options): + raise AdjointError( + f"CustomVJPConfig path {vjp_config.path_key} not in traced structure paths." + ) + + +def run_custom( simulation: WorkflowType, task_name: typing.Optional[str] = None, folder_name: str = "default", @@ -119,6 +221,7 @@ def run( pay_type: typing.Union[PayType, str] = PayType.AUTO, priority: typing.Optional[int] = None, lazy: typing.Optional[bool] = None, + custom_vjp: typing.Optional[typing.Union[CustomVJPConfig, tuple[CustomVJPConfig, ...]]] = None, ) -> WorkflowDataType: """ Submits a :class:`.Simulation` to server, starts running, monitors progress, downloads, @@ -166,6 +269,10 @@ def run( lazy: Optional[bool] = None Whether to return lazy data proxies. Defaults to ``False`` for single runs when unspecified, matching :func:`tidy3d.web.run`. + custom_vjp : typing.Optional[typing.Union[CustomVJPConfig, tuple[CustomVJPConfig, ...]]] = None + Specification of alternate gradient function for certain structures in the simulation. + This can be a single vjp configuration or multiple specified in a tuple. + Returns ------- Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`, :class:`.ModalComponentModelerData`, :class:`.TerminalComponentModelerData`] @@ -224,6 +331,10 @@ def run( stub = Tidy3dStub(simulation=simulation) task_name = stub.get_default_task_name() + if custom_vjp is not None: + if isinstance(custom_vjp, CustomVJPConfig): + custom_vjp = (custom_vjp,) + # component modeler path: route autograd-valid modelers to local run from tidy3d.plugins.smatrix.component_modelers.types import ComponentModelerType @@ -245,9 +356,22 @@ def run( priority=priority, local_gradient=local_gradient, max_num_adjoint_per_fwd=max_num_adjoint_per_fwd, + custom_vjp=custom_vjp, ) - if isinstance(simulation, td.Simulation) and is_valid_for_autograd(simulation): + should_use_autograd = False + if isinstance(simulation, td.Simulation): + should_use_autograd = is_valid_for_autograd(simulation) + + if should_use_autograd: + if (custom_vjp is not None) and (not local_gradient): + raise AdjointError("custom_vjp specified for a remote gradient not supported.") + + if custom_vjp is not None: + expanded_custom_vjp = expand_custom_vjp(custom_vjp, simulation) + else: + expanded_custom_vjp = None + return _run( simulation=simulation, task_name=task_name, @@ -263,6 +387,7 @@ def run( parent_tasks=parent_tasks, local_gradient=local_gradient, max_num_adjoint_per_fwd=max_num_adjoint_per_fwd, + custom_vjp=expanded_custom_vjp, pay_type=pay_type, priority=priority, lazy=lazy, @@ -288,7 +413,51 @@ def run( ) -def run_async( +def run( + simulation: WorkflowType, + task_name: typing.Optional[str] = None, + folder_name: str = "default", + path: PathLike = "simulation_data.hdf5", + callback_url: typing.Optional[str] = None, + verbose: bool = True, + progress_callback_upload: typing.Optional[typing.Callable[[float], None]] = None, + progress_callback_download: typing.Optional[typing.Callable[[float], None]] = None, + solver_version: typing.Optional[str] = None, + worker_group: typing.Optional[str] = None, + simulation_type: str = "tidy3d", + parent_tasks: typing.Optional[list[str]] = None, + local_gradient: typing.Optional[bool] = None, + max_num_adjoint_per_fwd: typing.Optional[int] = None, + reduce_simulation: typing.Literal["auto", True, False] = "auto", + pay_type: typing.Union[PayType, str] = PayType.AUTO, + priority: typing.Optional[int] = None, + lazy: typing.Optional[bool] = None, +) -> WorkflowDataType: + """Wrapper for run_custom for usage without custom_vjp for public facing API.""" + return run_custom( + simulation=simulation, + task_name=task_name, + folder_name=folder_name, + path=path, + callback_url=callback_url, + verbose=verbose, + progress_callback_upload=progress_callback_upload, + progress_callback_download=progress_callback_download, + solver_version=solver_version, + worker_group=worker_group, + simulation_type=simulation_type, + parent_tasks=parent_tasks, + local_gradient=local_gradient, + max_num_adjoint_per_fwd=max_num_adjoint_per_fwd, + reduce_simulation=reduce_simulation, + pay_type=pay_type, + priority=priority, + lazy=lazy, + custom_vjp=None, + ) + + +def run_async_custom( simulations: typing.Union[dict[str, td.Simulation], tuple[td.Simulation], list[td.Simulation]], folder_name: str = "default", path_dir: PathLike = DEFAULT_DATA_DIR, @@ -304,6 +473,15 @@ def run_async( pay_type: typing.Union[PayType, str] = PayType.AUTO, priority: typing.Optional[int] = None, lazy: typing.Optional[bool] = None, + custom_vjp: typing.Optional[ + typing.Union[ + CustomVJPConfig, + dict[str, CustomVJPConfig], + typing.Sequence[CustomVJPConfig], + dict[str, typing.Sequence[CustomVJPConfig]], + typing.Sequence[typing.Sequence[CustomVJPConfig]], + ] + ] = None, ) -> BatchData: """Submits a set of Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`] objects to server, starts running, monitors progress, downloads, and loads results as a :class:`.BatchData` object. @@ -344,6 +522,18 @@ def run_async( lazy: Optional[bool] = None Whether to return lazy data proxies. Defaults to ``True`` for batch runs when unspecified, matching :func:`tidy3d.web.run`. + custom_vjp: typing.Optional[typing.Union[ + CustomVJPConfig, + dict[str, CustomVJPConfig], + typing.Sequence[CustomVJPConfig], + dict[str, typing.Sequence[CustomVJPConfig]], + typing.Sequence[typing.Sequence[CustomVJPConfig]], + ]] = None + Specification of alternate gradient function for certain structures in the simulation. Different + custom_vjp's can be added for different simulations or the same set can be broadcasted to all simulations. + Specifying a single config will broadcast to all simluations. Specifying a dict or a sequence with single configs + as values will set one config for each simluation. Most generally, multiple custom_vjp's can be specified for each + simulation by specifying a dict with sequence values or a sequence of sequences. Returns ------ @@ -371,16 +561,88 @@ def run_async( lazy = True if lazy is None else bool(lazy) + def validate_and_expand( + fn_arg: CustomVJPConfig, + fn_arg_name: str, + base_type: type[CustomVJPConfig], + orig_sim_arg: typing.Union[ + dict[str, td.Simulation], tuple[td.Simulation], list[td.Simulation] + ], + sim_dict: dict[str, tuple[td.Simulation]], + ) -> dict[str, typing.Sequence[CustomVJPConfig]]: + """Check and validate the provided custom_vjp type and expand as necessary to""" + """match the provided simulation specification.""" + if fn_arg is None: + return fn_arg + + if isinstance(fn_arg, base_type): + expanded = dict.fromkeys(sim_dict.keys(), (fn_arg,)) + return expanded + + expanded = {} + if not isinstance(fn_arg, type(orig_sim_arg)): + raise AdjointError( + f"{fn_arg_name} type ({type(fn_arg)}) should match simulations type ({type(simulations)})" + ) + + if isinstance(orig_sim_arg, dict): + check_keys = fn_arg.keys() == sim_dict.keys() + + if not check_keys: + raise AdjointError(f"{fn_arg_name} keys do not match simulations keys") + + for key, val in fn_arg.items(): + if isinstance(val, base_type): + expanded[key] = (val,) + else: + expanded[key] = val + + elif isinstance(orig_sim_arg, (list, tuple)): + if not (len(fn_arg) == len(orig_sim_arg)): + raise AdjointError( + f"{fn_arg_name} is not the same length as simulations ({len(expanded)} vs. {len(simulations)})" + ) + + for idx, key in enumerate(sim_dict.keys()): + val = fn_arg[idx] + if isinstance(val, (list, tuple)): + expanded[key] = val + else: + expanded[key] = (val,) + + return expanded + if isinstance(simulations, (tuple, list)): sim_dict = {} for i, sim in enumerate(simulations, 1): task_name = Tidy3dStub(simulation=sim).get_default_task_name() + f"_{i}" sim_dict[task_name] = sim - simulations = sim_dict + else: + sim_dict = simulations + + custom_vjp = validate_and_expand( + custom_vjp, "custom_vjp", CustomVJPConfig, simulations, sim_dict + ) + + simulations = sim_dict path_dir = Path(path_dir) - if is_valid_for_autograd_async(simulations): + should_use_autograd_async = is_valid_for_autograd_async(simulations) + + if should_use_autograd_async: + if (custom_vjp is not None) and (not local_gradient): + raise AdjointError("custom_vjp specified for a remote gradient not supported.") + + if custom_vjp is not None: + expanded_custom_vjp_dict = {} + for sim_key, custom_vjp_entry in custom_vjp.items(): + expanded_custom_vjp_dict[sim_key] = expand_custom_vjp( + custom_vjp_entry, simulations[sim_key] + ) + else: + expanded_custom_vjp_dict = None + return _run_async( simulations=simulations, folder_name=folder_name, @@ -393,6 +655,7 @@ def run_async( parent_tasks=parent_tasks, local_gradient=local_gradient, max_num_adjoint_per_fwd=max_num_adjoint_per_fwd, + custom_vjp=expanded_custom_vjp_dict, pay_type=pay_type, priority=priority, lazy=lazy, @@ -415,6 +678,44 @@ def run_async( ) +def run_async( + simulations: typing.Union[dict[str, td.Simulation], tuple[td.Simulation], list[td.Simulation]], + folder_name: str = "default", + path_dir: PathLike = DEFAULT_DATA_DIR, + callback_url: typing.Optional[str] = None, + num_workers: typing.Optional[int] = None, + verbose: bool = True, + simulation_type: str = "tidy3d", + solver_version: typing.Optional[str] = None, + parent_tasks: typing.Optional[dict[str, list[str]]] = None, + local_gradient: typing.Optional[bool] = None, + max_num_adjoint_per_fwd: typing.Optional[int] = None, + reduce_simulation: typing.Literal["auto", True, False] = "auto", + pay_type: typing.Union[PayType, str] = PayType.AUTO, + priority: typing.Optional[int] = None, + lazy: typing.Optional[bool] = None, +) -> BatchData: + """Wrapper for run_async_custom for usage without custom_vjp for public facing API.""" + return run_async_custom( + simulations=simulations, + folder_name=folder_name, + path_dir=path_dir, + callback_url=callback_url, + num_workers=num_workers, + verbose=verbose, + simulation_type=simulation_type, + solver_version=solver_version, + parent_tasks=parent_tasks, + local_gradient=local_gradient, + max_num_adjoint_per_fwd=max_num_adjoint_per_fwd, + reduce_simulation=reduce_simulation, + pay_type=pay_type, + priority=priority, + lazy=lazy, + custom_vjp=None, + ) + + """ User-facing ``run`` and `run_async`` functions, compatible with ``autograd`` """ @@ -423,11 +724,17 @@ def _run( task_name: str, local_gradient: bool = False, max_num_adjoint_per_fwd: typing.Optional[int] = None, + custom_vjp: typing.Optional[tuple[CustomVJPConfig, ...]] = None, **run_kwargs: Any, ) -> td.SimulationData: """User-facing ``web.run`` function, compatible with ``autograd`` differentiation.""" - traced_fields_sim = setup_run(simulation=simulation) + setup_result = setup_run(simulation=simulation) + if custom_vjp: + verify_custom_vjp(custom_vjp, setup_result.sim_fields) + + traced_fields_sim = setup_result.sim_fields + simulation = setup_result.simulation # if we register this as not needing adjoint at all (no tracers), call regular run function if not traced_fields_sim: @@ -456,6 +763,7 @@ def _run( aux_data=aux_data, local_gradient=local_gradient, max_num_adjoint_per_fwd=max_num_adjoint_per_fwd, + custom_vjp=custom_vjp, **run_kwargs, ) @@ -466,39 +774,56 @@ def _run_async( simulations: dict[str, td.Simulation], local_gradient: bool = False, max_num_adjoint_per_fwd: typing.Optional[int] = None, + custom_vjp: typing.Optional[dict[str, typing.Sequence[CustomVJPConfig]]] = None, **run_async_kwargs: Any, ) -> dict[str, td.SimulationData]: """User-facing ``web.run_async`` function, compatible with ``autograd`` differentiation.""" - task_names = simulations.keys() traced_fields_sim_dict: dict[str, AutogradFieldMap] = {} sims_original: dict[str, td.Simulation] = {} + sims_prepared: dict[str, td.Simulation] = {} + + if max_num_adjoint_per_fwd is None: + max_num_adjoint_per_fwd = config.adjoint.max_adjoint_per_fwd + + aux_data_dict = {task_name: {} for task_name in task_names} + for task_name in task_names: sim = simulations[task_name] - traced_fields = setup_run(simulation=sim) + setup_result = setup_run(simulation=sim) + + if custom_vjp: + verify_custom_vjp(custom_vjp[task_name], setup_result.sim_fields) + + sim_prepared = setup_result.simulation + traced_fields = setup_result.sim_fields + + sims_prepared[task_name] = sim_prepared + traced_fields_sim_dict[task_name] = traced_fields - payload = sim._serialized_traced_field_keys(traced_fields) - sim_static = sim.to_static() + payload = sim_prepared._serialized_traced_field_keys(traced_fields) + sim_static = sim_prepared.to_static() if payload: sim_static.attrs[TRACED_FIELD_KEYS_ATTR] = payload + sims_original[task_name] = sim_static - traced_fields_sim_dict = dict_ag(traced_fields_sim_dict) # TODO: shortcut primitive running for any items with no tracers? + traced_fields_sim_dict = dict_ag(traced_fields_sim_dict) + sims_original = {name: sims_original[name] for name in traced_fields_sim_dict.keys()} - aux_data_dict = {task_name: {} for task_name in task_names} traced_fields_data_dict = _run_async_primitive( traced_fields_sim_dict, # if you pass as a kwarg it will not trace :/ sims_original=sims_original, aux_data_dict=aux_data_dict, local_gradient=local_gradient, max_num_adjoint_per_fwd=max_num_adjoint_per_fwd, + custom_vjp=custom_vjp, **run_async_kwargs, ) - # TODO: package this as a Batch? it might be not possible as autograd tracers lose their - # powers when we save them to file. + # TODO: package this as a Batch? it might be not possible as autograd tracers lose their powers when we save them to file. sim_data_dict = {} for task_name in task_names: traced_fields_data = traced_fields_data_dict[task_name] @@ -509,14 +834,22 @@ def _run_async( return sim_data_dict -def setup_run(simulation: td.Simulation) -> AutogradFieldMap: - """Process a user-supplied ``Simulation`` into inputs to ``_run_primitive``.""" +def setup_run( + simulation: td.Simulation, +) -> SetupRunResult: + """Prepare simulation and traced fields, including numerical structure insertions.""" + + sim_prepared = simulation - # get a mapping of all the traced fields in the provided simulation - return simulation._strip_traced_fields( + sim_fields_map = sim_prepared._strip_traced_fields( include_untraced_data_arrays=False, starting_path=("structures",) ) + return SetupRunResult( + sim_fields=sim_fields_map, + simulation=sim_prepared, + ) + def postprocess_run(traced_fields_data: AutogradFieldMap, aux_data: dict) -> td.SimulationData: """Process the return from ``_run_primitive`` into ``SimulationData`` for user.""" @@ -537,6 +870,7 @@ def _run_primitive( aux_data: dict, local_gradient: bool, max_num_adjoint_per_fwd: int, + custom_vjp: typing.Optional[typing.Union[CustomVJPConfig, tuple[CustomVJPConfig, ...]]] = None, **run_kwargs: Any, ) -> AutogradFieldMap: """Autograd-traced 'run()' function: runs simulation, strips tracer data, caches fwd data.""" @@ -605,6 +939,7 @@ def _run_async_primitive( aux_data_dict: dict[dict[str, typing.Any]], local_gradient: bool, max_num_adjoint_per_fwd: int, + custom_vjp: typing.Optional[dict[str, typing.Sequence[CustomVJPConfig]],] = None, **run_async_kwargs: Any, ) -> dict[str, AutogradFieldMap]: task_names = sim_fields_dict.keys() @@ -710,6 +1045,7 @@ def _run_bwd( aux_data: dict, local_gradient: bool, max_num_adjoint_per_fwd: int, + custom_vjp: tuple[CustomVJPConfig, ...], **run_kwargs: Any, ) -> typing.Callable[[AutogradFieldMap], AutogradFieldMap]: """VJP-maker for ``_run_primitive()``. Constructs and runs adjoint simulations, computes grad.""" @@ -784,6 +1120,7 @@ def vjp(data_fields_vjp: AutogradFieldMap) -> AutogradFieldMap: sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, sim_fields_keys=sim_fields_keys, + custom_vjp=custom_vjp, ) else: td.log.info("Starting server-side batch of adjoint simulations ...") @@ -835,6 +1172,7 @@ def _run_async_bwd( aux_data_dict: dict[str, dict[str, typing.Any]], local_gradient: bool, max_num_adjoint_per_fwd: int, + custom_vjp: typing.Optional[dict[str, typing.Sequence[CustomVJPConfig]]] = None, **run_async_kwargs: Any, ) -> typing.Callable[[dict[str, AutogradFieldMap]], dict[str, AutogradFieldMap]]: """VJP-maker for ``_run_primitive()``. Constructs and runs adjoint simulation, computes grad.""" @@ -844,6 +1182,7 @@ def _run_async_bwd( task_names = data_fields_original_dict.keys() + custom_vjp = custom_vjp or {} # get the fwd epsilon and field data from the cached aux_data sim_data_orig_dict = {} sim_data_fwd_dict = {} @@ -856,7 +1195,7 @@ def _run_async_bwd( if local_gradient: sim_data_fwd_dict[task_name] = aux_data[AUX_KEY_SIM_DATA_FWD] - td.log.info("constructing custom vjp function for backwards pass.") + td.log.info("Constructing custom VJP function for backwards pass.") def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, AutogradFieldMap]: """dJ/d{sim.traced_fields()} as a function of Function of dJ/d{data.traced_fields()}""" @@ -920,11 +1259,16 @@ def vjp(data_fields_dict_vjp: dict[str, AutogradFieldMap]) -> dict[str, Autograd sim_fields_keys = sim_fields_keys_dict[task_name] # Compute VJP contribution + task_custom_vjp = custom_vjp.get(task_name) + if isinstance(task_custom_vjp, CustomVJPConfig): + task_custom_vjp = (task_custom_vjp,) + vjp_results[adj_task_name] = postprocess_adj( sim_data_adj=sim_data_adj, sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, sim_fields_keys=sim_fields_keys, + custom_vjp=task_custom_vjp, ) else: # Set up parent tasks mapping for all adjoint simulations @@ -990,6 +1334,7 @@ def postprocess_adj( sim_data_orig: td.SimulationData, sim_data_fwd: td.SimulationData, sim_fields_keys: list[tuple], + custom_vjp: tuple[CustomVJPConfig, ...], ) -> AutogradFieldMap: """Postprocess adjoint results into VJPs (delegated).""" return _postprocess_adj_impl( @@ -997,6 +1342,7 @@ def postprocess_adj( sim_data_orig=sim_data_orig, sim_data_fwd=sim_data_fwd, sim_fields_keys=sim_fields_keys, + custom_vjp=custom_vjp, ) diff --git a/tidy3d/web/api/autograd/backward.py b/tidy3d/web/api/autograd/backward.py index 0c596f61dd..b9d01963e8 100644 --- a/tidy3d/web/api/autograd/backward.py +++ b/tidy3d/web/api/autograd/backward.py @@ -1,5 +1,7 @@ from __future__ import annotations +import functools +import typing from collections import defaultdict import numpy as np @@ -9,11 +11,14 @@ from tidy3d import Medium from tidy3d.components.autograd import AutogradFieldMap, get_static from tidy3d.components.autograd.derivative_utils import DerivativeInfo -from tidy3d.components.data.data_array import DataArray +from tidy3d.components.data.data_array import DataArray, FreqDataArray, ScalarFieldDataArray +from tidy3d.components.geometry.base import Box +from tidy3d.components.geometry.utils import GeometryType from tidy3d.config import config from tidy3d.exceptions import AdjointError from tidy3d.packaging import disable_local_subpixel +from .types import CustomVJPConfig from .utils import E_to_D, get_derivative_maps @@ -105,18 +110,49 @@ def postprocess_adj( sim_data_orig: td.SimulationData, sim_data_fwd: td.SimulationData, sim_fields_keys: list[tuple], + custom_vjp: typing.Optional[tuple[CustomVJPConfig, ...]] = None, ) -> AutogradFieldMap: """Postprocess some data from the adjoint simulation into the VJP for the original sim flds.""" - # map of index into 'structures' to the list of paths we need vjps for + def get_all_paths(match_structure_index: int) -> tuple[tuple[str, str, int]]: + """Get all the paths that may appear in autograd for this structure index. This allows a + custom_vjp to be called for all autograd paths for the structure. + """ + all_paths = tuple( + tuple(structure_path) + for namespace, structure_index, *structure_path in sim_fields_keys + if structure_index == match_structure_index + ) + + return all_paths + + custom_vjp_lookup: dict[int, dict[tuple[str, str], typing.Callable[..., typing.Any]]] = {} + if custom_vjp: + for vjp_config in custom_vjp: + structure_index = vjp_config.structure_index + vjp_fn = vjp_config.compute_derivatives + path = vjp_config.path_key + + if path is None: + for match_path in get_all_paths(structure_index): + custom_vjp_lookup.setdefault(structure_index, {})[match_path[0:2]] = vjp_fn + else: + custom_vjp_lookup.setdefault(structure_index, {})[path] = vjp_fn + + # map of index into 'structures' to the paths we need VJPs for sim_vjp_map = defaultdict(list) - for _, structure_index, *structure_path in sim_fields_keys: + for namespace, structure_index, *structure_path in sim_fields_keys: structure_path = tuple(structure_path) - sim_vjp_map[structure_index].append(structure_path) + if namespace == "structures": + sim_vjp_map[structure_index].append(structure_path) # store the derivative values given the forward and adjoint data sim_fields_vjp = {} - for structure_index, structure_paths in sim_vjp_map.items(): + all_structure_indices = sorted(set(sim_vjp_map.keys())) + + for structure_index in all_structure_indices: + structure_paths = tuple(sim_vjp_map.get(structure_index, ())) + # grab the forward and adjoint data fld_fwd = sim_data_fwd._get_adjoint_data(structure_index, data_type="fld") eps_fwd = sim_data_fwd._get_adjoint_data(structure_index, data_type="eps") @@ -162,8 +198,8 @@ def postprocess_adj( eps_background = None # auto permittivity detection for non-box geometries + sim_orig = sim_data_orig.simulation if not isinstance(structure.geometry, td.Box): - sim_orig = sim_data_orig.simulation plane_eps = eps_fwd.monitor.geometry sim_orig_grid_spec = td.components.grid.grid_spec.GridSpec.from_grid(sim_orig.grid) @@ -215,6 +251,42 @@ def postprocess_adj( rmax_intersect = tuple([min(a, b) for a, b in zip(rmax_sim, rmax_struct)]) bounds_intersect = (rmin_intersect, rmax_intersect) + def updated_epsilon_full_impl( + replacement_geometry: GeometryType, + adjoint_frequencies: typing.Optional[FreqDataArray], + structure_index: typing.Optional[int], + eps_box: typing.Optional[Box], + sim_orig: td.Simulation, + ) -> ScalarFieldDataArray: + """Return the simulation permittivity for eps_box after replacing the geometry + for this structure with a new geometry. This is helpful for carrying out finite + difference permittivity computations. + """ + update_sim = sim_orig.updated_copy( + structures=[ + sim_orig.structures[idx].updated_copy(geometry=replacement_geometry) + if idx == structure_index + else sim_orig.structures[idx] + for idx in range(len(sim_orig.structures)) + ], + grid_spec=td.components.grid.grid_spec.GridSpec.from_grid(sim_orig.grid), + ) + + eps_by_f = [ + update_sim.epsilon(box=eps_box, coord_key="centers", freq=f) + for f in adjoint_frequencies + ] + + return xr.concat(eps_by_f, dim="f").assign_coords(f=adjoint_frequencies) + + updated_epsilon_full = functools.partial( + updated_epsilon_full_impl, + adjoint_frequencies=adjoint_frequencies, + structure_index=structure_index, + eps_box=eps_fwd.monitor.geometry, + sim_orig=sim_orig, + ) + # get chunk size - if None, process all frequencies as one chunk freq_chunk_size = config.adjoint.solver_freq_chunk_size n_freqs = len(adjoint_frequencies) @@ -277,44 +349,64 @@ def postprocess_adj( else None ) - # create derivative info with sliced data - derivative_info = DerivativeInfo( - paths=structure_paths, - E_der_map=E_der_map_chunk, - D_der_map=D_der_map_chunk, - H_der_map=H_der_map_chunk, - E_fwd=E_fwd_chunk, - E_adj=E_adj_chunk, - D_fwd=D_fwd_chunk, - D_adj=D_adj_chunk, - H_fwd=H_fwd_chunk, - H_adj=H_adj_chunk, - eps_data=eps_data_chunk, - eps_in=eps_in_chunk, - eps_out=eps_out_chunk, - eps_background=eps_background_chunk, - frequencies=select_adjoint_freqs, # only chunk frequencies - eps_no_structure=eps_no_structure_chunk, - eps_inf_structure=eps_inf_structure_chunk, - bounds=struct_bounds, - bounds_intersect=bounds_intersect, - simulation_bounds=sim_data_orig.simulation.bounds, - is_medium_pec=structure.medium.is_pec, + def updated_epsilon_wrapper( + replacement_geometry: GeometryType, + select_adjoint_freqs: typing.Optional[FreqDataArray], + updated_epsilon_full: typing.Optional[typing.Callable], + ) -> ScalarFieldDataArray: + # Get permittivity function for a subset of frequencies + return updated_epsilon_full(replacement_geometry).sel(f=select_adjoint_freqs) + + updated_epsilon = functools.partial( + updated_epsilon_wrapper, + select_adjoint_freqs=select_adjoint_freqs, + updated_epsilon_full=updated_epsilon_full, ) - # compute derivatives for chunk - vjp_chunk = structure._compute_derivatives(derivative_info) + common_kwargs = { + "E_der_map": E_der_map_chunk, + "D_der_map": D_der_map_chunk, + "H_der_map": H_der_map_chunk, + "E_fwd": E_fwd_chunk, + "E_adj": E_adj_chunk, + "D_fwd": D_fwd_chunk, + "D_adj": D_adj_chunk, + "H_fwd": H_fwd_chunk, + "H_adj": H_adj_chunk, + "eps_data": eps_data_chunk, + "eps_in": eps_in_chunk, + "eps_out": eps_out_chunk, + "eps_background": eps_background_chunk, + "frequencies": select_adjoint_freqs, + "eps_no_structure": eps_no_structure_chunk, + "eps_inf_structure": eps_inf_structure_chunk, + "updated_epsilon": updated_epsilon, + "bounds": struct_bounds, + "bounds_intersect": bounds_intersect, + "simulation_bounds": sim_data_orig.simulation.bounds, + "is_medium_pec": structure.medium.is_pec, + } + + if structure_paths: + derivative_info_struct = DerivativeInfo( + paths=structure_paths, + **common_kwargs, + ) - # accumulate results - for path, value in vjp_chunk.items(): - if path in vjp_value_map: - val = vjp_value_map[path] - if isinstance(val, (list, tuple)) and isinstance(value, (list, tuple)): - vjp_value_map[path] = type(val)(x + y for x, y in zip(val, value)) + vjp_fns = custom_vjp_lookup.get(structure_index) + vjp_chunk = structure._compute_derivatives(derivative_info_struct, vjp_fns=vjp_fns) + + for path, value in vjp_chunk.items(): + if path in vjp_value_map: + existing = vjp_value_map[path] + if isinstance(existing, (list, tuple)) and isinstance(value, (list, tuple)): + vjp_value_map[path] = type(existing)( + x + y for x, y in zip(existing, value) + ) + else: + vjp_value_map[path] = existing + value else: - vjp_value_map[path] += value - else: - vjp_value_map[path] = value + vjp_value_map[path] = value # store vjps in output map for structure_path, vjp_value in vjp_value_map.items(): diff --git a/tidy3d/web/api/autograd/types.py b/tidy3d/web/api/autograd/types.py new file mode 100644 index 0000000000..5fb762f92b --- /dev/null +++ b/tidy3d/web/api/autograd/types.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import typing +from dataclasses import dataclass + +import tidy3d as td +from tidy3d.components.autograd import AutogradFieldMap +from tidy3d.components.geometry.utils import GeometryType +from tidy3d.components.medium import MediumType + + +@dataclass +class CustomVJPConfig: + structure_index: typing.Union[int, type[GeometryType], type[MediumType]] + """Index for structure to replace vjp or specification of geometry or medium type. If a type is provided, + the custom vjp will be applied to all structures in the simulation with the geometry or medium type. + """ + + compute_derivatives: typing.Callable + """Function for computing the targeted vjp value. The function should accept the geometry or medium in the + structure depending on if this is a geometry or medium path (see path_key) as the first argument. The second + argument should be named derivative_info and accept a DerivativeInfo object that contains important for computing + the gradient. The function should return a dict object that maps the path to the computed gradient value. + """ + + path_key: typing.Optional[tuple[str, ...]] = None + """Path key corresponding to the vjp. For example, this could be ('geometry', 'radius') if you are targeting + the radius parameter in the given structure geometry. It can also target the medium by specifying medium first + (i.e. - ('medium', 'permittivity') will target the permittivity variable in the structure's medium). If not + specified or set to None, the supplied function applies for all possible vjp paths. + """ + + +class SetupRunResult(typing.NamedTuple): + sim_fields: AutogradFieldMap + simulation: td.Simulation + + +__all__ = [ + "SetupRunResult", +]