Skip to content

Commit

Permalink
polishing unified validation check
Browse files Browse the repository at this point in the history
  • Loading branch information
dbochkov-flexcompute committed Jan 4, 2024
1 parent d0a7616 commit df5a83b
Show file tree
Hide file tree
Showing 22 changed files with 99 additions and 49 deletions.
4 changes: 3 additions & 1 deletion tidy3d/components/apodization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pydantic.v1 as pd
import numpy as np

from .base import Tidy3dBaseModel
from .base import Tidy3dBaseModel, skip_if_fields_missing
from ..constants import SECOND
from ..exceptions import SetupError
from .types import ArrayFloat1D, Ax
Expand Down Expand Up @@ -40,6 +40,7 @@ class ApodizationSpec(Tidy3dBaseModel):
)

@pd.validator("end", always=True, allow_reuse=True)
@skip_if_fields_missing(["start"])
def end_greater_than_start(cls, val, values):
"""Ensure end is greater than or equal to start."""
start = values.get("start")
Expand All @@ -48,6 +49,7 @@ def end_greater_than_start(cls, val, values):
return val

@pd.validator("width", always=True, allow_reuse=True)
@skip_if_fields_missing(["start", "end"])
def width_provided(cls, val, values):
"""Check that width is provided if either start or end apodization is requested."""
start = values.get("start")
Expand Down
4 changes: 2 additions & 2 deletions tidy3d/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@ def _get_valid_extension(fname: str) -> str:
)


def check_previous_fields_validation(required_fields):
def skip_if_fields_missing(fields: List[str]):
"""Decorate ``validator`` to check that other fields have passed validation."""

def actual_decorator(validator):
@wraps(validator)
def _validator(cls, val, values):
"""New validator function."""
for field in required_fields:
for field in fields:
if field not in values:
log.warning(
f"Could not execute validator '{validator.__name__}' because field "
Expand Down
6 changes: 3 additions & 3 deletions tidy3d/components/base_sim/data/sim_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..simulation import AbstractSimulation
from ...data.dataset import UnstructuredGridDatasetType
from ...base import Tidy3dBaseModel
from ...base import check_previous_fields_validation
from ...base import skip_if_fields_missing
from ...types import FieldVal
from ....exceptions import DataError, Tidy3dKeyError, ValidationError

Expand Down Expand Up @@ -52,7 +52,7 @@ def monitor_data(self) -> Dict[str, AbstractMonitorData]:
return {monitor_data.monitor.name: monitor_data for monitor_data in self.data}

@pd.validator("data", always=True)
@check_previous_fields_validation(["simulation"])
@skip_if_fields_missing(["simulation"])
def data_monitors_match_sim(cls, val, values):
"""Ensure each :class:`AbstractMonitorData` in ``.data`` corresponds to a monitor in
``.simulation``.
Expand All @@ -71,7 +71,7 @@ def data_monitors_match_sim(cls, val, values):
return val

@pd.validator("data", always=True)
@check_previous_fields_validation(["simulation"])
@skip_if_fields_missing(["simulation"])
def validate_no_ambiguity(cls, val, values):
"""Ensure all :class:`AbstractMonitorData` entries in ``.data`` correspond to different
monitors in ``.simulation``.
Expand Down
3 changes: 2 additions & 1 deletion tidy3d/components/base_sim/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .monitor import AbstractMonitor

from ..base import cached_property
from ..base import cached_property, skip_if_fields_missing
from ..validators import assert_unique_names, assert_objects_in_sim_bounds
from ..geometry.base import Box
from ..types import Ax, Bound, Axis, Symmetry, TYPE_TAG_STR
Expand Down Expand Up @@ -97,6 +97,7 @@ class AbstractSimulation(Box, ABC):
_structures_in_bounds = assert_objects_in_sim_bounds("structures", error=False)

@pd.validator("structures", always=True)
@skip_if_fields_missing(["size", "center"])
def _structures_not_at_edges(cls, val, values):
"""Warn if any structures lie at the simulation boundaries."""

Expand Down
6 changes: 3 additions & 3 deletions tidy3d/components/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from ..viz import equal_aspect, add_ax_if_none, plot_params_grid
from ..base import Tidy3dBaseModel, cached_property
from ..base import check_previous_fields_validation
from ..base import skip_if_fields_missing
from ..types import Axis, Bound, ArrayLike, Ax, Coordinate, Literal
from ..types import vtk, requires_vtk
from ...exceptions import DataError, ValidationError, Tidy3dNotImplementedError
Expand Down Expand Up @@ -525,7 +525,7 @@ def match_cells_to_vtk_type(cls, val):
return CellDataArray(val.data.astype(vtk["id_type"], copy=False), coords=val.coords)

@pd.validator("values", always=True)
@check_previous_fields_validation(["points"])
@skip_if_fields_missing(["points"])
def number_of_values_matches_points(cls, val, values):
"""Check that the number of data values matches the number of grid points."""
num_values = len(val)
Expand Down Expand Up @@ -565,7 +565,7 @@ def cells_right_type(cls, val):
return val

@pd.validator("cells", always=True)
@check_previous_fields_validation(["points"])
@skip_if_fields_missing(["points"])
def check_cell_vertex_range(cls, val, values):
"""Check that cell connections use only defined points."""
all_point_indices_used = val.data.ravel()
Expand Down
3 changes: 2 additions & 1 deletion tidy3d/components/field_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .monitor import FieldProjectionCartesianMonitor, FieldProjectionKSpaceMonitor
from .types import Direction, Coordinate, ArrayComplex4D
from .medium import MediumType
from .base import Tidy3dBaseModel, cached_property
from .base import Tidy3dBaseModel, cached_property, skip_if_fields_missing
from ..exceptions import SetupError
from ..constants import C_0, MICROMETER, ETA_0, EPSILON_0, MU_0
from ..log import get_logging_console
Expand Down Expand Up @@ -72,6 +72,7 @@ class FieldProjector(Tidy3dBaseModel):
)

