diff --git a/CHANGELOG.md b/CHANGELOG.md index 370f69647a..16fafa42d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Correct sign in objective function history depending on `Optimizer.maximize`. - Fix to batch mode solver run that could create multiple copies of the same folder. - Fixed ``ModeSolver.plot`` method when the simulation is not at the origin. +- Gradient calculation is orders of magnitude faster for large datasets and many structures by applying more efficient handling of field interpolation and passing to structures. ## [2.7.4] - 2024-09-25 diff --git a/tidy3d/components/autograd/derivative_utils.py b/tidy3d/components/autograd/derivative_utils.py index 612fc018f6..781148f8d9 100644 --- a/tidy3d/components/autograd/derivative_utils.py +++ b/tidy3d/components/autograd/derivative_utils.py @@ -167,7 +167,7 @@ def evaluate_flds_at( components = {} for fld_name, arr in fld_dataset.items(): - components[fld_name] = arr.interp(**interp_kwargs).sum("f") + components[fld_name] = arr.interp(**interp_kwargs, assume_sorted=True).sum("f") return components diff --git a/tidy3d/components/geometry/base.py b/tidy3d/components/geometry/base.py index 26d4c7c6eb..cd2fff7b3f 100644 --- a/tidy3d/components/geometry/base.py +++ b/tidy3d/components/geometry/base.py @@ -3214,7 +3214,7 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM _, index, *geo_path = field_path geo = self.geometries[index] geo_info = derivative_info.updated_copy( - paths=[geo_path], bounds=geo.bounds, eps_approx=True + paths=[geo_path], bounds=geo.bounds, eps_approx=True, deep=False ) vjp_dict_geo = geo.compute_derivatives(geo_info) grad_vjp_values = list(vjp_dict_geo.values()) diff --git a/tidy3d/components/geometry/primitives.py b/tidy3d/components/geometry/primitives.py index e769c4a3e4..c8316659e1 100644 --- a/tidy3d/components/geometry/primitives.py +++ b/tidy3d/components/geometry/primitives.py @@ -295,7 +295,7 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM # construct equivalent polyslab and compute the derivatives polyslab = self.to_polyslab(num_pts_circumference=num_pts_circumference) - derivative_info_polyslab = derivative_info.updated_copy(paths=[("vertices",)]) + derivative_info_polyslab = derivative_info.updated_copy(paths=[("vertices",)], deep=False) vjps_polyslab = polyslab.compute_derivatives(derivative_info_polyslab) vjps_vertices_xs, vjps_vertices_ys = vjps_polyslab[("vertices",)].T diff --git a/tidy3d/components/medium.py b/tidy3d/components/medium.py index ac29f31d3c..5e60901fc0 100644 --- a/tidy3d/components/medium.py +++ b/tidy3d/components/medium.py @@ -1426,7 +1426,9 @@ def _derivative_field_cmp( # TODO: probably this could be more robust. eg if the DataArray has weird edge cases E_der_dim = E_der_map[f"E{dim}"] - E_der_dim_interp = E_der_dim.interp(**coords_interp).fillna(0.0).sum(dims_sum).sum("f") + E_der_dim_interp = ( + E_der_dim.interp(**coords_interp, assume_sorted=True).fillna(0.0).sum(dims_sum).sum("f") + ) vjp_array = np.array(E_der_dim_interp.values).astype(complex) vjp_array = vjp_array.reshape(eps_data.shape) @@ -2618,7 +2620,9 @@ def _derivative_field_cmp( # TODO: probably this could be more robust. eg if the DataArray has weird edge cases E_der_dim = E_der_map[f"E{dim}"] - E_der_dim_interp = E_der_dim.interp(**coords_interp).fillna(0.0).sum(dims_sum).real + E_der_dim_interp = ( + E_der_dim.interp(**coords_interp, assume_sorted=True).fillna(0.0).sum(dims_sum).real + ) E_der_dim_interp = E_der_dim_interp.sum("f") vjp_array = np.array(E_der_dim_interp.values, dtype=float) diff --git a/tidy3d/components/structure.py b/tidy3d/components/structure.py index fdfee65fe3..518dc3e3ba 100644 --- a/tidy3d/components/structure.py +++ b/tidy3d/components/structure.py @@ -256,7 +256,7 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM for med_or_geo, field_paths in structure_fields_map.items(): # grab derivative values {field_name -> vjp_value} med_or_geo_field = self.medium if med_or_geo == "medium" else self.geometry - info = derivative_info.updated_copy(paths=field_paths) + info = derivative_info.updated_copy(paths=field_paths, deep=False) derivative_values_map = med_or_geo_field.compute_derivatives(derivative_info=info) # construct map of {field path -> derivative value}