Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,13 @@ python_files = "*.py"
[tool.mypy]
python_version = "3.10"
files = [
"tidy3d/components/geometry",
"tidy3d/components/autograd",
"tidy3d/components/data",
"tidy3d/components/geometry",
"tidy3d/components/grid",
"tidy3d/components/mode",
"tidy3d/components/tcad",
"tidy3d/components/viz",
"tidy3d/config",
"tidy3d/material_library",
"tidy3d/plugins",
Expand Down
4 changes: 2 additions & 2 deletions tidy3d/components/autograd/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


@classmethod
def from_arraybox(cls, box: ArrayBox) -> TidyArrayBox:
def from_arraybox(cls: Any, box: ArrayBox) -> TidyArrayBox:
"""Construct a TidyArrayBox from an ArrayBox."""
return cls(box._value, box._trace, box._node)

Expand Down Expand Up @@ -142,7 +142,7 @@ def __array_ufunc__(
return NotImplemented


def item(self):
def item(self: Any) -> Any:
if self.size != 1:
raise ValueError("Can only convert an array of size 1 to a scalar")
return anp.ravel(self)[0]
Expand Down
86 changes: 54 additions & 32 deletions tidy3d/components/autograd/derivative_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

import numpy as np
import xarray as xr
from numpy.typing import NDArray

from tidy3d.compat import Self
from tidy3d.components.data.data_array import FreqDataArray, ScalarFieldDataArray
from tidy3d.components.types import ArrayLike, Bound, Complex
from tidy3d.config import config
Expand All @@ -20,17 +22,19 @@
FieldData = dict[str, ScalarFieldDataArray]
PermittivityData = dict[str, ScalarFieldDataArray]
EpsType = Union[Complex, FreqDataArray]
ArrayFloat = NDArray[np.floating]
ArrayComplex = NDArray[np.complexfloating]


class LazyInterpolator:
"""Lazy wrapper for interpolators that creates them on first access."""

def __init__(self, creator_func: Callable) -> None:
def __init__(self, creator_func: Callable[[], Callable[[ArrayFloat], ArrayComplex]]) -> None:
"""Initialize with a function that creates the interpolator when called."""
self.creator_func = creator_func
self._interpolator = None
self._interpolator: Optional[Callable[[ArrayFloat], ArrayComplex]] = None

def __call__(self, *args: Any, **kwargs: Any):
def __call__(self, *args: Any, **kwargs: Any) -> ArrayComplex:
"""Create interpolator on first call and delegate to it."""
if self._interpolator is None:
self._interpolator = self.creator_func()
Expand Down Expand Up @@ -172,14 +176,16 @@ class DerivativeInfo:
# private cache for interpolators
_interpolators_cache: dict = field(default_factory=dict, init=False, repr=False)

def updated_copy(self, **kwargs: Any):
def updated_copy(self, **kwargs: Any) -> Self:
"""Create a copy with updated fields."""
kwargs.pop("deep", None)
kwargs.pop("validate", None)
return replace(self, **kwargs)

@staticmethod
def _nan_to_num_if_needed(coords: np.ndarray) -> np.ndarray:
def _nan_to_num_if_needed(
coords: Union[ArrayFloat, ArrayComplex],
) -> Union[ArrayFloat, ArrayComplex]:
"""Convert NaN and infinite values to finite numbers, optimized for finite inputs."""
# skip check for small arrays
if coords.size < 1000:
Expand All @@ -191,8 +197,9 @@ def _nan_to_num_if_needed(coords: np.ndarray) -> np.ndarray:

@staticmethod
def _evaluate_with_interpolators(
interpolators: dict, coords: np.ndarray
) -> dict[str, np.ndarray]:
interpolators: dict[str, Callable[[ArrayFloat], ArrayComplex]],
coords: ArrayFloat,
) -> dict[str, ArrayComplex]:
"""Evaluate field components at coordinates using cached interpolators.

Parameters
Expand All @@ -216,7 +223,7 @@ def _evaluate_with_interpolators(
coords = coords.astype(float_dtype, copy=False)
return {name: interp(coords) for name, interp in interpolators.items()}

def create_interpolators(self, dtype: Optional[np.dtype] = None) -> dict:
def create_interpolators(self, dtype: Optional[np.dtype[Any]] = None) -> dict[str, Any]:
"""Create interpolators for field components and permittivity data.

Creates and caches ``RegularGridInterpolator`` objects for all field components
Expand All @@ -226,7 +233,7 @@ def create_interpolators(self, dtype: Optional[np.dtype] = None) -> dict:

Parameters
----------
dtype : np.dtype, optional
dtype : np.dtype[Any], optional = None
Data type for interpolation coordinates and values. Defaults to the
current ``config.adjoint.gradient_dtype_float``.

Expand All @@ -251,8 +258,14 @@ def create_interpolators(self, dtype: Optional[np.dtype] = None) -> dict:
interpolators = {}
coord_cache = {}

def _make_lazy_interpolator_group(field_data_dict, group_key, is_field_group=True) -> None:
def _make_lazy_interpolator_group(
field_data_dict: Optional[FieldData],
group_key: Optional[str],
is_field_group: bool = True,
) -> None:
"""Helper to create a group of lazy interpolators."""
if not field_data_dict:
return
if is_field_group:
interpolators[group_key] = {}

Expand All @@ -264,7 +277,10 @@ def _make_lazy_interpolator_group(field_data_dict, group_key, is_field_group=Tru
coord_cache[arr_id] = points
points = coord_cache[arr_id]

def creator_func(arr=arr, points=points):
def creator_func(
arr: ScalarFieldDataArray = arr,
points: tuple[np.ndarray, ...] = points,
) -> Callable[[ArrayFloat], ArrayComplex]:
data = arr.data.astype(
complex_dtype if np.iscomplexobj(arr.data) else dtype, copy=False
)
Expand Down Expand Up @@ -292,7 +308,7 @@ def creator_func(arr=arr, points=points):
points_with_freq, data, method=method, bounds_error=False, fill_value=None
)

def interpolator(coords):
def interpolator(coords: ArrayFloat) -> ArrayComplex:
# coords: (N, 3) spatial points
n_points = coords.shape[0]
n_freqs = len(freq_coords)
Expand Down Expand Up @@ -399,14 +415,13 @@ def evaluate_gradient_at_points(

def _evaluate_dielectric_gradient_at_points(
self,
spatial_coords: np.ndarray,
normals: np.ndarray,
perps1: np.ndarray,
perps2: np.ndarray,
interpolators: dict,
# todo: type
eps_out,
) -> np.ndarray:
spatial_coords: ArrayFloat,
normals: ArrayFloat,
perps1: ArrayFloat,
perps2: ArrayFloat,
interpolators: dict[str, dict[str, Callable[[ArrayFloat], ArrayComplex]]],
eps_out: ArrayComplex,
) -> ArrayComplex:
# evaluate all field components at surface points
E_fwd_at_coords = {
name: interp(spatial_coords) for name, interp in interpolators["E_fwd"].items()
Expand Down Expand Up @@ -449,15 +464,16 @@ def _evaluate_dielectric_gradient_at_points(

def _evaluate_pec_gradient_at_points(
self,
spatial_coords: np.ndarray,
normals: np.ndarray,
perps1: np.ndarray,
perps2: np.ndarray,
interpolators: dict,
# todo: type
eps_out,
) -> np.ndarray:
def _adjust_spatial_coords_pec(grid_centers: dict[str, np.ndarray]):
spatial_coords: ArrayFloat,
normals: ArrayFloat,
perps1: ArrayFloat,
perps2: ArrayFloat,
interpolators: dict[str, dict[str, Callable[[ArrayFloat], ArrayComplex]]],
eps_out: ArrayComplex,
) -> ArrayComplex:
def _adjust_spatial_coords_pec(
grid_centers: dict[str, ArrayFloat],
) -> tuple[ArrayFloat, ArrayFloat]:
"""Assuming a nearest interpolation, adjust the interpolation points given the grid
defined by `grid_centers` and using `spatial_coords` as a starting point such that we
select a point outside of the PEC boundary.
Expand Down Expand Up @@ -534,7 +550,9 @@ def _adjust_spatial_coords_pec(grid_centers: dict[str, np.ndarray]):

return adjust_spatial_coords, edge_distance

def _snap_coordinate_outside(field_components: FieldData):
def _snap_coordinate_outside(
field_components: FieldData,
) -> dict[str, dict[str, ArrayFloat]]:
"""Helper function to perform coordinate adjustment and compute edge distance for each
component in `field_components`.

Expand Down Expand Up @@ -565,7 +583,9 @@ def _snap_coordinate_outside(field_components: FieldData):

return adjustment

def _interpolate_field_components(interp_coords, field_name):
def _interpolate_field_components(
interp_coords: dict[str, dict[str, ArrayFloat]], field_name: str
) -> dict[str, ArrayComplex]:
return {
name: interp(interp_coords[name]["coords"])
for name, interp in interpolators[field_name].items()
Expand Down Expand Up @@ -596,7 +616,9 @@ def _interpolate_field_components(interp_coords, field_name):
# on of the H field integration components and apply singularity correction
pec_line_integration = is_flat_perp_dim1 or is_flat_perp_dim2

def _compute_singularity_correction(adjustment_: dict[str, dict[str, np.ndarray]]):
def _compute_singularity_correction(
adjustment_: dict[str, dict[str, ArrayFloat]],
) -> ArrayFloat:
"""
Given the `adjustment_` which contains the distance from the PEC edge each field
component is nearest interpolated at, computes the singularity correction when
Expand Down
10 changes: 5 additions & 5 deletions tidy3d/components/autograd/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from __future__ import annotations

import copy
from typing import Annotated, Literal, Optional, Union, get_origin
from typing import Annotated, Any, Literal, Optional, Union, get_origin

import autograd.numpy as anp
from autograd.builtins import dict as TracedDict
from autograd.extend import Box, defvjp, primitive
from autograd.numpy.numpy_boxes import ArrayBox
from pydantic import BeforeValidator, PlainSerializer, PositiveFloat, TypeAdapter
from pydantic import BeforeValidator, PlainSerializer, PositiveFloat, SerializationInfo, TypeAdapter

from tidy3d.compat import TypeAlias
from tidy3d.components.types import ArrayFloat2D, ArrayLike, Complex, Size1D
Expand All @@ -35,10 +35,10 @@
Box.__repr__ = Box.__str__


def traced_alias(base_alias, *, name: Optional[str] = None) -> TypeAlias:
def traced_alias(base_alias: Any, *, name: Optional[str] = None) -> TypeAlias:
base_adapter = TypeAdapter(base_alias, config={"arbitrary_types_allowed": True})

def _validate_box_or_container(v):
def _validate_box_or_container(v: Any) -> Any:
# case 1: v itself is a tracer
# in this case we just validate but leave the tracer untouched
if isinstance(v, Box):
Expand Down Expand Up @@ -78,7 +78,7 @@ def _validate_box_or_container(v):

return base_adapter.validate_python(v)

def _serialize_traced(a, info):
def _serialize_traced(a: Any, info: SerializationInfo) -> Any:
return _auto_serializer(get_static(a), info)

return Annotated[
Expand Down
8 changes: 5 additions & 3 deletions tidy3d/components/autograd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from __future__ import annotations

from collections.abc import Iterable, Mapping, Sequence
from typing import Any
from typing import Any, Union

import autograd.numpy as anp
from autograd.numpy.numpy_boxes import ArrayBox
from autograd.tracer import getval, isbox
from numpy.typing import ArrayLike, NDArray

__all__ = [
"asarray1d",
Expand Down Expand Up @@ -66,12 +68,12 @@ def hasbox(obj: Any) -> bool:
return False


def pack_complex_vec(z):
def pack_complex_vec(z: Union[NDArray, ArrayBox]) -> Union[NDArray, ArrayBox]:
"""Ravel [Re(z); Im(z)] into one real vector (autograd-safe)."""
return anp.concatenate([anp.ravel(anp.real(z)), anp.ravel(anp.imag(z))])


def asarray1d(x):
def asarray1d(x: Union[ArrayLike, ArrayBox]) -> Union[NDArray, ArrayBox]:
"""Autograd-friendly 1D flatten: returns ndarray of shape (-1,)."""
x = anp.array(x)
return x if x.ndim == 1 else anp.ravel(x)
7 changes: 4 additions & 3 deletions tidy3d/components/grid/corner_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Literal, Optional

import numpy as np
from numpy.typing import NDArray
from pydantic import Field, PositiveFloat, PositiveInt

from tidy3d.components.base import Tidy3dBaseModel, cached_property
Expand Down Expand Up @@ -68,7 +69,7 @@ class CornerFinderSpec(Tidy3dBaseModel):
)

@cached_property
def _no_min_dl_override(self):
def _no_min_dl_override(self) -> bool:
return all(
(
self.concave_resolution is None,
Expand Down Expand Up @@ -197,7 +198,7 @@ def _corners_and_convexity(
return self._ravel_corners_and_convexity(ravel, corner_list, convexity_list)

def _ravel_corners_and_convexity(
self, ravel: bool, corner_list, convexity_list
self, ravel: bool, corner_list: list[ArrayFloat2D], convexity_list: list[ArrayFloat1D]
) -> tuple[ArrayFloat2D, ArrayFloat1D]:
"""Whether to put the resulting corners in a single list or per polygon."""
if ravel and len(corner_list) > 0:
Expand Down Expand Up @@ -259,7 +260,7 @@ def _filter_collinear_vertices(
Convexity of corners: True for outer corners, False for inner corners.
"""

def normalize(v):
def normalize(v: NDArray) -> NDArray:
return v / np.linalg.norm(v, axis=-1)[:, np.newaxis]

# drop the last vertex, which is identical to the 1st one.
Expand Down
Loading