Skip to content

Commit

Permalink
unified validation check for missing dependency fields
Browse files Browse the repository at this point in the history
  • Loading branch information
dbochkov-flexcompute committed Jan 4, 2024
1 parent efb87ae commit b10bb94
Show file tree
Hide file tree
Showing 28 changed files with 140 additions and 74 deletions.
6 changes: 3 additions & 3 deletions tests/test_components/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,11 +505,11 @@ def test_validate_components_none():
assert SIM._source_homogeneous_isotropic(val=None, values=SIM.dict()) is None


def test_sources_edge_case_validation():
def test_sources_edge_case_validation(log_capture):
values = SIM.dict()
values.pop("sources")
with pytest.raises(ValidationError):
SIM._warn_monitor_simulation_frequency_range(val="test", values=values)
SIM._warn_monitor_simulation_frequency_range(val="test", values=values)
assert_log_level(log_capture, "WARNING")


def test_validate_size_run_time(monkeypatch):
Expand Down
4 changes: 3 additions & 1 deletion tidy3d/components/apodization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pydantic.v1 as pd
import numpy as np

from .base import Tidy3dBaseModel
from .base import Tidy3dBaseModel, skip_if_fields_missing
from ..constants import SECOND
from ..exceptions import SetupError
from .types import ArrayFloat1D, Ax
Expand Down Expand Up @@ -40,6 +40,7 @@ class ApodizationSpec(Tidy3dBaseModel):
)

@pd.validator("end", always=True, allow_reuse=True)
@skip_if_fields_missing(["start"])
def end_greater_than_start(cls, val, values):
"""Ensure end is greater than or equal to start."""
start = values.get("start")
Expand All @@ -48,6 +49,7 @@ def end_greater_than_start(cls, val, values):
return val

@pd.validator("width", always=True, allow_reuse=True)
@skip_if_fields_missing(["start", "end"])
def width_provided(cls, val, values):
"""Check that width is provided if either start or end apodization is requested."""
start = values.get("start")
Expand Down
22 changes: 22 additions & 0 deletions tidy3d/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,28 @@ def _get_valid_extension(fname: str) -> str:
)


def skip_if_fields_missing(fields: List[str]):
"""Decorate ``validator`` to check that other fields have passed validation."""

def actual_decorator(validator):
@wraps(validator)
def _validator(cls, val, values):
"""New validator function."""
for field in fields:
if field not in values:
log.warning(
f"Could not execute validator '{validator.__name__}' because field "
f"'{field}' failed validation."
)
return val

return validator(cls, val, values)

return _validator

return actual_decorator


