From af00a5b81a080f00f2592c1fcf71d3e160ef0408 Mon Sep 17 00:00:00 2001 From: marcorudolphflex Date: Tue, 25 Nov 2025 09:47:04 +0100 Subject: [PATCH] chore(tidy3d): FXC-4301 Add mypy typedefs in components subpackages autograd, grid, mode and viz --- pyproject.toml | 6 +- tidy3d/components/autograd/boxes.py | 4 +- .../components/autograd/derivative_utils.py | 86 +++++--- tidy3d/components/autograd/types.py | 10 +- tidy3d/components/autograd/utils.py | 8 +- tidy3d/components/grid/corner_finder.py | 7 +- tidy3d/components/grid/grid.py | 27 +-- tidy3d/components/grid/grid_spec.py | 102 ++++++--- tidy3d/components/grid/mesher.py | 4 +- tidy3d/components/mode/derivatives.py | 78 +++++-- tidy3d/components/mode/mode_solver.py | 58 +++--- tidy3d/components/mode/simulation.py | 19 +- tidy3d/components/mode/solver.py | 193 ++++++++++-------- tidy3d/components/mode/transforms.py | 21 +- tidy3d/components/viz/axes_utils.py | 27 ++- tidy3d/components/viz/descartes.py | 15 +- tidy3d/components/viz/plot_params.py | 5 +- tidy3d/components/viz/plot_sim_3d.py | 14 +- tidy3d/components/viz/visualization_spec.py | 4 +- 19 files changed, 450 insertions(+), 238 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e73b2bf00a..b1090a07e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -341,9 +341,13 @@ python_files = "*.py" [tool.mypy] python_version = "3.10" files = [ - "tidy3d/components/geometry", + "tidy3d/components/autograd", "tidy3d/components/data", + "tidy3d/components/geometry", + "tidy3d/components/grid", + "tidy3d/components/mode", "tidy3d/components/tcad", + "tidy3d/components/viz", "tidy3d/config", "tidy3d/material_library", "tidy3d/plugins", diff --git a/tidy3d/components/autograd/boxes.py b/tidy3d/components/autograd/boxes.py index 481c67ee82..75f5193095 100644 --- a/tidy3d/components/autograd/boxes.py +++ b/tidy3d/components/autograd/boxes.py @@ -26,7 +26,7 @@ @classmethod -def from_arraybox(cls, box: ArrayBox) -> TidyArrayBox: +def from_arraybox(cls: Any, box: ArrayBox) -> TidyArrayBox: """Construct a TidyArrayBox from an ArrayBox.""" return cls(box._value, box._trace, box._node) @@ -142,7 +142,7 @@ def __array_ufunc__( return NotImplemented -def item(self): +def item(self: Any) -> Any: if self.size != 1: raise ValueError("Can only convert an array of size 1 to a scalar") return anp.ravel(self)[0] diff --git a/tidy3d/components/autograd/derivative_utils.py b/tidy3d/components/autograd/derivative_utils.py index ad55d1d337..eed50e60f0 100644 --- a/tidy3d/components/autograd/derivative_utils.py +++ b/tidy3d/components/autograd/derivative_utils.py @@ -7,7 +7,9 @@ import numpy as np import xarray as xr +from numpy.typing import NDArray +from tidy3d.compat import Self from tidy3d.components.data.data_array import FreqDataArray, ScalarFieldDataArray from tidy3d.components.types import ArrayLike, Bound, Complex from tidy3d.config import config @@ -20,17 +22,19 @@ FieldData = dict[str, ScalarFieldDataArray] PermittivityData = dict[str, ScalarFieldDataArray] EpsType = Union[Complex, FreqDataArray] +ArrayFloat = NDArray[np.floating] +ArrayComplex = NDArray[np.complexfloating] class LazyInterpolator: """Lazy wrapper for interpolators that creates them on first access.""" - def __init__(self, creator_func: Callable) -> None: + def __init__(self, creator_func: Callable[[], Callable[[ArrayFloat], ArrayComplex]]) -> None: """Initialize with a function that creates the interpolator when called.""" self.creator_func = creator_func - self._interpolator = None + self._interpolator: Optional[Callable[[ArrayFloat], ArrayComplex]] = None - def __call__(self, *args: Any, **kwargs: Any): + def __call__(self, *args: Any, **kwargs: Any) -> ArrayComplex: """Create interpolator on first call and delegate to it.""" if self._interpolator is None: self._interpolator = self.creator_func() @@ -172,14 +176,16 @@ class DerivativeInfo: # private cache for interpolators _interpolators_cache: dict = field(default_factory=dict, init=False, repr=False) - def updated_copy(self, **kwargs: Any): + def updated_copy(self, **kwargs: Any) -> Self: """Create a copy with updated fields.""" kwargs.pop("deep", None) kwargs.pop("validate", None) return replace(self, **kwargs) @staticmethod - def _nan_to_num_if_needed(coords: np.ndarray) -> np.ndarray: + def _nan_to_num_if_needed( + coords: Union[ArrayFloat, ArrayComplex], + ) -> Union[ArrayFloat, ArrayComplex]: """Convert NaN and infinite values to finite numbers, optimized for finite inputs.""" # skip check for small arrays if coords.size < 1000: @@ -191,8 +197,9 @@ def _nan_to_num_if_needed(coords: np.ndarray) -> np.ndarray: @staticmethod def _evaluate_with_interpolators( - interpolators: dict, coords: np.ndarray - ) -> dict[str, np.ndarray]: + interpolators: dict[str, Callable[[ArrayFloat], ArrayComplex]], + coords: ArrayFloat, + ) -> dict[str, ArrayComplex]: """Evaluate field components at coordinates using cached interpolators. Parameters @@ -216,7 +223,7 @@ def _evaluate_with_interpolators( coords = coords.astype(float_dtype, copy=False) return {name: interp(coords) for name, interp in interpolators.items()} - def create_interpolators(self, dtype: Optional[np.dtype] = None) -> dict: + def create_interpolators(self, dtype: Optional[np.dtype[Any]] = None) -> dict[str, Any]: """Create interpolators for field components and permittivity data. Creates and caches ``RegularGridInterpolator`` objects for all field components @@ -226,7 +233,7 @@ def create_interpolators(self, dtype: Optional[np.dtype] = None) -> dict: Parameters ---------- - dtype : np.dtype, optional + dtype : np.dtype[Any], optional = None Data type for interpolation coordinates and values. Defaults to the current ``config.adjoint.gradient_dtype_float``. @@ -251,8 +258,14 @@ def create_interpolators(self, dtype: Optional[np.dtype] = None) -> dict: interpolators = {} coord_cache = {} - def _make_lazy_interpolator_group(field_data_dict, group_key, is_field_group=True) -> None: + def _make_lazy_interpolator_group( + field_data_dict: Optional[FieldData], + group_key: Optional[str], + is_field_group: bool = True, + ) -> None: """Helper to create a group of lazy interpolators.""" + if not field_data_dict: + return if is_field_group: interpolators[group_key] = {} @@ -264,7 +277,10 @@ def _make_lazy_interpolator_group(field_data_dict, group_key, is_field_group=Tru coord_cache[arr_id] = points points = coord_cache[arr_id] - def creator_func(arr=arr, points=points): + def creator_func( + arr: ScalarFieldDataArray = arr, + points: tuple[np.ndarray, ...] = points, + ) -> Callable[[ArrayFloat], ArrayComplex]: data = arr.data.astype( complex_dtype if np.iscomplexobj(arr.data) else dtype, copy=False ) @@ -292,7 +308,7 @@ def creator_func(arr=arr, points=points): points_with_freq, data, method=method, bounds_error=False, fill_value=None ) - def interpolator(coords): + def interpolator(coords: ArrayFloat) -> ArrayComplex: # coords: (N, 3) spatial points n_points = coords.shape[0] n_freqs = len(freq_coords) @@ -399,14 +415,13 @@ def evaluate_gradient_at_points( def _evaluate_dielectric_gradient_at_points( self, - spatial_coords: np.ndarray, - normals: np.ndarray, - perps1: np.ndarray, - perps2: np.ndarray, - interpolators: dict, - # todo: type - eps_out, - ) -> np.ndarray: + spatial_coords: ArrayFloat, + normals: ArrayFloat, + perps1: ArrayFloat, + perps2: ArrayFloat, + interpolators: dict[str, dict[str, Callable[[ArrayFloat], ArrayComplex]]], + eps_out: ArrayComplex, + ) -> ArrayComplex: # evaluate all field components at surface points E_fwd_at_coords = { name: interp(spatial_coords) for name, interp in interpolators["E_fwd"].items() @@ -449,15 +464,16 @@ def _evaluate_dielectric_gradient_at_points( def _evaluate_pec_gradient_at_points( self, - spatial_coords: np.ndarray, - normals: np.ndarray, - perps1: np.ndarray, - perps2: np.ndarray, - interpolators: dict, - # todo: type - eps_out, - ) -> np.ndarray: - def _adjust_spatial_coords_pec(grid_centers: dict[str, np.ndarray]): + spatial_coords: ArrayFloat, + normals: ArrayFloat, + perps1: ArrayFloat, + perps2: ArrayFloat, + interpolators: dict[str, dict[str, Callable[[ArrayFloat], ArrayComplex]]], + eps_out: ArrayComplex, + ) -> ArrayComplex: + def _adjust_spatial_coords_pec( + grid_centers: dict[str, ArrayFloat], + ) -> tuple[ArrayFloat, ArrayFloat]: """Assuming a nearest interpolation, adjust the interpolation points given the grid defined by `grid_centers` and using `spatial_coords` as a starting point such that we select a point outside of the PEC boundary. @@ -534,7 +550,9 @@ def _adjust_spatial_coords_pec(grid_centers: dict[str, np.ndarray]): return adjust_spatial_coords, edge_distance - def _snap_coordinate_outside(field_components: FieldData): + def _snap_coordinate_outside( + field_components: FieldData, + ) -> dict[str, dict[str, ArrayFloat]]: """Helper function to perform coordinate adjustment and compute edge distance for each component in `field_components`. @@ -565,7 +583,9 @@ def _snap_coordinate_outside(field_components: FieldData): return adjustment - def _interpolate_field_components(interp_coords, field_name): + def _interpolate_field_components( + interp_coords: dict[str, dict[str, ArrayFloat]], field_name: str + ) -> dict[str, ArrayComplex]: return { name: interp(interp_coords[name]["coords"]) for name, interp in interpolators[field_name].items() @@ -596,7 +616,9 @@ def _interpolate_field_components(interp_coords, field_name): # on of the H field integration components and apply singularity correction pec_line_integration = is_flat_perp_dim1 or is_flat_perp_dim2 - def _compute_singularity_correction(adjustment_: dict[str, dict[str, np.ndarray]]): + def _compute_singularity_correction( + adjustment_: dict[str, dict[str, ArrayFloat]], + ) -> ArrayFloat: """ Given the `adjustment_` which contains the distance from the PEC edge each field component is nearest interpolated at, computes the singularity correction when diff --git a/tidy3d/components/autograd/types.py b/tidy3d/components/autograd/types.py index d853647b71..751c8127cd 100644 --- a/tidy3d/components/autograd/types.py +++ b/tidy3d/components/autograd/types.py @@ -3,13 +3,13 @@ from __future__ import annotations import copy -from typing import Annotated, Literal, Optional, Union, get_origin +from typing import Annotated, Any, Literal, Optional, Union, get_origin import autograd.numpy as anp from autograd.builtins import dict as TracedDict from autograd.extend import Box, defvjp, primitive from autograd.numpy.numpy_boxes import ArrayBox -from pydantic import BeforeValidator, PlainSerializer, PositiveFloat, TypeAdapter +from pydantic import BeforeValidator, PlainSerializer, PositiveFloat, SerializationInfo, TypeAdapter from tidy3d.compat import TypeAlias from tidy3d.components.types import ArrayFloat2D, ArrayLike, Complex, Size1D @@ -35,10 +35,10 @@ Box.__repr__ = Box.__str__ -def traced_alias(base_alias, *, name: Optional[str] = None) -> TypeAlias: +def traced_alias(base_alias: Any, *, name: Optional[str] = None) -> TypeAlias: base_adapter = TypeAdapter(base_alias, config={"arbitrary_types_allowed": True}) - def _validate_box_or_container(v): + def _validate_box_or_container(v: Any) -> Any: # case 1: v itself is a tracer # in this case we just validate but leave the tracer untouched if isinstance(v, Box): @@ -78,7 +78,7 @@ def _validate_box_or_container(v): return base_adapter.validate_python(v) - def _serialize_traced(a, info): + def _serialize_traced(a: Any, info: SerializationInfo) -> Any: return _auto_serializer(get_static(a), info) return Annotated[ diff --git a/tidy3d/components/autograd/utils.py b/tidy3d/components/autograd/utils.py index 5587ba085d..5dbf4a594e 100644 --- a/tidy3d/components/autograd/utils.py +++ b/tidy3d/components/autograd/utils.py @@ -2,10 +2,12 @@ from __future__ import annotations from collections.abc import Iterable, Mapping, Sequence -from typing import Any +from typing import Any, Union import autograd.numpy as anp +from autograd.numpy.numpy_boxes import ArrayBox from autograd.tracer import getval, isbox +from numpy.typing import ArrayLike, NDArray __all__ = [ "asarray1d", @@ -66,12 +68,12 @@ def hasbox(obj: Any) -> bool: return False -def pack_complex_vec(z): +def pack_complex_vec(z: Union[NDArray, ArrayBox]) -> Union[NDArray, ArrayBox]: """Ravel [Re(z); Im(z)] into one real vector (autograd-safe).""" return anp.concatenate([anp.ravel(anp.real(z)), anp.ravel(anp.imag(z))]) -def asarray1d(x): +def asarray1d(x: Union[ArrayLike, ArrayBox]) -> Union[NDArray, ArrayBox]: """Autograd-friendly 1D flatten: returns ndarray of shape (-1,).""" x = anp.array(x) return x if x.ndim == 1 else anp.ravel(x) diff --git a/tidy3d/components/grid/corner_finder.py b/tidy3d/components/grid/corner_finder.py index e2a3c9353a..8b4b6c6286 100644 --- a/tidy3d/components/grid/corner_finder.py +++ b/tidy3d/components/grid/corner_finder.py @@ -5,6 +5,7 @@ from typing import Any, Literal, Optional import numpy as np +from numpy.typing import NDArray from pydantic import Field, PositiveFloat, PositiveInt from tidy3d.components.base import Tidy3dBaseModel, cached_property @@ -68,7 +69,7 @@ class CornerFinderSpec(Tidy3dBaseModel): ) @cached_property - def _no_min_dl_override(self): + def _no_min_dl_override(self) -> bool: return all( ( self.concave_resolution is None, @@ -197,7 +198,7 @@ def _corners_and_convexity( return self._ravel_corners_and_convexity(ravel, corner_list, convexity_list) def _ravel_corners_and_convexity( - self, ravel: bool, corner_list, convexity_list + self, ravel: bool, corner_list: list[ArrayFloat2D], convexity_list: list[ArrayFloat1D] ) -> tuple[ArrayFloat2D, ArrayFloat1D]: """Whether to put the resulting corners in a single list or per polygon.""" if ravel and len(corner_list) > 0: @@ -259,7 +260,7 @@ def _filter_collinear_vertices( Convexity of corners: True for outer corners, False for inner corners. """ - def normalize(v): + def normalize(v: NDArray) -> NDArray: return v / np.linalg.norm(v, axis=-1)[:, np.newaxis] # drop the last vertex, which is identical to the 1st one. diff --git a/tidy3d/components/grid/grid.py b/tidy3d/components/grid/grid.py index a38739a2b9..99e3c27bfd 100644 --- a/tidy3d/components/grid/grid.py +++ b/tidy3d/components/grid/grid.py @@ -2,9 +2,10 @@ from __future__ import annotations -from typing import Literal, Union +from typing import Any, Literal, Self, Union import numpy as np +from numpy.typing import NDArray from pydantic import Field from tidy3d.components.base import Tidy3dBaseModel, cached_property @@ -45,12 +46,12 @@ class Coords(Tidy3dBaseModel): ) @property - def to_dict(self): + def to_dict(self) -> dict[str, Any]: """Return a dict of the three Coord1D objects as numpy arrays.""" return {key: self.model_dump()[key] for key in "xyz"} @property - def to_list(self): + def to_list(self) -> list[NDArray]: """Return a list of the three Coord1D objects as numpy arrays.""" return list(self.to_dict.values()) @@ -75,7 +76,7 @@ def cell_sizes(self) -> SpatialDataArray: return cell_sizes @cached_property - def cell_size_meshgrid(self): + def cell_size_meshgrid(self) -> NDArray: """Returns an N-dimensional grid where N is the number of coordinate arrays that have more than one element. Each grid element corresponds to the size of the mesh cell in N-dimensions and 1 for N=0.""" coord_dict = self.to_dict @@ -323,7 +324,7 @@ class YeeGrid(Tidy3dBaseModel): ) @property - def grid_dict(self): + def grid_dict(self) -> dict[str, Coords]: """The Yee grid coordinates associated to various field components as a dictionary.""" return { "Ex": self.E.x, @@ -356,12 +357,12 @@ class Grid(Tidy3dBaseModel): ) @staticmethod - def _avg(coords1d: Coords1D): + def _avg(coords1d: Coords1D) -> Coords1D: """Return average positions of an array of 1D coordinates.""" return (coords1d[1:] + coords1d[:-1]) / 2.0 @staticmethod - def _min(coords1d: Coords1D): + def _min(coords1d: Coords1D) -> Coords1D: """Return minus positions of 1D coordinates.""" return coords1d[:-1] @@ -535,7 +536,7 @@ def __getitem__(self, coord_key: str) -> Coords: return coord_dict.get(coord_key) - def _yee_e(self, axis: Axis): + def _yee_e(self, axis: Axis) -> Coords: """E field yee lattice sites for axis.""" boundary_coords = self.boundaries.to_dict @@ -549,7 +550,7 @@ def _yee_e(self, axis: Axis): return Coords(**yee_coords) - def _yee_h(self, axis: Axis): + def _yee_h(self, axis: Axis) -> Coords: """H field yee lattice sites for axis.""" boundary_coords = self.boundaries.to_dict @@ -679,7 +680,7 @@ def extended_subspace( return padded_coords[ind_beg:ind_end] - def snap_to_box_zero_dim(self, box: Box): + def snap_to_box_zero_dim(self, box: Box) -> Self: """Snap a grid to an exact box position for dimensions for which the box is size 0. If the box location is outside of the grid, an error is raised. @@ -714,7 +715,9 @@ def _translated_copy(self, vector: Coordinate) -> Grid: ) return self.updated_copy(boundaries=boundaries) - def _get_geo_inds(self, geo: Geometry, span_inds: ArrayLike = None, expand_inds: int = 2): + def _get_geo_inds( + self, geo: Geometry, span_inds: ArrayLike = None, expand_inds: int = 2 + ) -> NDArray: """ Get ``geo_inds`` based on a geometry's bounding box, enlarged by ``expand_inds``. If ``span_inds`` is supplied, take the intersection of ``span_inds`` and ``geo``'s bounding @@ -731,7 +734,7 @@ def _get_geo_inds(self, geo: Geometry, span_inds: ArrayLike = None, expand_inds: Returns ------- - List[Tuple[int, int]] + np.ndarray The (start, stop) indexes of the cells for interpolation. """ # only interpolate inside the bounding box diff --git a/tidy3d/components/grid/grid_spec.py b/tidy3d/components/grid/grid_spec.py index bfe27c16ad..ca8c6f15d2 100644 --- a/tidy3d/components/grid/grid_spec.py +++ b/tidy3d/components/grid/grid_spec.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Literal, Optional, Union +from typing import Any, Literal, Optional, Self, Union import numpy as np from pydantic import ( @@ -30,6 +30,7 @@ Coordinate, CoordinateOptional, PriorityMode, + Shapely, Symmetry, Undefined, ) @@ -52,6 +53,8 @@ # Tolerance for distinguishing pec/grid intersections GAP_MESHING_TOL = 1e-3 +CornersAndConvexity = tuple[list[ArrayFloat2D], list[ArrayFloat1D]] + class GridSpec1d(Tidy3dBaseModel, ABC): """Abstract base class, defines 1D grid generation specifications.""" @@ -295,7 +298,7 @@ class UniformGrid(GridSpec1d): @field_validator("dl") @classmethod - def _validate_dl(cls, val): + def _validate_dl(cls, val: PositiveFloat) -> PositiveFloat: """ Ensure 'dl' is not too small. """ @@ -430,7 +433,7 @@ def estimated_min_dl( @field_validator("coords") @classmethod - def _validate_coords(cls, val): + def _validate_coords(cls, val: Coords1D) -> Coords1D: """ Ensure 'coords' is sorted and has at least 2 entries. """ @@ -1093,7 +1096,7 @@ class LayerRefinementSpec(Box): ) @model_validator(mode="after") - def _finite_size_along_axis(self): + def _finite_size_along_axis(self) -> Self: if self.size is None: return self """size must be finite along axis.""" @@ -1116,7 +1119,7 @@ def from_layer_bounds( gap_meshing_iters: NonNegativeInt = 1, dl_min_from_gap_width: bool = True, **kwargs: Any, - ): + ) -> Self: """Constructs a :class:`LayerRefinementSpec` that is unbounded in inplane dimensions from bounds along layer thickness dimension. @@ -1192,7 +1195,7 @@ def from_bounds( gap_meshing_iters: NonNegativeInt = 1, dl_min_from_gap_width: bool = True, **kwargs: Any, - ): + ) -> Self: """Constructs a :class:`LayerRefinementSpec` from minimum and maximum coordinate bounds. Parameters @@ -1269,7 +1272,7 @@ def from_structures( gap_meshing_iters: NonNegativeInt = 1, dl_min_from_gap_width: bool = True, **kwargs: Any, - ): + ) -> Self: """Constructs a :class:`LayerRefinementSpec` from the bounding box of a list of structures. Parameters @@ -1409,7 +1412,9 @@ def suggested_dl_min(self, grid_size_in_vacuum: float, structures: list[Structur return dl_min def generate_snapping_points( - self, structure_list: list[Structure], cached_corners_and_convexity=None + self, + structure_list: list[Structure], + cached_corners_and_convexity: Optional[CornersAndConvexity] = None, ) -> list[CoordinateOptional]: """generate snapping points for mesh refinement.""" snapping_points = self._snapping_points_along_axis @@ -1421,7 +1426,7 @@ def generate_override_structures( self, grid_size_in_vacuum: float, structure_list: list[Structure], - cached_corners_and_convexity=None, + cached_corners_and_convexity: Optional[CornersAndConvexity] = None, ) -> list[MeshOverrideStructure]: """Generate mesh override structures for mesh refinement.""" return self._override_structures_along_axis( @@ -1484,7 +1489,7 @@ def _corners_and_convexity_2d( return inplane_points, convexity - def _dl_min_from_smallest_feature(self, structure_list: list[Structure]): + def _dl_min_from_smallest_feature(self, structure_list: list[Structure]) -> float: """Calculate `dl_min` suggestion based on smallest feature size.""" inplane_points, convexity = self._corners_and_convexity_2d( @@ -1523,7 +1528,9 @@ def _dl_min_from_smallest_feature(self, structure_list: list[Structure]): return dl_min def _corners( - self, structure_list: list[Structure], cached_corners_and_convexity=None + self, + structure_list: list[Structure], + cached_corners_and_convexity: Optional[CornersAndConvexity] = None, ) -> list[CoordinateOptional]: """Inplane corners in 3D coordinate.""" if self.corner_finder is None: @@ -1573,7 +1580,7 @@ def _override_structures_inplane( self, structure_list: list[Structure], grid_size_in_vacuum: float, - cached_corners_and_convexity=None, + cached_corners_and_convexity: Optional[CornersAndConvexity] = None, ) -> list[MeshOverrideStructure]: """Inplane mesh override structures for refining mesh around corners.""" if self.corner_refinement is None: @@ -1643,8 +1650,12 @@ def _override_structures_along_axis( return override_structures def _find_vertical_intersections( - self, grid_x_coords, grid_y_coords, poly_vertices, boundary - ) -> tuple[list[tuple[int, int]], list[float]]: + self, + grid_x_coords: ArrayFloat1D, + grid_y_coords: ArrayFloat1D, + poly_vertices: ArrayFloat2D, + boundary: tuple[Optional[str], Optional[str]], + ) -> tuple[np.typing.NDArray[np.int_], np.typing.NDArray[np.float64]]: """Detect intersection points of single polygon and vertical grid lines.""" # indices of cells that contain intersection with grid lines (left edge of a cell) @@ -1809,12 +1820,24 @@ def _find_vertical_intersections( np.zeros(len(cells_ij_one_side)), ] ) + else: + cells_ij = np.empty((0, 2), dtype=int) + cells_dy = np.empty(0, dtype=float) return cells_ij, cells_dy def _process_poly( - self, grid_x_coords, grid_y_coords, poly_vertices, boundaries - ) -> tuple[list[tuple[int, int]], list[float], list[tuple[int, int]], list[float]]: + self, + grid_x_coords: ArrayFloat1D, + grid_y_coords: ArrayFloat1D, + poly_vertices: ArrayFloat2D, + boundaries: tuple[tuple[Optional[str], Optional[str]], tuple[Optional[str], Optional[str]]], + ) -> tuple[ + np.typing.NDArray[np.int_], + np.typing.NDArray[np.float64], + np.typing.NDArray[np.int_], + np.typing.NDArray[np.float64], + ]: """Detect intersection points of single polygon and grid lines.""" # find cells that contain intersections of vertical grid lines @@ -1836,8 +1859,17 @@ def _process_poly( return v_cells_ij, v_cells_dy, h_cells_ij, h_cells_dx def _process_slice( - self, x, y, merged_geos, boundaries - ) -> tuple[list[tuple[int, int]], list[float], list[tuple[int, int]], list[float]]: + self, + x: ArrayFloat1D, + y: ArrayFloat1D, + merged_geos: list[tuple[Any, Shapely]], + boundaries: list[list[Optional[str], Optional[str]], list[Optional[str], Optional[str]]], + ) -> tuple[ + np.typing.NDArray[np.int_], + np.typing.NDArray[np.float64], + np.typing.NDArray[np.int_], + np.typing.NDArray[np.float64], + ]: """Detect intersection points of geometries boundaries and grid lines.""" # cells that contain intersections of vertical grid lines @@ -1911,16 +1943,25 @@ def _process_slice( if len(v_cells_ij) > 0: v_cells_ij = np.concatenate(v_cells_ij) v_cells_dy = np.concatenate(v_cells_dy) + else: + v_cells_ij = np.empty((0, 2), dtype=int) + v_cells_dy = np.empty(0, dtype=float) if len(h_cells_ij) > 0: h_cells_ij = np.concatenate(h_cells_ij) h_cells_dx = np.concatenate(h_cells_dx) + else: + h_cells_ij = np.empty((0, 2), dtype=int) + h_cells_dx = np.empty(0, dtype=float) return v_cells_ij, v_cells_dy, h_cells_ij, h_cells_dx def _generate_horizontal_snapping_lines( - self, grid_y_coords, intersected_cells_ij, relative_vert_disp - ) -> tuple[list[CoordinateOptional], float]: + self, + grid_y_coords: ArrayFloat1D, + intersected_cells_ij: np.typing.NDArray[np.int_], + relative_vert_disp: np.typing.NDArray[np.float64], + ) -> tuple[list[float], float]: """Convert a list of intersections of vertical grid lines, given as coordinates of cells and relative vertical displacement inside each cell, into locations of snapping lines that resolve thin gaps and strips. @@ -1997,8 +2038,15 @@ def _generate_horizontal_snapping_lines( return snapping_lines_y, min_gap_width def _resolve_gaps( - self, structures: list[Structure], grid: Grid, boundary_types: tuple - ) -> tuple[list[CoordinateOptional], float]: + self, + structures: list[Structure], + grid: Grid, + boundary_types: tuple[ + tuple[Optional[str], Optional[str]], + tuple[Optional[str], Optional[str]], + tuple[Optional[str], Optional[str]], + ], + ) -> tuple[tuple[CoordinateOptional], float]: """ Detect underresolved gaps and place snapping lines in them. Also return the detected minimal gap width. @@ -2014,7 +2062,7 @@ def _resolve_gaps( Returns ------- - tuple[list[CoordinateOptional], float] + list[list[CoordinateOptional], float] List of snapping lines and the detected minimal gap width. """ @@ -2308,7 +2356,7 @@ def internal_snapping_points( self, structures: list[Structure], lumped_elements: list[LumpedElementType], - cached_corners_and_convexity=None, + cached_corners_and_convexity: Optional[list[CornersAndConvexity]] = None, ) -> list[CoordinateOptional]: """Internal snapping points. So far, internal snapping points are generated by `layer_refinement_specs` and lumped element. @@ -2319,7 +2367,7 @@ def internal_snapping_points( List of physical structures. lumped_elements : list[LumpedElementType] List of lumped elements. - cached_corners_and_convexity : Optional[list[CachedCornersAndConvexity]] + cached_corners_and_convexity : Optional[list[CornersAndConvexity]] Cached corners and convexity data. Returns @@ -2389,7 +2437,7 @@ def internal_override_structures( wavelength: PositiveFloat, sim_size: tuple[float, 3], lumped_elements: list[LumpedElementType], - cached_corners_and_convexity=None, + cached_corners_and_convexity: Optional[list[CornersAndConvexity]] = None, ) -> list[StructureType]: """Internal mesh override structures. So far, internal override structures are generated by `layer_refinement_specs` and lumped element. @@ -2404,7 +2452,7 @@ def internal_override_structures( Simulation domain size. lumped_elements : list[LumpedElementType] List of lumped elements. - cached_corners_and_convexity : Optional[list[CachedCornersAndConvexity]] + cached_corners_and_convexity : Optional[list[CornersAndConvexity]] Cached corners and convexity data. Returns diff --git a/tidy3d/components/grid/mesher.py b/tidy3d/components/grid/mesher.py index 61955438ed..4246f11026 100644 --- a/tidy3d/components/grid/mesher.py +++ b/tidy3d/components/grid/mesher.py @@ -692,7 +692,7 @@ def rotate_structure_bounds(structures: list[StructureType], axis: Axis) -> list return struct_bbox @staticmethod - def bounds_2d_tree(struct_bbox: list[ArrayFloat1D]): + def bounds_2d_tree(struct_bbox: list[ArrayFloat1D]) -> STRtree: """Make a shapely Rtree for the 2D bounding boxes of all structures in the plane perpendicular to the meshing axis.""" @@ -1339,7 +1339,7 @@ def grid_grow_in_interval( if len_mismatch_even > small_dl: - def fun_scale(new_scale): + def fun_scale(new_scale: float) -> float: if isclose(new_scale, 1.0): return len_interval - small_dl * (1 + num_step) return ( diff --git a/tidy3d/components/mode/derivatives.py b/tidy3d/components/mode/derivatives.py index 19f0f4b04f..166e153b32 100644 --- a/tidy3d/components/mode/derivatives.py +++ b/tidy3d/components/mode/derivatives.py @@ -2,12 +2,22 @@ from __future__ import annotations +from collections.abc import Sequence +from typing import TYPE_CHECKING, Literal + import numpy as np +from numpy.typing import NDArray from tidy3d.constants import EPSILON_0, ETA_0 +if TYPE_CHECKING: + from scipy import sparse as sp + +ArrayFloat = NDArray[np.floating] +ArrayComplex = NDArray[np.complexfloating] -def make_dxf(dls, shape, pmc): + +def make_dxf(dls: ArrayFloat, shape: tuple[int, int], pmc: bool) -> sp.csr_matrix: """Forward derivative in x.""" import scipy.sparse as sp @@ -22,7 +32,7 @@ def make_dxf(dls, shape, pmc): return dxf -def make_dxb(dls, shape, pmc): +def make_dxb(dls: ArrayFloat, shape: tuple[int, int], pmc: bool) -> sp.csr_matrix: """Backward derivative in x.""" import scipy.sparse as sp @@ -39,7 +49,7 @@ def make_dxb(dls, shape, pmc): return dxb -def make_dyf(dls, shape, pmc): +def make_dyf(dls: ArrayFloat, shape: tuple[int, int], pmc: bool) -> sp.csr_matrix: """Forward derivative in y.""" import scipy.sparse as sp @@ -54,7 +64,7 @@ def make_dyf(dls, shape, pmc): return dyf -def make_dyb(dls, shape, pmc): +def make_dyb(dls: ArrayFloat, shape: tuple[int, int], pmc: bool) -> sp.csr_matrix: """Backward derivative in y.""" import scipy.sparse as sp @@ -71,7 +81,11 @@ def make_dyb(dls, shape, pmc): return dyb -def create_d_matrices(shape, dls, dmin_pmc=(False, False)): +def create_d_matrices( + shape: tuple[int, int], + dls: tuple[Sequence[ArrayFloat], Sequence[ArrayFloat]], + dmin_pmc: tuple[bool, bool] = (False, False), +) -> tuple[sp.csr_matrix, sp.csr_matrix, sp.csr_matrix, sp.csr_matrix]: """Make the derivative matrices without PML. If dmin_pmc is True, the 'backward' derivative in that dimension will be set to implement PMC boundary, otherwise it will be set to PEC.""" @@ -85,7 +99,15 @@ def create_d_matrices(shape, dls, dmin_pmc=(False, False)): return (dxf, dxb, dyf, dyb) -def create_s_matrices(omega, shape, npml, dls, eps_tensor, mu_tensor, dmin_pml=(True, True)): +def create_s_matrices( + omega: float, + shape: tuple[int, int], + npml: tuple[int, int], + dls: tuple[Sequence[ArrayFloat], Sequence[ArrayFloat]], + eps_tensor: ArrayComplex, + mu_tensor: ArrayComplex, + dmin_pml: tuple[bool, bool] = (True, True), +) -> tuple[sp.csr_matrix, sp.csr_matrix, sp.csr_matrix, sp.csr_matrix]: """Makes the 'S-matrices'. When dotted with derivative matrices, they add PML. If dmin_pml is set to False, PML will not be applied on the "bottom" side of the domain.""" @@ -136,17 +158,23 @@ def create_s_matrices(omega, shape, npml, dls, eps_tensor, mu_tensor, dmin_pml=( return sx_f, sx_b, sy_f, sy_b -def average_relative_speed(Nx, Ny, npml, eps_tensor, mu_tensor): +def average_relative_speed( + Nx: int, + Ny: int, + npml: tuple[int, int], + eps_tensor: ArrayComplex, + mu_tensor: ArrayComplex, +) -> ArrayFloat: """Compute the relative speed of light in the four pml regions by averaging the diagonal elements of the relative epsilon and mu within the pml region.""" - def relative_mean(tensor): + def relative_mean(tensor: ArrayComplex) -> float: """Mean for relative parameters. If an empty array just return 1.""" if tensor.size == 0: return 1.0 return np.mean(tensor) - def pml_average_allsides(tensor): + def pml_average_allsides(tensor: ArrayComplex) -> ArrayFloat: """Average ``tensor`` in the PML regions on all four sides. Returns the average values in order (xminus, xplus, yminus, yplus).""" @@ -165,7 +193,15 @@ def pml_average_allsides(tensor): return 1 / np.sqrt(eps_avg * mu_avg) -def create_sfactor(direction, omega, dls, N, n_pml, dmin_pml, avg_speed): +def create_sfactor( + direction: Literal["f", "b"], + omega: float, + dls: ArrayFloat, + N: int, + n_pml: int, + dmin_pml: bool, + avg_speed: Sequence[float], +) -> ArrayComplex: """Creates the S-factor cross section needed in the S-matrices""" # For no PNL, this should just be identity matrix. @@ -181,7 +217,14 @@ def create_sfactor(direction, omega, dls, N, n_pml, dmin_pml, avg_speed): raise ValueError(f"Direction value {direction} not recognized") -def create_sfactor_f(omega, dls, N, n_pml, dmin_pml, avg_speed=(1, 1)): +def create_sfactor_f( + omega: float, + dls: ArrayFloat, + N: int, + n_pml: int, + dmin_pml: bool, + avg_speed: Sequence[float] = (1, 1), +) -> ArrayComplex: """S-factor profile applied after forward derivative matrix, i.e. applied to H-field locations.""" sfactor_array = np.ones(N, dtype=np.complex128) @@ -195,7 +238,14 @@ def create_sfactor_f(omega, dls, N, n_pml, dmin_pml, avg_speed=(1, 1)): return sfactor_array -def create_sfactor_b(omega, dls, N, n_pml, dmin_pml, avg_speed=(1, 1)): +def create_sfactor_b( + omega: float, + dls: ArrayFloat, + N: int, + n_pml: int, + dmin_pml: bool, + avg_speed: Sequence[float] = (1, 1), +) -> ArrayComplex: """S-factor profile applied after backward derivative matrix, i.e. applied to E-field locations.""" sfactor_array = np.ones(N, dtype=np.complex128) @@ -209,14 +259,14 @@ def create_sfactor_b(omega, dls, N, n_pml, dmin_pml, avg_speed=(1, 1)): def s_value( dl: float, - step: int, + step: float, omega: float, avg_speed: float, sigma_max: float = 2, kappa_min: float = 1, kappa_max: float = 3, order: int = 3, -): +) -> complex: """S-value to use in the S-matrices. We use coordinate stretching formulation such that s(x) = kappa(x) + 1j * sigma(x) / (omega * EPSILON_0) diff --git a/tidy3d/components/mode/mode_solver.py b/tidy3d/components/mode/mode_solver.py index f1127539f2..082a42dba9 100644 --- a/tidy3d/components/mode/mode_solver.py +++ b/tidy3d/components/mode/mode_solver.py @@ -6,7 +6,7 @@ from functools import wraps from math import isclose -from typing import Any, Literal, Optional, Union, get_args +from typing import Any, Callable, Literal, Optional, ParamSpec, Self, TypeVar, Union, get_args import numpy as np import xarray as xr @@ -121,19 +121,23 @@ # Log a warning when the PML covers more than this portion of the mode plane in any axis WARN_THICK_PML_PERCENT = 50 +P = ParamSpec("P") +R = TypeVar("R") -def require_fdtd_simulation(fn): + +def require_fdtd_simulation(fn: Callable[P, R]) -> Callable[P, R]: """Decorate a function to check that ``simulation`` is an FDTD ``Simulation``.""" @wraps(fn) - def _fn(self, **kwargs: Any): + def _fn(*args: P.args, **kwargs: P.kwargs) -> R: """New decorated function.""" + self = args[0] if not isinstance(self.simulation, Simulation): raise SetupError( f"The function '{fn.__name__}' is only supported " "for 'simulation' of type FDTD 'Simulation'." ) - return fn(self, **kwargs) + return fn(*args, **kwargs) return _fn @@ -209,7 +213,7 @@ class ModeSolver(Tidy3dBaseModel): @field_validator("simulation") @classmethod - def _convert_to_simulation(cls, val): + def _convert_to_simulation(cls, val: MODE_SIMULATION_TYPE) -> MODE_SIMULATION_TYPE: """Convert to regular Simulation if e.g. JaxSimulation given.""" if hasattr(val, "to_simulation"): val = val.to_simulation()[0] @@ -221,7 +225,7 @@ def _convert_to_simulation(cls, val): @field_validator("plane") @classmethod - def is_plane(cls, val): + def is_plane(cls, val: MODE_PLANE_TYPE) -> MODE_PLANE_TYPE: """Raise validation error if not planar.""" if val.size.count(0.0) != 1: raise ValidationError(f"ModeSolver plane must be planar, given size={val}") @@ -231,7 +235,7 @@ def is_plane(cls, val): _freqs_lower_bound = validate_freqs_min() @model_validator(mode="after") - def plane_in_sim_bounds(self): + def plane_in_sim_bounds(self) -> Self: """Check that the plane is at least partially inside the simulation bounds.""" sim_box = Box(size=self.simulation.size, center=self.simulation.center) if not sim_box.intersects(self.plane): @@ -239,7 +243,7 @@ def plane_in_sim_bounds(self): return self @model_validator(mode="after") - def _warn_plane_crosses_symmetry(self): + def _warn_plane_crosses_symmetry(self) -> Self: """Warn if the mode plane crosses the symmetry plane of the underlying simulation but the centers do not match.""" for dim in range(3): @@ -259,26 +263,26 @@ def _warn_plane_crosses_symmetry(self): return self @model_validator(mode="after") - def _validate_warn_thick_pml(self): + def _validate_warn_thick_pml(self) -> Self: """Warn if the pml covers a significant portion of the mode plane.""" self._warn_thick_pml(simulation=self.simulation, plane=self.plane, mode_spec=self.mode_spec) self._validate_rotate_structures() return self @model_validator(mode="after") - def _validate_bend_radius(self): + def _validate_bend_radius(self) -> Self: """Validate that the bend radius is not too small.""" sim_box = Box(size=self.simulation.size, center=self.simulation.center) self._validate_mode_plane_radius(self.mode_spec, self.plane, sim_box) return self @model_validator(mode="after") - def _validate_rotate_structures_after(self): + def _validate_rotate_structures_after(self) -> Self: self._validate_rotate_structures() return self @model_validator(mode="after") - def _validate_num_grid_points(self): + def _validate_num_grid_points(self) -> Self: """Upper bound of the product of the number of grid points and the number of modes. The bound is very loose: subspace size times the size of eigenvector can be indexed by a 32bit integer. """ @@ -349,7 +353,7 @@ def _validate_rotate_structures(self) -> None: @staticmethod def _make_rotated_structures( structures: list[Structure], translate_kwargs: dict, rotate_kwargs: dict - ): + ) -> list[Structure]: try: rotated_structures = [] for structure in structures: @@ -386,7 +390,7 @@ def normal_axis(self) -> Axis: return self.plane.size.index(0.0) @staticmethod - def plane_center_tangential(plane) -> tuple[float, float]: + def plane_center_tangential(plane: MODE_PLANE_TYPE) -> tuple[float, float]: """Mode lane center in the tangential axes.""" _, plane_center = plane.pop_axis(plane.center, plane.size.index(0.0)) return plane_center @@ -679,7 +683,7 @@ def rotated_mode_solver_data(self) -> ModeSolverData: return rotated_mode_data @cached_property - def rotated_structures_copy(self): + def rotated_structures_copy(self) -> ModeSolver: """Create a copy of the original ModeSolver with rotated structures to the simulation and updates the ModeSpec to disable bend correction and reset angles to normal.""" @@ -1104,7 +1108,7 @@ def theta_reference(self) -> float: return theta_ref @cached_property - def _bend_radius(self): + def _bend_radius(self) -> float: """A bend_radius to use when ``angle_rotation`` is on. When there is no bend defined, we use an effectively very large radius, much larger than the mode plane. This is only used for the rotation of the fields - the reference modes are still computed without any @@ -1116,7 +1120,7 @@ def _bend_radius(self): return EFFECTIVE_RADIUS_FACTOR * largest_dim @cached_property - def bend_center(self) -> list: + def bend_center(self) -> list[float]: """Computes the bend center based on plane center, angle_theta and angle_phi.""" _, id_bend_uv = self.plane.pop_axis((0, 1, 2), axis=self.bend_axis_3d) @@ -1360,7 +1364,7 @@ def _normalize_modes(self, mode_solver_data: ModeSolverData) -> None: """Normalize modes. Note: this modifies ``mode_solver_data`` in-place.""" mode_solver_data._normalize_modes() - def _filter_components(self, mode_solver_data: ModeSolverData): + def _filter_components(self, mode_solver_data: ModeSolverData) -> ModeSolverData: skip_components = { comp: None for comp in mode_solver_data.field_components.keys() @@ -1368,7 +1372,7 @@ def _filter_components(self, mode_solver_data: ModeSolverData): } return mode_solver_data.updated_copy(**skip_components, validate=False) - def _filter_polarization(self, mode_solver_data: ModeSolverData): + def _filter_polarization(self, mode_solver_data: ModeSolverData) -> ModeSolverData: """Filter polarization.""" filter_pol = self.mode_spec.filter_pol if filter_pol is None: @@ -1612,7 +1616,13 @@ def _solve_all_freqs_relative( return n_complex, fields, eps_spec @staticmethod - def _postprocess_solver_fields(solver_fields, normal_axis, plane, mode_spec, coords): + def _postprocess_solver_fields( + solver_fields: ArrayComplex4D, + normal_axis: Axis, + plane: MODE_PLANE_TYPE, + mode_spec: ModeSpec, + coords: tuple[ArrayFloat1D, ArrayFloat1D], + ) -> dict[str, ArrayComplex4D]: """Postprocess `solver_fields` from `compute_modes` to proper coordinate""" fields = {key: [] for key in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz")} diff_coords = (np.diff(coords[0]), np.diff(coords[1])) @@ -1672,7 +1682,9 @@ def _rotate_field_coords_inverse( return np.stack(plane.unpop_axis(f_n, f_ts, axis=2), axis=0) @classmethod - def _postprocess_solver_fields_inverse(cls, fields, normal_axis: Axis, plane: MODE_PLANE_TYPE): + def _postprocess_solver_fields_inverse( + cls, fields: dict[str, ArrayComplex4D], normal_axis: Axis, plane: MODE_PLANE_TYPE + ) -> ArrayComplex4D: """Convert ``fields`` to ``solver_fields``. Doesn't change gauge.""" E = [fields[key] for key in ("Ex", "Ey", "Ez")] H = [fields[key] for key in ("Hx", "Hy", "Hz")] @@ -2721,7 +2733,7 @@ def validate_pre_upload(self) -> None: self._validate_modes_size() @cached_property - def reduced_simulation_copy(self): + def reduced_simulation_copy(self) -> Self: """Strip objects not used by the mode solver from simulation object. This might significantly reduce upload time in the presence of custom mediums. """ @@ -2834,7 +2846,7 @@ def _patch_data(self, data: ModeSolverData) -> None: self._cached_properties.pop("data", None) self._cached_properties.pop("sim_data", None) - def plot_3d(self, width=800, height=800) -> None: + def plot_3d(self, width: int = 800, height: int = 800) -> None: """Render 3D plot of ``ModeSolver`` (in jupyter notebook only). Parameters ---------- diff --git a/tidy3d/components/mode/simulation.py b/tidy3d/components/mode/simulation.py index 505ee25e2d..f446be8574 100644 --- a/tidy3d/components/mode/simulation.py +++ b/tidy3d/components/mode/simulation.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Self, Union import numpy as np from pydantic import Field, PositiveFloat, field_validator, model_validator @@ -39,6 +39,9 @@ from .mode_solver import ModeSolver +if TYPE_CHECKING: + from tidy3d.components.mode.data.sim_data import ModeSimulationData + ModeSimulationMonitorType = Union[PermittivityMonitor, MediumMonitor] # dummy run time for conversion to FDTD sim @@ -216,20 +219,20 @@ class ModeSimulation(AbstractYeeGridSimulation): @field_validator("grid_spec") @classmethod - def _validate_auto_grid_wavelength(cls, val): + def _validate_auto_grid_wavelength(cls, val: GridSpec) -> GridSpec: # abstract override, logic is handled in post-init to ensure freqs is defined return val @field_validator("plane") @classmethod - def _validate_planar(cls, val): + def _validate_planar(cls, val: Optional[MODE_PLANE_TYPE]) -> Optional[MODE_PLANE_TYPE]: if val.size.count(0.0) != 1: raise ValidationError(f"'ModeSimulation.plane' must be planar, given 'size={val.size}'") return val @model_validator(mode="before") @classmethod - def is_plane(cls, data): + def is_plane(cls, data: dict[str, Any]) -> dict[str, Any]: """Raise validation error if not planar.""" if hasattr(data, "get") and data.get("plane") is None: val = Box(size=data.get("size"), center=data.get("center")) @@ -242,7 +245,7 @@ def is_plane(cls, data): return data @model_validator(mode="after") - def plane_in_sim_bounds(self): + def plane_in_sim_bounds(self) -> Self: """Check that the plane is at least partially inside the simulation bounds.""" sim_box = Box(size=self.size, center=self.center) if not sim_box.intersects(self.plane): @@ -250,12 +253,12 @@ def plane_in_sim_bounds(self): return self @model_validator(mode="after") - def _validate_mode_solver(self): + def _validate_mode_solver(self) -> Self: _ = self._mode_solver return self @model_validator(mode="after") - def _validate_grid(self): + def _validate_grid(self) -> Self: _ = self.grid return self @@ -266,7 +269,7 @@ def _mode_solver(self) -> ModeSolver: return ModeSolver(simulation=self._as_fdtd_sim, **kwargs) @supports_local_subpixel - def run_local(self): + def run_local(self) -> ModeSimulationData: """Run locally.""" if tidy3d_extras["use_local_subpixel"]: diff --git a/tidy3d/components/mode/solver.py b/tidy3d/components/mode/solver.py index afdfb9a762..be9b04eccf 100644 --- a/tidy3d/components/mode/solver.py +++ b/tidy3d/components/mode/solver.py @@ -2,9 +2,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Literal, Optional, Union import numpy as np +from numpy.typing import NDArray from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.types import EpsSpecType, ModeSolverType @@ -14,6 +16,9 @@ from .derivatives import create_s_matrices as s_mats from .transforms import angled_transform, radial_transform +if TYPE_CHECKING: + from scipy import sparse as sp + # Consider vec to be complex if norm(vec.imag)/norm(vec) > TOL_COMPLEX TOL_COMPLEX = 1e-10 # Tolerance for eigs @@ -28,12 +33,13 @@ # double precision. This value is very heuristic. GOOD_CONDUCTOR_CUT_OFF = 1e70 -if TYPE_CHECKING: - from scipy import sparse as sp # Consider a material to be good conductor if |ep| (or |mu|) > GOOD_CONDUCTOR_THRESHOLD * |pec_val| GOOD_CONDUCTOR_THRESHOLD = 0.9 +ArrayFloat = NDArray[np.floating] +ArrayComplex = NDArray[np.complexfloating] + class EigSolver(Tidy3dBaseModel): """Interface for computing eigenvalues given permittivity and mode spec. @@ -43,18 +49,18 @@ class EigSolver(Tidy3dBaseModel): @classmethod def compute_modes( cls, - eps_cross, - coords, - freq, - mode_spec, - precision, - mu_cross=None, - split_curl_scaling=None, - symmetry=(0, 0), - direction="+", - solver_basis_fields=None, + eps_cross: Union[ArrayComplex, tuple[ArrayComplex, ...]], + coords: Sequence[ArrayFloat], + freq: float, + mode_spec: ModeSolverType, + precision: Literal["single", "double"], + mu_cross: Optional[Union[ArrayComplex, tuple[ArrayComplex, ...]]] = None, + split_curl_scaling: Optional[ArrayFloat] = None, + symmetry: tuple[int, int] = (0, 0), + direction: Literal["+", "-"] = "+", + solver_basis_fields: Optional[ArrayComplex] = None, plane_center: Optional[tuple[float, float]] = None, - ) -> tuple[np.ndarray, np.ndarray, EpsSpecType]: + ) -> tuple[ArrayComplex, ArrayComplex, EpsSpecType]: """ Solve for the modes of a waveguide cross-section. @@ -124,7 +130,7 @@ def compute_modes( if len(coords[0]) != Nx + 1 or len(coords[1]) != Ny + 1: raise ValueError("Mismatch between 'coords' and 'esp_cross' shapes.") - new_coords = [np.copy(c) for c in coords] + new_coords = (np.copy(coords[0]), np.copy(coords[1])) """We work with full tensorial epsilon in mu to handle the most general cases that can be introduced by coordinate transformations. In the solver, we distinguish the case when @@ -299,20 +305,20 @@ def compute_modes( @classmethod def solver_em( cls, - Nx, - Ny, - eps_tensor, - mu_tensor, - der_mats, - num_modes, - neff_guess, - mat_precision, - direction, - enable_incidence_matrices, - basis_E, - dls, - dmin_pmc=None, - ): + Nx: int, + Ny: int, + eps_tensor: ArrayComplex, + mu_tensor: ArrayComplex, + der_mats: Sequence[sp.csr_matrix], + num_modes: int, + neff_guess: float, + mat_precision: Literal["single", "double"], + direction: Literal["+", "-"], + enable_incidence_matrices: bool, + basis_E: Optional[ArrayComplex], + dls: tuple[Sequence[ArrayFloat], Sequence[ArrayFloat]], + dmin_pmc: Optional[Sequence[bool]] = None, + ) -> tuple[ArrayComplex, ArrayComplex, ArrayFloat, ArrayFloat, EpsSpecType]: """Solve for the electromagnetic modes of a system defined by in-plane permittivity and permeability and assuming translational invariance in the normal direction. @@ -366,8 +372,8 @@ def solver_em( # use a high-conductivity model for locations associated with a good conductor def conductivity_model_for_good_conductor( - eps, threshold=GOOD_CONDUCTOR_THRESHOLD * pec_val - ): + eps: ArrayComplex, threshold: complex = GOOD_CONDUCTOR_THRESHOLD * pec_val + ) -> ArrayComplex: """Entries associated with 'eps' are converted to a high-conductivity model.""" eps = eps.astype(complex) eps[np.abs(eps) >= abs(threshold)] = 1 + 1j * pec_scaled_val @@ -445,16 +451,16 @@ def conductivity_model_for_good_conductor( @classmethod def solver_diagonal( cls, - eps, - mu, - der_mats, - num_modes, - neff_guess, - vec_init, - mat_precision, - enable_incidence_matrices, - basis_E, - ): + eps: ArrayComplex, + mu: ArrayComplex, + der_mats: Sequence[sp.csr_matrix], + num_modes: int, + neff_guess: float, + vec_init: ArrayComplex, + mat_precision: Literal["single", "double"], + enable_incidence_matrices: bool, + basis_E: Optional[ArrayComplex], + ) -> tuple[ArrayComplex, ArrayComplex, ArrayFloat, ArrayFloat]: """EM eigenmode solver assuming ``eps`` and ``mu`` are diagonal everywhere.""" import scipy.sparse as sp import scipy.sparse.linalg as spl @@ -464,7 +470,9 @@ def solver_diagonal( analyze_conditioning = False _threshold = 0.9 * np.abs(pec_val) - def incidence_matrix_for_pec(eps_vec, threshold=_threshold): + def incidence_matrix_for_pec( + eps_vec: ArrayComplex, threshold: float = _threshold + ) -> sp.csr_matrix: """Incidence matrix indicating non-PEC entries associated with 'eps_vec'.""" nnz = eps_vec[np.abs(eps_vec) < threshold] eps_nz = eps_vec.copy() @@ -553,7 +561,9 @@ def incidence_matrix_for_pec(eps_vec, threshold=_threshold): elif PRECONDITIONER == "Material": - def conditional_inverted_vec(eps_vec, threshold=1): + def conditional_inverted_vec( + eps_vec: ArrayComplex, threshold: float = 1 + ) -> sp.csr_matrix: """Returns a diagonal sparse matrix whose i-th element in the diagonal is |eps_i|^-1 if |eps_i|>threshold, and |eps_i| otherwise. """ @@ -671,7 +681,14 @@ def conditional_inverted_vec(eps_vec, threshold=1): return E, H, neff, keff @classmethod - def matrix_data_type(cls, eps, mu, der_mats, mat_precision, is_tensorial): + def matrix_data_type( + cls, + eps: ArrayComplex, + mu: ArrayComplex, + der_mats: Sequence[sp.csr_matrix], + mat_precision: Literal["single", "double"], + is_tensorial: bool, + ) -> np.dtype[Any]: """Determine data type that should be used for the matrix for diagonalization.""" mat_dtype = np.float32 # In tensorial case, even though the matrix can be real, the @@ -708,18 +725,18 @@ def trim_small_values(cls, mat: sp.csr_matrix, tol: float) -> sp.csr_matrix: @classmethod def solver_tensorial( cls, - eps, - mu, - der_mats, - num_modes, - neff_guess, - vec_init, - mat_precision, - direction, - dls, - Nxy=None, - dmin_pmc=None, - ): + eps: ArrayComplex, + mu: ArrayComplex, + der_mats: Sequence[sp.csr_matrix], + num_modes: int, + neff_guess: float, + vec_init: ArrayComplex, + mat_precision: Literal["single", "double"], + direction: Literal["+", "-"], + dls: tuple[Sequence[ArrayFloat], Sequence[ArrayFloat]], + Nxy: Optional[tuple[int, int]] = None, + dmin_pmc: Optional[Sequence[bool]] = None, + ) -> tuple[ArrayComplex, ArrayComplex, ArrayFloat, ArrayFloat]: """EM eigenmode solver assuming ``eps`` or ``mu`` have off-diagonal elements.""" import scipy.sparse as sp @@ -850,13 +867,13 @@ def solver_tensorial( @classmethod def solver_eigs( cls, - mat, - num_modes, - vec_init, - guess_value=1.0, - M=None, + mat: sp.csr_matrix, + num_modes: int, + vec_init: ArrayComplex, + guess_value: float = 1.0, + M: Optional[sp.csr_matrix] = None, **kwargs: Any, - ): + ) -> tuple[ArrayComplex, ArrayComplex]: """Find ``num_modes`` eigenmodes of ``mat`` cloest to ``guess_value``. Parameters @@ -888,14 +905,14 @@ def solver_eigs( @classmethod def solver_eigs_relative( cls, - mat, - num_modes, - vec_init, - guess_value=1.0, - M=None, - basis_vecs=None, + mat: sp.csr_matrix, + num_modes: int, + vec_init: ArrayComplex, + guess_value: float = 1.0, + M: Optional[sp.csr_matrix] = None, + basis_vecs: Optional[ArrayComplex] = None, **kwargs: Any, - ): + ) -> tuple[ArrayComplex, ArrayComplex]: """Find ``num_modes`` eigenmodes of ``mat`` cloest to ``guess_value``. Parameters @@ -916,7 +933,9 @@ def solver_eigs_relative( return values, vectors @classmethod - def isinstance_complex(cls, vec_or_mat, tol=TOL_COMPLEX): + def isinstance_complex( + cls, vec_or_mat: Union[ArrayComplex, sp.csr_matrix], tol: float = TOL_COMPLEX + ) -> bool: """Check if a numpy array or scipy.sparse.csr_matrix has complex component by looking at norm(x.imag)/norm(x)>TOL_COMPLEX @@ -938,7 +957,9 @@ def isinstance_complex(cls, vec_or_mat, tol=TOL_COMPLEX): ) @classmethod - def type_conversion(cls, vec_or_mat, new_dtype): + def type_conversion( + cls, vec_or_mat: Union[ArrayComplex, sp.csr_matrix], new_dtype: np.dtype[Any] + ) -> Union[ArrayComplex, sp.csr_matrix]: """Convert vec_or_mat to new_type. Parameters @@ -962,7 +983,7 @@ def type_conversion(cls, vec_or_mat, new_dtype): raise RuntimeError("Unsupported new_type.") @classmethod - def set_initial_vec(cls, Nx, Ny, is_tensorial=False): + def set_initial_vec(cls, Nx: int, Ny: int, is_tensorial: bool = False) -> ArrayComplex: """Set initial vector for eigs: 1) The field at x=0 and y=0 boundaries are set to 0. This should be the case for PEC boundaries, but wouldn't hurt for non-PEC boundary; @@ -1000,7 +1021,9 @@ def set_initial_vec(cls, Nx, Ny, is_tensorial=False): return vec_init.flatten("F") @classmethod - def eigs_to_effective_index(cls, eig_list: np.ndarray, mode_solver_type: ModeSolverType): + def eigs_to_effective_index( + cls, eig_list: ArrayComplex, mode_solver_type: ModeSolverType + ) -> tuple[ArrayFloat, ArrayFloat]: """Convert obtained eigenvalues to n_eff and k_eff. Parameters @@ -1030,7 +1053,9 @@ def eigs_to_effective_index(cls, eig_list: np.ndarray, mode_solver_type: ModeSol raise RuntimeError(f"Unidentified 'mode_solver_type={mode_solver_type}'.") @staticmethod - def format_medium_data(mat_data): + def format_medium_data( + mat_data: Union[ArrayComplex, Sequence[ArrayComplex]], + ) -> tuple[ArrayComplex, ...]: """ mat_data can be either permittivity or permeability. It's either a single 2D array defining the relative property in the cross-section, or nine 2D arrays defining @@ -1038,13 +1063,13 @@ def format_medium_data(mat_data): xx, xy, xz, yx, yy, yz, zx, zy, zz. """ if isinstance(mat_data, np.ndarray): - return (mat_data[i, :, :] for i in range(9)) + return tuple(mat_data[i, :, :] for i in range(9)) if len(mat_data) == 9: - return (np.copy(e) for e in mat_data) + return tuple(np.copy(e) for e in mat_data) raise ValueError("Wrong input to mode solver pemittivity/permeability!") @staticmethod - def split_curl_field_postprocess(split_curl, E): + def split_curl_field_postprocess(split_curl: ArrayFloat, E: ArrayComplex) -> ArrayComplex: """E has the shape (3, N, num_modes)""" _, Nx, Ny = split_curl.shape field_shape = E.shape @@ -1062,7 +1087,9 @@ def split_curl_field_postprocess(split_curl, E): return E @staticmethod - def make_pml_invariant(Nxy, tensor, num_pml): + def make_pml_invariant( + Nxy: tuple[int, int], tensor: ArrayComplex, num_pml: tuple[int, int] + ) -> ArrayComplex: """For a given epsilon or mu tensor of shape ``(3, 3, Nx, Ny)``, and ``num_pml`` pml layers along ``x`` and ``y``, make all the tensor values in the PML equal by replicating the first pixel into the PML.""" @@ -1076,12 +1103,16 @@ def make_pml_invariant(Nxy, tensor, num_pml): return new_ten.reshape((3, 3, -1)) @staticmethod - def split_curl_field_postprocess_inverse(split_curl, E) -> None: + def split_curl_field_postprocess_inverse( + split_curl: ArrayFloat, E: ArrayComplex + ) -> ArrayComplex: """E has the shape (3, N, num_modes)""" raise RuntimeError("Split curl not yet implemented for relative mode solver.") @staticmethod - def mode_plane_contain_good_conductor(material_response) -> bool: + def mode_plane_contain_good_conductor( + material_response: Optional[ArrayComplex], + ) -> bool: """Find out if epsilon on the modal plane contain good conductors whose permittivity or permeability value is very large. """ @@ -1090,6 +1121,6 @@ def mode_plane_contain_good_conductor(material_response) -> bool: return np.any(np.abs(material_response) > GOOD_CONDUCTOR_THRESHOLD * np.abs(pec_val)) -def compute_modes(*args: Any, **kwargs: Any) -> tuple[np.ndarray, np.ndarray, str]: +def compute_modes(*args: Any, **kwargs: Any) -> tuple[ArrayComplex, ArrayComplex, EpsSpecType]: """A wrapper around ``EigSolver.compute_modes``, which is used in :class:`.ModeSolver`.""" return EigSolver.compute_modes(*args, **kwargs) diff --git a/tidy3d/components/mode/transforms.py b/tidy3d/components/mode/transforms.py index c498be1af9..d0a433768e 100644 --- a/tidy3d/components/mode/transforms.py +++ b/tidy3d/components/mode/transforms.py @@ -10,10 +10,21 @@ from __future__ import annotations +from collections.abc import Sequence + import numpy as np +from numpy.typing import NDArray + +ArrayFloat = NDArray[np.floating] +CoordsTuple = tuple[ArrayFloat, ArrayFloat] -def radial_transform(coords, radius, bend_axis, plane_center): +def radial_transform( + coords: CoordsTuple, + radius: float, + bend_axis: int, + plane_center: Sequence[float], +) -> tuple[CoordsTuple, ArrayFloat, ArrayFloat]: """Compute the new coordinates and the Jacobian of a polar coordinate transformation. After offsetting the plane such that its center is a distance of ``radius`` away from the center of curvature, we have, e.g. for ``bend_axis=='y'``: @@ -73,7 +84,11 @@ def radial_transform(coords, radius, bend_axis, plane_center): return new_coords, jac_e, jac_h -def angled_transform(coords, angle_theta, angle_phi): +def angled_transform( + coords: CoordsTuple, + angle_theta: float, + angle_phi: float, +) -> tuple[CoordsTuple, ArrayFloat, ArrayFloat]: """Compute the new coordinates and the Jacobian for a transformation that "straightens" an angled waveguide such that it is translationally invariant in w. The transformation is u = x - tan(angle) * z @@ -100,7 +115,7 @@ def angled_transform(coords, angle_theta, angle_phi): Nx, Ny = coords[0].size - 1, coords[1].size - 1 # The new coordinates are exactly the same at z = 0 - new_coords = [np.copy(c) for c in coords] + new_coords = tuple(np.copy(c) for c in coords) # The only nontrivial derivatives are dudz, dvdz and they are constant everywhere jac = np.zeros((3, 3, Nx * Ny)) diff --git a/tidy3d/components/viz/axes_utils.py b/tidy3d/components/viz/axes_utils.py index 8fb78f699c..d70d53f9b1 100644 --- a/tidy3d/components/viz/axes_utils.py +++ b/tidy3d/components/viz/axes_utils.py @@ -1,14 +1,23 @@ from __future__ import annotations from functools import wraps -from typing import Any, Optional +from typing import TYPE_CHECKING, Optional from tidy3d.components.types import Ax, Axis, LengthUnit from tidy3d.constants import UnitScaling from tidy3d.exceptions import Tidy3dKeyError +if TYPE_CHECKING: + from typing import Callable, ParamSpec, TypeVar -def _create_unit_aware_locator(): + import matplotlib.ticker as ticker + from matplotlib.axes import Axes + + P = ParamSpec("P") + T = TypeVar("T", bound=Callable[..., Axes]) + + +def _create_unit_aware_locator() -> ticker.Locator: """Create UnitAwareLocator lazily due to matplotlib import restrictions.""" import matplotlib.ticker as ticker @@ -25,15 +34,15 @@ def __init__(self, scale_factor: float) -> None: super().__init__() self.scale_factor = scale_factor - def __call__(self): + def __call__(self) -> list[float]: vmin, vmax = self.axis.get_view_interval() return self.tick_values(vmin, vmax) - def view_limits(self, vmin, vmax): + def view_limits(self, vmin: float, vmax: float) -> tuple[float, float]: """Override to prevent matplotlib from adjusting our limits.""" return vmin, vmax - def tick_values(self, vmin, vmax): + def tick_values(self, vmin: float, vmax: float) -> list[float]: # convert the view range to the target unit vmin_unit = vmin * self.scale_factor vmax_unit = vmax * self.scale_factor @@ -105,13 +114,13 @@ def make_ax() -> Ax: return ax -def add_ax_if_none(plot): +def add_ax_if_none(plot: T) -> T: """Decorates ``plot(*args, **kwargs, ax=None)`` function. if ax=None in the function call, creates an ax and feeds it to rest of function. """ @wraps(plot) - def _plot(*args: Any, **kwargs: Any) -> Ax: + def _plot(*args: P.args, **kwargs: P.kwargs) -> Axes: """New plot function using a generated ax if None.""" if kwargs.get("ax") is None: ax = make_ax() @@ -121,14 +130,14 @@ def _plot(*args: Any, **kwargs: Any) -> Ax: return _plot -def equal_aspect(plot): +def equal_aspect(plot: T) -> T: """Decorates a plotting function returning a matplotlib axes. Ensures the aspect ratio of the returned axes is set to equal. Useful for 2D plots, like sim.plot() or sim_data.plot_fields() """ @wraps(plot) - def _plot(*args: Any, **kwargs: Any) -> Ax: + def _plot(*args: P.args, **kwargs: P.kwargs) -> Axes: """New plot function with equal aspect ratio axes returned.""" ax = plot(*args, **kwargs) ax.set_aspect("equal") diff --git a/tidy3d/components/viz/descartes.py b/tidy3d/components/viz/descartes.py index b743839585..67885ea94a 100644 --- a/tidy3d/components/viz/descartes.py +++ b/tidy3d/components/viz/descartes.py @@ -17,25 +17,28 @@ from typing import Any +from shapely.geometry.base import BaseGeometry + try: from matplotlib.patches import PathPatch from matplotlib.path import Path except ImportError: pass from numpy import array, concatenate, ones +from numpy.typing import NDArray class Polygon: """Adapt Shapely polygons to a common interface""" - def __init__(self, context) -> None: + def __init__(self, context: dict[str, Any]) -> None: if isinstance(context, dict): self.context = context["coordinates"] else: self.context = context @property - def exterior(self): + def exterior(self) -> Any: """Get polygon exterior.""" value = getattr(self.context, "exterior", None) if value is None: @@ -43,7 +46,7 @@ def exterior(self): return value @property - def interiors(self): + def interiors(self) -> Any: """Get polygon interiors.""" value = getattr(self.context, "interiors", None) if value is None: @@ -51,11 +54,11 @@ def interiors(self): return value -def polygon_path(polygon): +def polygon_path(polygon: BaseGeometry) -> Path: """Constructs a compound matplotlib path from a Shapely or GeoJSON-like geometric object""" - def coding(obj): + def coding(obj: Any) -> NDArray: # The codes will be all "LINETO" commands, except for "MOVETO"s at the # beginning of each subpath crds = getattr(obj, "coords", None) @@ -88,7 +91,7 @@ def coding(obj): return Path(vertices, codes) -def polygon_patch(polygon, **kwargs: Any): +def polygon_patch(polygon: BaseGeometry, **kwargs: Any) -> PathPatch: """Constructs a matplotlib patch from a geometric object The ``polygon`` may be a Shapely or GeoJSON-like object with or without holes. diff --git a/tidy3d/components/viz/plot_params.py b/tidy3d/components/viz/plot_params.py index a5500f76a9..533f94d7d1 100644 --- a/tidy3d/components/viz/plot_params.py +++ b/tidy3d/components/viz/plot_params.py @@ -6,6 +6,7 @@ from pydantic import Field, NonNegativeFloat from tidy3d.components.base import Tidy3dBaseModel +from tidy3d.components.viz.visualization_spec import VisualizationSpec class AbstractPlotParams(Tidy3dBaseModel): @@ -25,11 +26,11 @@ def include_kwargs(self, **kwargs: Any) -> AbstractPlotParams: } return self.copy(update=update_dict) - def override_with_viz_spec(self, viz_spec) -> AbstractPlotParams: + def override_with_viz_spec(self, viz_spec: VisualizationSpec) -> AbstractPlotParams: """Override plot params with supplied VisualizationSpec.""" return self.include_kwargs(**dict(viz_spec)) - def to_kwargs(self) -> dict: + def to_kwargs(self) -> dict[str, Any]: """Export the plot parameters as kwargs dict that can be supplied to plot function.""" kwarg_dict = self.model_dump() for ignore_key in ("type", "attrs"): diff --git a/tidy3d/components/viz/plot_sim_3d.py b/tidy3d/components/viz/plot_sim_3d.py index 7111309446..c1a1a8ad97 100644 --- a/tidy3d/components/viz/plot_sim_3d.py +++ b/tidy3d/components/viz/plot_sim_3d.py @@ -1,11 +1,17 @@ from __future__ import annotations from html import escape +from typing import TYPE_CHECKING, Union from tidy3d.exceptions import SetupError +if TYPE_CHECKING: + from IPython.core.display_functions import DisplayHandle -def plot_scene_3d(scene, width=800, height=800) -> None: + from tidy3d import Scene, Simulation + + +def plot_scene_3d(scene: Scene, width: int = 800, height: int = 800) -> None: import gzip import json from base64 import b64encode @@ -23,7 +29,7 @@ def plot_scene_3d(scene, width=800, height=800) -> None: buffer2 = BytesIO() with h5py.File(buffer2, "w") as dst: - def copy_item(name, obj) -> None: + def copy_item(name: str, obj: h5py.Group | h5py.Dataset) -> None: if isinstance(obj, h5py.Group): dst.create_group(name) for k, v in obj.attrs.items(): @@ -60,7 +66,9 @@ def copy_item(name, obj) -> None: plot_sim_3d(sim_base64, width=width, height=height, is_gz_base64=True) -def plot_sim_3d(sim, width=800, height=800, is_gz_base64=False) -> None: +def plot_sim_3d( + sim: Union[Simulation, str], width: int = 800, height: int = 800, is_gz_base64: bool = False +) -> DisplayHandle: """Make 3D display of simulation in ipython notebook.""" try: diff --git a/tidy3d/components/viz/visualization_spec.py b/tidy3d/components/viz/visualization_spec.py index 740e7009bc..14dc7d96e1 100644 --- a/tidy3d/components/viz/visualization_spec.py +++ b/tidy3d/components/viz/visualization_spec.py @@ -1,6 +1,6 @@ from __future__ import annotations -from pydantic import Field, field_validator +from pydantic import Field, ValidationInfo, field_validator from tidy3d.components.base import Tidy3dBaseModel from tidy3d.log import log @@ -57,7 +57,7 @@ def _validate_facecolor(cls, value: str) -> str: @field_validator("edgecolor") @classmethod - def _ensure_edgecolor(cls, value, info) -> str: + def _ensure_edgecolor(cls, value: str, info: ValidationInfo) -> str: # if no explicit edgecolor given, fall back to facecolor if (value == "") and "facecolor" in info.data: return is_valid_color(info.data["facecolor"])