Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add support for `np.unwrap` in `tidy3d.plugins.autograd`.
- Add Nunley variant to germanium material library based on Nunley et al. 2016 data.

### Changed
- Switched to an analytical gradient calculation for spatially-varying pole-residue models (`CustomPoleResidue`).

### Fixed
- Arrow lengths are now scaled consistently in the X and Y directions, and their lengths no longer exceed the height of the plot window.
- Bug in `PlaneWave` defined with a negative `angle_theta` which would lead to wrong injection.
Expand Down Expand Up @@ -101,8 +104,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
- Fixed `reverse` property of `td.Scene.plot_structures_property()` to also reverse the colorbar.

### Fixed
- Fixed bug in surface gradient computation where fields, instead of gradients, were being summed in frequency.

## [2.8.2] - 2025-04-09
Expand Down
17 changes: 9 additions & 8 deletions tests/test_components/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1650,18 +1650,18 @@ def test_custom_pole_residue(monkeypatch):
custom_med_pole_res = td.CustomPoleResidue(eps_inf=eps_inf, poles=poles)

def J(eps):
return anp.sum(abs(eps))
return anp.sum(anp.abs(eps))

freq = 3e8
pr = td.CustomPoleResidue(eps_inf=eps_inf, poles=poles)
eps0 = pr.eps_model(freq)

dJ_deps = ag.holomorphic_grad(J)(eps0)
dJ_deps = np.conj(ag.holomorphic_grad(J)(eps0))

monkeypatch.setattr(
td.CustomPoleResidue,
"_derivative_field_cmp",
lambda self, E_der_map, eps_data, dim, freqs: dJ_deps,
lambda self, E_der_map, eps_data, dim, freqs: dJ_deps / 3.0,
)

import importlib
Expand Down Expand Up @@ -1703,17 +1703,18 @@ def f(eps_inf, poles):
eps = td.CustomPoleResidue._eps_model(eps_inf, poles, freq)
return J(eps)

gfn = ag.holomorphic_grad(f, argnum=(0, 1))
with warnings.catch_warnings():
warnings.simplefilter("ignore")
grad_eps_inf, grad_poles = gfn(eps_inf.values, poles_complex)
gfn = ag.grad(lambda x: f(x, poles_complex))
grad_eps_inf = gfn(eps_inf.values)

assert np.allclose(grads_computed[("eps_inf",)], grad_eps_inf)

gfn = ag.holomorphic_grad(lambda x: f(eps_inf.values, x))
grad_poles = gfn(poles_complex)

for i in range(len(poles)):
for j in range(2):
field_path = ("poles", i, j)
assert np.allclose(grads_computed[field_path], grad_poles[i][j])
assert np.allclose(grads_computed[field_path], np.conj(grad_poles[i][j]))


# @pytest.mark.timeout(18.0)
Expand Down
141 changes: 48 additions & 93 deletions tidy3d/components/medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
from __future__ import annotations

import functools
import warnings
from abc import ABC, abstractmethod
from math import isclose
from typing import Callable, Optional, Union

import autograd as ag
import autograd.numpy as np

# TODO: it's hard to figure out which functions need this, for now all get it
Expand Down Expand Up @@ -3425,50 +3423,52 @@ def loss_upper_bound(self) -> float:
ep = ep[~np.isnan(ep)]
return max(ep.imag)

@staticmethod
def _get_vjps_from_params(
dJ_deps_complex: Union[complex, np.ndarray],
poles_vals: list[tuple[Union[complex, np.ndarray], Union[complex, np.ndarray]]],
omega: float,
requested_paths: list[tuple],
) -> AutogradFieldMap:
"""
Static helper to compute VJPs from parameters using the analytical chain rule.
"""
jw = 1j * omega
vjps = {}

if ("eps_inf",) in requested_paths:
vjps[("eps_inf",)] = np.real(dJ_deps_complex)

for i, (a_val, c_val) in enumerate(poles_vals):
if any(path[1] == i for path in requested_paths if path[0] == "poles"):
if ("poles", i, 0) in requested_paths:
deps_da = c_val / (jw + a_val) ** 2
dJ_da = dJ_deps_complex * deps_da
vjps[("poles", i, 0)] = dJ_da
if ("poles", i, 1) in requested_paths:
deps_dc = -1 / (jw + a_val)
dJ_dc = dJ_deps_complex * deps_dc
vjps[("poles", i, 1)] = dJ_dc

return vjps

def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap:
"""Compute adjoint derivatives for each of the ``fields`` given the multiplied E and D."""
"""Compute adjoint derivatives by preparing scalar data and calling the static helper."""