class Tidy3dBaseModel(pydantic.BaseModel):
"""Base pydantic model that all Tidy3d components inherit from.
Defines configuration for handling data structures
Expand Down
10 changes: 4 additions & 6 deletions tidy3d/components/base_sim/data/sim_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..simulation import AbstractSimulation
from ...data.dataset import UnstructuredGridDatasetType
from ...base import Tidy3dBaseModel
from ...base import skip_if_fields_missing
from ...types import FieldVal
from ....exceptions import DataError, Tidy3dKeyError, ValidationError

Expand Down Expand Up @@ -51,13 +52,13 @@ def monitor_data(self) -> Dict[str, AbstractMonitorData]:
return {monitor_data.monitor.name: monitor_data for monitor_data in self.data}

@pd.validator("data", always=True)
@skip_if_fields_missing(["simulation"])
def data_monitors_match_sim(cls, val, values):
"""Ensure each :class:`AbstractMonitorData` in ``.data`` corresponds to a monitor in
``.simulation``.
"""
sim = values.get("simulation")
if sim is None:
raise ValidationError("'.simulation' failed validation, can't validate data.")

for mnt_data in val:
try:
monitor_name = mnt_data.monitor.name
Expand All @@ -70,14 +71,11 @@ def data_monitors_match_sim(cls, val, values):
return val

@pd.validator("data", always=True)
@skip_if_fields_missing(["simulation"])
def validate_no_ambiguity(cls, val, values):
"""Ensure all :class:`AbstractMonitorData` entries in ``.data`` correspond to different
monitors in ``.simulation``.
"""
sim = values.get("simulation")
if sim is None:
raise ValidationError("'.simulation' failed validation, can't validate data.")

names = [mnt_data.monitor.name for mnt_data in val]

if len(set(names)) != len(names):
Expand Down
3 changes: 2 additions & 1 deletion tidy3d/components/base_sim/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .monitor import AbstractMonitor

from ..base import cached_property
from ..base import cached_property, skip_if_fields_missing
from ..validators import assert_unique_names, assert_objects_in_sim_bounds
from ..geometry.base import Box
from ..types import Ax, Bound, Axis, Symmetry, TYPE_TAG_STR
Expand Down Expand Up @@ -97,6 +97,7 @@ class AbstractSimulation(Box, ABC):
_structures_in_bounds = assert_objects_in_sim_bounds("structures", error=False)

@pd.validator("structures", always=True)
@skip_if_fields_missing(["size", "center"])
def _structures_not_at_edges(cls, val, values):
"""Warn if any structures lie at the simulation boundaries."""

Expand Down
7 changes: 3 additions & 4 deletions tidy3d/components/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from ..viz import equal_aspect, add_ax_if_none, plot_params_grid
from ..base import Tidy3dBaseModel, cached_property
from ..base import skip_if_fields_missing
from ..types import Axis, Bound, ArrayLike, Ax, Coordinate, Literal
from ..types import vtk, requires_vtk
from ...exceptions import DataError, ValidationError, Tidy3dNotImplementedError
Expand Down Expand Up @@ -524,13 +525,12 @@ def match_cells_to_vtk_type(cls, val):
return CellDataArray(val.data.astype(vtk["id_type"], copy=False), coords=val.coords)

@pd.validator("values", always=True)
@skip_if_fields_missing(["points"])
def number_of_values_matches_points(cls, val, values):
"""Check that the number of data values matches the number of grid points."""
num_values = len(val)

points = values.get("points")
if points is None:
raise ValidationError("Cannot validate '.values' because '.points' failed validation.")
num_points = len(points)

if num_points != num_values:
Expand Down Expand Up @@ -565,15 +565,14 @@ def cells_right_type(cls, val):
return val

@pd.validator("cells", always=True)
@skip_if_fields_missing(["points"])
def check_cell_vertex_range(cls, val, values):
"""Check that cell connections use only defined points."""
all_point_indices_used = val.data.ravel()
min_index_used = np.min(all_point_indices_used)
max_index_used = np.max(all_point_indices_used)

points = values.get("points")
if points is None:
raise ValidationError("Cannot validate '.values' because '.points' failed validation.")
num_points = len(points)

if max_index_used != num_points - 1 or min_index_used != 0:
Expand Down
3 changes: 2 additions & 1 deletion tidy3d/components/data/monitor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .data_array import FreqDataArray, TimeDataArray, FreqModeDataArray
from .dataset import Dataset, AbstractFieldDataset, ElectromagneticFieldDataset
from .dataset import FieldDataset, FieldTimeDataset, ModeSolverDataset, PermittivityDataset
from ..base import TYPE_TAG_STR, cached_property
from ..base import TYPE_TAG_STR, cached_property, skip_if_fields_missing
from ..types import Coordinate, Symmetry, ArrayFloat1D, ArrayFloat2D, Size, Numpy, TrackFreq
from ..types import EpsSpecType, Literal
from ..grid.grid import Grid, Coords
Expand Down Expand Up @@ -926,6 +926,7 @@ class ModeSolverData(ModeSolverDataset, ElectromagneticFieldData):
)

@pd.validator("eps_spec", always=True)
@skip_if_fields_missing(["monitor"])
def eps_spec_match_mode_spec(cls, val, values):
"""Raise validation error if frequencies in eps_spec does not match frequency list"""
if val:
Expand Down
3 changes: 2 additions & 1 deletion tidy3d/components/field_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .monitor import FieldProjectionCartesianMonitor, FieldProjectionKSpaceMonitor
from .types import Direction, Coordinate, ArrayComplex4D
from .medium import MediumType
from .base import Tidy3dBaseModel, cached_property
from .base import Tidy3dBaseModel, cached_property, skip_if_fields_missing
from ..exceptions import SetupError
from ..constants import C_0, MICROMETER, ETA_0, EPSILON_0, MU_0
from ..log import get_logging_console
Expand Down Expand Up @@ -72,6 +72,7 @@ class FieldProjector(Tidy3dBaseModel):
)

@pydantic.validator("origin", always=True)
@skip_if_fields_missing(["surfaces"])
def set_origin(cls, val, values):
"""Sets .origin as the average of centers of all surface monitors if not provided."""
if val is None:
Expand Down
5 changes: 3 additions & 2 deletions tidy3d/components/geometry/polyslab.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from matplotlib import path

from ..base import cached_property
from ..base import skip_if_fields_missing
from ..types import Axis, Bound, PlanePosition, ArrayFloat2D, Coordinate
from ..types import MatrixReal4x4, Shapely, trimesh
from ...log import log
Expand Down Expand Up @@ -105,6 +106,7 @@ def correct_shape(cls, val):
return val

@pydantic.validator("vertices", always=True)
@skip_if_fields_missing(["dilation"])
def no_complex_self_intersecting_polygon_at_reference_plane(cls, val, values):
"""At the reference plane, check if the polygon is self-intersecting.
Expand Down Expand Up @@ -154,6 +156,7 @@ def no_complex_self_intersecting_polygon_at_reference_plane(cls, val, values):
return val