@pydantic.validator("origin", always=True)
@skip_if_fields_missing(["surfaces"])
def set_origin(cls, val, values):
"""Sets .origin as the average of centers of all surface monitors if not provided."""
if val is None:
Expand Down
4 changes: 2 additions & 2 deletions tidy3d/components/geometry/polyslab.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from matplotlib import path

from ..base import cached_property
from ..base import check_previous_fields_validation
from ..base import skip_if_fields_missing
from ..types import Axis, Bound, PlanePosition, ArrayFloat2D, Coordinate
from ..types import MatrixReal4x4, Shapely, trimesh
from ...log import log
Expand Down Expand Up @@ -155,7 +155,7 @@ def no_complex_self_intersecting_polygon_at_reference_plane(cls, val, values):
return val

@pydantic.validator("vertices", always=True)
@check_previous_fields_validation(["sidewall_angle"])
@skip_if_fields_missing(["sidewall_angle"])
def no_self_intersecting_polygon_during_extrusion(cls, val, values):
"""In this simple polyslab, we don't support self-intersecting polygons yet, meaning that
any normal cross section of the PolySlab cannot be self-intersecting. This part checks
Expand Down
2 changes: 2 additions & 0 deletions tidy3d/components/heat/data/monitor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pydantic.v1 as pd

from ..monitor import TemperatureMonitor, HeatMonitorType
from ...base import skip_if_fields_missing
from ...base_sim.data.monitor_data import AbstractMonitorData
from ...data.data_array import SpatialDataArray
from ...data.dataset import TriangularGridDataset, TetrahedralGridDataset
Expand Down Expand Up @@ -74,6 +75,7 @@ class TemperatureData(HeatMonitorData):
)

@pd.validator("temperature", always=True)
@skip_if_fields_missing(["monitor"])
def warn_no_data(cls, val, values):
"""Warn if no data provided."""

Expand Down
3 changes: 2 additions & 1 deletion tidy3d/components/heat/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Union, Tuple
import pydantic.v1 as pd

from ..base import Tidy3dBaseModel
from ..base import Tidy3dBaseModel, skip_if_fields_missing
from ...constants import MICROMETER
from ...exceptions import ValidationError

Expand Down Expand Up @@ -107,6 +107,7 @@ class DistanceUnstructuredGrid(Tidy3dBaseModel):
)

@pd.validator("distance_bulk", always=True)
@skip_if_fields_missing(["distance_interface"])
def names_exist_bcs(cls, val, values):
"""Error if distance_bulk is less than distance_interface"""
distance_interface = values.get("distance_interface")
Expand Down
5 changes: 4 additions & 1 deletion tidy3d/components/heat/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .viz import plot_params_heat_bc, plot_params_heat_source, HEAT_SOURCE_CMAP

from ..base_sim.simulation import AbstractSimulation
from ..base import cached_property
from ..base import cached_property, skip_if_fields_missing
from ..types import Ax, Shapely, TYPE_TAG_STR, ScalarSymmetry, Bound
from ..viz import add_ax_if_none, equal_aspect, PlotParams
from ..structure import Structure
Expand Down Expand Up @@ -139,6 +139,7 @@ def check_zero_dim_domain(cls, val, values):
return val

@pd.validator("boundary_spec", always=True)
@skip_if_fields_missing(["structures", "medium"])
def names_exist_bcs(cls, val, values):
"""Error if boundary conditions point to non-existing structures/media."""

Expand Down Expand Up @@ -175,6 +176,7 @@ def names_exist_bcs(cls, val, values):
return val

@pd.validator("grid_spec", always=True)
@skip_if_fields_missing(["structures"])
def names_exist_grid_spec(cls, val, values):
"""Warn if UniformUnstructuredGrid points at a non-existing structure."""

Expand All @@ -191,6 +193,7 @@ def names_exist_grid_spec(cls, val, values):
return val

@pd.validator("sources", always=True)
@skip_if_fields_missing(["structures"])
def names_exist_sources(cls, val, values):
"""Error if a heat source point to non-existing structures."""
structures = values.get("structures")
Expand Down
Loading

0 comments on commit df5a83b

Please sign in to comment.