From 3478ad3906961370f2bc08d9d4385fae649ffe55 Mon Sep 17 00:00:00 2001 From: mahlau-flex Date: Wed, 26 Nov 2025 13:04:42 +0100 Subject: [PATCH] chore(invdes): fixed smoothed projection gradient with beta=inf --- .../autograd/invdes/test_projections.py | 49 +++++++++++++------ tidy3d/plugins/autograd/invdes/projections.py | 6 +-- 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/tests/test_plugins/autograd/invdes/test_projections.py b/tests/test_plugins/autograd/invdes/test_projections.py index 192612173c..9a43e592f0 100644 --- a/tests/test_plugins/autograd/invdes/test_projections.py +++ b/tests/test_plugins/autograd/invdes/test_projections.py @@ -2,23 +2,34 @@ import autograd import numpy as np +import pytest +from autograd.test_util import check_grads from tidy3d.plugins.autograd.invdes.filters import ConicFilter from tidy3d.plugins.autograd.invdes.projections import smoothed_projection, tanh_projection -def test_smoothed_projection_beta_inf(): - nx, ny = 50, 50 - arr = np.zeros((50, 50), dtype=float) +def create_circle(nx, ny, radius): + # 1. Initialize array + arr = np.zeros((nx, ny), dtype=float) - center_x, center_y = 25, 25 - radius = 10 + # 2. Logic to create circle + center_x, center_y = nx / 2, ny / 2 x = np.arange(nx) y = np.arange(ny) - X, Y = np.meshgrid(x, y) + # Note: indexing='ij' ensures x corresponds to rows (nx) and y to cols (ny) + X, Y = np.meshgrid(x, y, indexing="ij") distance = np.sqrt((X - center_x) ** 2 + (Y - center_y) ** 2) arr[distance <= radius] = 1 + return arr + + +def test_smoothed_projection_beta_inf(): + nx, ny = 50, 50 + radius = 10 + + arr = create_circle(nx, ny, radius) filter = ConicFilter(kernel_size=5) arr_filtered = filter(arr) @@ -29,7 +40,7 @@ def test_smoothed_projection_beta_inf(): eta=0.5, ) assert not np.any(np.isinf(result) | np.isnan(result)) - assert np.isclose(result[center_x, center_y], 1) + assert np.isclose(result[round(nx / 2), round(ny / 2)], 1) assert np.isclose(result[0, -1], 0) assert np.isclose(result[0, 0], 0) assert np.isclose(result[-1, 0], 0) @@ -46,16 +57,9 @@ def test_smoothed_projection_beta_inf(): def test_smoothed_projection_beta_non_inf(): nx, ny = 50, 50 - arr = np.zeros((50, 50), dtype=float) - - center_x, center_y = 25, 25 radius = 10 - x = np.arange(nx) - y = np.arange(ny) - X, Y = np.meshgrid(x, y) - distance = np.sqrt((X - center_x) ** 2 + (Y - center_y) ** 2) - arr[distance <= radius] = 1 + arr = create_circle(nx, ny, radius) # fully discrete input should still be fully discrete output discrete_result = smoothed_projection( @@ -99,3 +103,18 @@ def _helper_fn(x): val, grad = autograd.value_and_grad(_helper_fn)(arr) assert val == 0.5 assert np.all(~(np.isnan(grad) | np.isinf(grad))) + + +@pytest.mark.parametrize("beta", [0, 5, np.inf]) +@pytest.mark.parametrize("size", [30, 50]) +@pytest.mark.parametrize("radius", [10, 15, 20]) +@pytest.mark.parametrize("smoothing_radius", [3, 5, 7]) +def test_projection_gradient_correctness(beta, size, radius, smoothing_radius): + arr = create_circle(size, size, radius) + filter = ConicFilter(kernel_size=smoothing_radius) + arr = filter(arr) + + def _helper_fn(x): + return smoothed_projection(x, beta=beta, eta=0.5).mean() + + check_grads(_helper_fn, modes=["fwd", "rev"], order=2)(arr) diff --git a/tidy3d/plugins/autograd/invdes/projections.py b/tidy3d/plugins/autograd/invdes/projections.py index 18d51e608b..ee14f38aa2 100644 --- a/tidy3d/plugins/autograd/invdes/projections.py +++ b/tidy3d/plugins/autograd/invdes/projections.py @@ -67,6 +67,8 @@ def tanh_projection( """ if beta == 0: return array + if beta == np.inf: + return np.where(array > eta, 1.0, 0.0) num = np.tanh(beta * eta) + np.tanh(beta * (array - eta)) denom = np.tanh(beta * eta) + np.tanh(beta * (1 - eta)) return num / denom @@ -102,10 +104,6 @@ def smoothed_projection( ```GridSpec.auto``` in the simulation, make sure to place a ``MeshOverrideStructure`` at the position of the optimized geometry. - .. warning:: - When using :math:`\\beta = \\infty` the function will produce NaN values if - the input is exactly equal to ``eta``. - Parameters ---------- array : np.ndarray