# compute all derivatives beforehand
dJ_deps = self._derivative_eps_complex_volume(
dJ_deps_complex = self._derivative_eps_complex_volume(
E_der_map=derivative_info.E_der_map,
bounds=derivative_info.bounds,
freqs=np.atleast_1d(derivative_info.frequency),
)

dJ_deps = complex(dJ_deps)

# TODO: fix for multi-frequency
frequency = derivative_info.frequency
poles_complex = [(complex(a), complex(c)) for a, c in self.poles]
poles_complex = np.stack(poles_complex, axis=0)

# compute gradients of eps_model with respect to eps_inf and poles
grad_eps_model = ag.holomorphic_grad(self._eps_model, argnum=(0, 1))
with warnings.catch_warnings():
# ignore warnings about holmorphic grad being passed a non-complex input (poles)
warnings.simplefilter("ignore")
deps_deps_inf, deps_dpoles = grad_eps_model(
complex(self.eps_inf), poles_complex, complex(frequency)
)

# multiply with partial dJ/deps to give full gradients

dJ_deps_inf = dJ_deps * deps_deps_inf
dJ_dpoles = [(dJ_deps * a, dJ_deps * c) for a, c in deps_dpoles]

# get vjps w.r.t. permittivity and conductivity of the bulk
derivative_map = {}
for field_path in derivative_info.paths:
field_name, *rest = field_path
poles_vals = [(complex(a), complex(c)) for a, c in self.poles]

if field_name == "eps_inf":
derivative_map[field_path] = float(np.real(dJ_deps_inf))

elif field_name == "poles":
pole_index, a_or_c = rest
derivative_map[field_path] = complex(dJ_dpoles[pole_index][a_or_c])

return derivative_map
return self._get_vjps_from_params(
dJ_deps_complex=complex(dJ_deps_complex),
poles_vals=poles_vals,
omega=2 * np.pi * derivative_info.frequency,
requested_paths=derivative_info.paths,
)

@classmethod
def _real_partial_fraction_decomposition(
Expand Down Expand Up @@ -3903,73 +3903,28 @@ def _sel_custom_data_inside(self, bounds: Bound):
return self.updated_copy(eps_inf=eps_inf_reduced, poles=poles_reduced)

def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap:
"""Compute adjoint derivatives for each of the ``fields`` given the multiplied E and D."""
"""Compute adjoint derivatives by preparing array data and calling the static helper."""

dJ_deps = 0.0
dJ_deps_complex = 0.0
for dim in "xyz":
dJ_deps += self._derivative_field_cmp(
dJ_deps_complex += self._derivative_field_cmp(
E_der_map=derivative_info.E_der_map,
eps_data=self.eps_inf,
dim=dim,
freqs=np.atleast_1d(derivative_info.frequency),
)

# TODO: fix for multi-frequency
frequency = derivative_info.frequency

poles_complex = [
poles_vals = [
(np.array(a.values, dtype=complex), np.array(c.values, dtype=complex))
for a, c in self.poles
]
poles_complex = np.stack(poles_complex, axis=0)

def eps_model_r(
eps_inf: complex, poles: list[tuple[complex, complex]], frequency: float
) -> float:
"""Real part of ``eps_model`` evaluated on ``self`` fields."""
return np.real(self._eps_model(eps_inf, poles, frequency))

def eps_model_i(
eps_inf: complex, poles: list[tuple[complex, complex]], frequency: float
) -> float:
"""Real part of ``eps_model`` evaluated on ``self`` fields."""
return np.imag(self._eps_model(eps_inf, poles, frequency))

# compute the gradients w.r.t. each real and imaginary parts for eps_inf and poles
grad_eps_model_r = ag.elementwise_grad(eps_model_r, argnum=(0, 1))
grad_eps_model_i = ag.elementwise_grad(eps_model_i, argnum=(0, 1))
deps_deps_inf_r, deps_dpoles_r = grad_eps_model_r(
self.eps_inf.values, poles_complex, frequency
)
deps_deps_inf_i, deps_dpoles_i = grad_eps_model_i(
self.eps_inf.values, poles_complex, frequency
)

# multiply with dJ_deps partial derivative to give full gradients

deps_deps_inf = deps_deps_inf_r + 1j * deps_deps_inf_i
dJ_deps_inf = dJ_deps * deps_deps_inf / 3.0 # mysterious 3

dJ_dpoles = []
for (da_r, dc_r), (da_i, dc_i) in zip(deps_dpoles_r, deps_dpoles_i):
da = da_r + 1j * da_i
dc = dc_r + 1j * dc_i
dJ_da = dJ_deps * da / 2.0 # mysterious 2
dJ_dc = dJ_deps * dc / 2.0 # mysterious 2
dJ_dpoles.append((dJ_da, dJ_dc))

derivative_map = {}
for field_path in derivative_info.paths:
field_name, *rest = field_path

if field_name == "eps_inf":
derivative_map[field_path] = np.real(dJ_deps_inf)

elif field_name == "poles":
pole_index, a_or_c = rest
derivative_map[field_path] = dJ_dpoles[pole_index][a_or_c]

return derivative_map
return PoleResidue._get_vjps_from_params(
dJ_deps_complex=dJ_deps_complex,
poles_vals=poles_vals,
omega=2 * np.pi * derivative_info.frequency,
requested_paths=derivative_info.paths,
)


class Sellmeier(DispersiveMedium):
Expand Down