@pydantic.validator("vertices", always=True)
@skip_if_fields_missing(["sidewall_angle", "dilation", "slab_bounds", "reference_plane"])
def no_self_intersecting_polygon_during_extrusion(cls, val, values):
"""In this simple polyslab, we don't support self-intersecting polygons yet, meaning that
any normal cross section of the PolySlab cannot be self-intersecting. This part checks
Expand All @@ -168,8 +171,6 @@ def no_self_intersecting_polygon_during_extrusion(cls, val, values):
To detect this, we sample _N_SAMPLE_POLYGON_INTERSECT cross sections to see if any creation
of polygons/holes, and changes in vertices number.
"""
if "sidewall_angle" not in values:
raise ValidationError("'sidewall_angle' failed validation.")

# no need to valiate anything here
if isclose(values["sidewall_angle"], 0):
Expand Down
3 changes: 2 additions & 1 deletion tidy3d/components/geometry/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import shapely

from ..base import cached_property
from ..base import cached_property, skip_if_fields_missing
from ..types import Axis, Bound, Coordinate, MatrixReal4x4, Shapely, trimesh
from ...exceptions import SetupError, ValidationError
from ...constants import MICROMETER, LARGE_NUMBER
Expand Down Expand Up @@ -191,6 +191,7 @@ class Cylinder(base.Centered, base.Circular, base.Planar):
)

@pydantic.validator("length", always=True)
@skip_if_fields_missing(["sidewall_angle", "reference_plane"])
def _only_middle_for_infinite_length_slanted_cylinder(cls, val, values):
"""For a slanted cylinder of infinite length, ``reference_plane`` can only
be ``middle``; otherwise, the radius at ``center`` is either td.inf or 0.
Expand Down
2 changes: 2 additions & 0 deletions tidy3d/components/heat/data/monitor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pydantic.v1 as pd

from ..monitor import TemperatureMonitor, HeatMonitorType
from ...base import skip_if_fields_missing
from ...base_sim.data.monitor_data import AbstractMonitorData
from ...data.data_array import SpatialDataArray
from ...data.dataset import TriangularGridDataset, TetrahedralGridDataset
Expand Down Expand Up @@ -74,6 +75,7 @@ class TemperatureData(HeatMonitorData):
)

@pd.validator("temperature", always=True)
@skip_if_fields_missing(["monitor"])
def warn_no_data(cls, val, values):
"""Warn if no data provided."""

Expand Down
3 changes: 2 additions & 1 deletion tidy3d/components/heat/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Union, Tuple
import pydantic.v1 as pd

from ..base import Tidy3dBaseModel
from ..base import Tidy3dBaseModel, skip_if_fields_missing
from ...constants import MICROMETER
from ...exceptions import ValidationError

Expand Down Expand Up @@ -107,6 +107,7 @@ class DistanceUnstructuredGrid(Tidy3dBaseModel):
)

@pd.validator("distance_bulk", always=True)
@skip_if_fields_missing(["distance_interface"])
def names_exist_bcs(cls, val, values):
"""Error if distance_bulk is less than distance_interface"""
distance_interface = values.get("distance_interface")
Expand Down
5 changes: 4 additions & 1 deletion tidy3d/components/heat/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .viz import plot_params_heat_bc, plot_params_heat_source, HEAT_SOURCE_CMAP

from ..base_sim.simulation import AbstractSimulation
from ..base import cached_property
from ..base import cached_property, skip_if_fields_missing
from ..types import Ax, Shapely, TYPE_TAG_STR, ScalarSymmetry, Bound
from ..viz import add_ax_if_none, equal_aspect, PlotParams
from ..structure import Structure
Expand Down Expand Up @@ -139,6 +139,7 @@ def check_zero_dim_domain(cls, val, values):
return val

@pd.validator("boundary_spec", always=True)
@skip_if_fields_missing(["structures", "medium"])
def names_exist_bcs(cls, val, values):
"""Error if boundary conditions point to non-existing structures/media."""

Expand Down Expand Up @@ -175,6 +176,7 @@ def names_exist_bcs(cls, val, values):
return val

@pd.validator("grid_spec", always=True)
@skip_if_fields_missing(["structures"])
def names_exist_grid_spec(cls, val, values):
"""Warn if UniformUnstructuredGrid points at a non-existing structure."""

Expand All @@ -191,6 +193,7 @@ def names_exist_grid_spec(cls, val, values):
return val

@pd.validator("sources", always=True)
@skip_if_fields_missing(["structures"])
def names_exist_sources(cls, val, values):
"""Error if a heat source point to non-existing structures."""
structures = values.get("structures")
Expand Down
Loading

0 comments on commit b10bb94

Please sign in to comment.