diff --git a/CHANGELOG.md b/CHANGELOG.md index a4ae0b7081..103e32a8fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add support for `np.unwrap` in `tidy3d.plugins.autograd`. - Add Nunley variant to germanium material library based on Nunley et al. 2016 data. +### Changed +- Switched to an analytical gradient calculation for spatially-varying pole-residue models (`CustomPoleResidue`). + ### Fixed - Arrow lengths are now scaled consistently in the X and Y directions, and their lengths no longer exceed the height of the plot window. - Bug in `PlaneWave` defined with a negative `angle_theta` which would lead to wrong injection. @@ -101,8 +104,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Fixed `reverse` property of `td.Scene.plot_structures_property()` to also reverse the colorbar. - -### Fixed - Fixed bug in surface gradient computation where fields, instead of gradients, were being summed in frequency. ## [2.8.2] - 2025-04-09 diff --git a/tests/test_components/test_autograd.py b/tests/test_components/test_autograd.py index c46624648f..92f482d6aa 100644 --- a/tests/test_components/test_autograd.py +++ b/tests/test_components/test_autograd.py @@ -1650,18 +1650,18 @@ def test_custom_pole_residue(monkeypatch): custom_med_pole_res = td.CustomPoleResidue(eps_inf=eps_inf, poles=poles) def J(eps): - return anp.sum(abs(eps)) + return anp.sum(anp.abs(eps)) freq = 3e8 pr = td.CustomPoleResidue(eps_inf=eps_inf, poles=poles) eps0 = pr.eps_model(freq) - dJ_deps = ag.holomorphic_grad(J)(eps0) + dJ_deps = np.conj(ag.holomorphic_grad(J)(eps0)) monkeypatch.setattr( td.CustomPoleResidue, "_derivative_field_cmp", - lambda self, E_der_map, eps_data, dim, freqs: dJ_deps, + lambda self, E_der_map, eps_data, dim, freqs: dJ_deps / 3.0, ) import importlib @@ -1703,17 +1703,18 @@ def f(eps_inf, poles): eps = td.CustomPoleResidue._eps_model(eps_inf, poles, freq) return J(eps) - gfn = ag.holomorphic_grad(f, argnum=(0, 1)) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - grad_eps_inf, grad_poles = gfn(eps_inf.values, poles_complex) + gfn = ag.grad(lambda x: f(x, poles_complex)) + grad_eps_inf = gfn(eps_inf.values) assert np.allclose(grads_computed[("eps_inf",)], grad_eps_inf) + gfn = ag.holomorphic_grad(lambda x: f(eps_inf.values, x)) + grad_poles = gfn(poles_complex) + for i in range(len(poles)): for j in range(2): field_path = ("poles", i, j) - assert np.allclose(grads_computed[field_path], grad_poles[i][j]) + assert np.allclose(grads_computed[field_path], np.conj(grad_poles[i][j])) # @pytest.mark.timeout(18.0) diff --git a/tidy3d/components/medium.py b/tidy3d/components/medium.py index d65aed8f35..ee371e916a 100644 --- a/tidy3d/components/medium.py +++ b/tidy3d/components/medium.py @@ -3,12 +3,10 @@ from __future__ import annotations import functools -import warnings from abc import ABC, abstractmethod from math import isclose from typing import Callable, Optional, Union -import autograd as ag import autograd.numpy as np # TODO: it's hard to figure out which functions need this, for now all get it @@ -3425,50 +3423,52 @@ def loss_upper_bound(self) -> float: ep = ep[~np.isnan(ep)] return max(ep.imag) + @staticmethod + def _get_vjps_from_params( + dJ_deps_complex: Union[complex, np.ndarray], + poles_vals: list[tuple[Union[complex, np.ndarray], Union[complex, np.ndarray]]], + omega: float, + requested_paths: list[tuple], + ) -> AutogradFieldMap: + """ + Static helper to compute VJPs from parameters using the analytical chain rule. + """ + jw = 1j * omega + vjps = {} + + if ("eps_inf",) in requested_paths: + vjps[("eps_inf",)] = np.real(dJ_deps_complex) + + for i, (a_val, c_val) in enumerate(poles_vals): + if any(path[1] == i for path in requested_paths if path[0] == "poles"): + if ("poles", i, 0) in requested_paths: + deps_da = c_val / (jw + a_val) ** 2 + dJ_da = dJ_deps_complex * deps_da + vjps[("poles", i, 0)] = dJ_da + if ("poles", i, 1) in requested_paths: + deps_dc = -1 / (jw + a_val) + dJ_dc = dJ_deps_complex * deps_dc + vjps[("poles", i, 1)] = dJ_dc + + return vjps + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Compute adjoint derivatives for each of the ``fields`` given the multiplied E and D.""" + """Compute adjoint derivatives by preparing scalar data and calling the static helper.""" - # compute all derivatives beforehand - dJ_deps = self._derivative_eps_complex_volume( + dJ_deps_complex = self._derivative_eps_complex_volume( E_der_map=derivative_info.E_der_map, bounds=derivative_info.bounds, freqs=np.atleast_1d(derivative_info.frequency), ) - dJ_deps = complex(dJ_deps) - - # TODO: fix for multi-frequency - frequency = derivative_info.frequency - poles_complex = [(complex(a), complex(c)) for a, c in self.poles] - poles_complex = np.stack(poles_complex, axis=0) - - # compute gradients of eps_model with respect to eps_inf and poles - grad_eps_model = ag.holomorphic_grad(self._eps_model, argnum=(0, 1)) - with warnings.catch_warnings(): - # ignore warnings about holmorphic grad being passed a non-complex input (poles) - warnings.simplefilter("ignore") - deps_deps_inf, deps_dpoles = grad_eps_model( - complex(self.eps_inf), poles_complex, complex(frequency) - ) - - # multiply with partial dJ/deps to give full gradients - - dJ_deps_inf = dJ_deps * deps_deps_inf - dJ_dpoles = [(dJ_deps * a, dJ_deps * c) for a, c in deps_dpoles] - - # get vjps w.r.t. permittivity and conductivity of the bulk - derivative_map = {} - for field_path in derivative_info.paths: - field_name, *rest = field_path + poles_vals = [(complex(a), complex(c)) for a, c in self.poles] - if field_name == "eps_inf": - derivative_map[field_path] = float(np.real(dJ_deps_inf)) - - elif field_name == "poles": - pole_index, a_or_c = rest - derivative_map[field_path] = complex(dJ_dpoles[pole_index][a_or_c]) - - return derivative_map + return self._get_vjps_from_params( + dJ_deps_complex=complex(dJ_deps_complex), + poles_vals=poles_vals, + omega=2 * np.pi * derivative_info.frequency, + requested_paths=derivative_info.paths, + ) @classmethod def _real_partial_fraction_decomposition( @@ -3903,73 +3903,28 @@ def _sel_custom_data_inside(self, bounds: Bound): return self.updated_copy(eps_inf=eps_inf_reduced, poles=poles_reduced) def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Compute adjoint derivatives for each of the ``fields`` given the multiplied E and D.""" + """Compute adjoint derivatives by preparing array data and calling the static helper.""" - dJ_deps = 0.0 + dJ_deps_complex = 0.0 for dim in "xyz": - dJ_deps += self._derivative_field_cmp( + dJ_deps_complex += self._derivative_field_cmp( E_der_map=derivative_info.E_der_map, eps_data=self.eps_inf, dim=dim, freqs=np.atleast_1d(derivative_info.frequency), ) - # TODO: fix for multi-frequency - frequency = derivative_info.frequency - - poles_complex = [ + poles_vals = [ (np.array(a.values, dtype=complex), np.array(c.values, dtype=complex)) for a, c in self.poles ] - poles_complex = np.stack(poles_complex, axis=0) - - def eps_model_r( - eps_inf: complex, poles: list[tuple[complex, complex]], frequency: float - ) -> float: - """Real part of ``eps_model`` evaluated on ``self`` fields.""" - return np.real(self._eps_model(eps_inf, poles, frequency)) - - def eps_model_i( - eps_inf: complex, poles: list[tuple[complex, complex]], frequency: float - ) -> float: - """Real part of ``eps_model`` evaluated on ``self`` fields.""" - return np.imag(self._eps_model(eps_inf, poles, frequency)) - - # compute the gradients w.r.t. each real and imaginary parts for eps_inf and poles - grad_eps_model_r = ag.elementwise_grad(eps_model_r, argnum=(0, 1)) - grad_eps_model_i = ag.elementwise_grad(eps_model_i, argnum=(0, 1)) - deps_deps_inf_r, deps_dpoles_r = grad_eps_model_r( - self.eps_inf.values, poles_complex, frequency - ) - deps_deps_inf_i, deps_dpoles_i = grad_eps_model_i( - self.eps_inf.values, poles_complex, frequency - ) - - # multiply with dJ_deps partial derivative to give full gradients - - deps_deps_inf = deps_deps_inf_r + 1j * deps_deps_inf_i - dJ_deps_inf = dJ_deps * deps_deps_inf / 3.0 # mysterious 3 - dJ_dpoles = [] - for (da_r, dc_r), (da_i, dc_i) in zip(deps_dpoles_r, deps_dpoles_i): - da = da_r + 1j * da_i - dc = dc_r + 1j * dc_i - dJ_da = dJ_deps * da / 2.0 # mysterious 2 - dJ_dc = dJ_deps * dc / 2.0 # mysterious 2 - dJ_dpoles.append((dJ_da, dJ_dc)) - - derivative_map = {} - for field_path in derivative_info.paths: - field_name, *rest = field_path - - if field_name == "eps_inf": - derivative_map[field_path] = np.real(dJ_deps_inf) - - elif field_name == "poles": - pole_index, a_or_c = rest - derivative_map[field_path] = dJ_dpoles[pole_index][a_or_c] - - return derivative_map + return PoleResidue._get_vjps_from_params( + dJ_deps_complex=dJ_deps_complex, + poles_vals=poles_vals, + omega=2 * np.pi * derivative_info.frequency, + requested_paths=derivative_info.paths, + ) class Sellmeier(DispersiveMedium):