Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unified validation check for dependency fields #1348

Merged
merged 1 commit into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/test_components/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,11 +505,11 @@ def test_validate_components_none():
assert SIM._source_homogeneous_isotropic(val=None, values=SIM.dict()) is None


def test_sources_edge_case_validation():
def test_sources_edge_case_validation(log_capture):
values = SIM.dict()
values.pop("sources")
with pytest.raises(ValidationError):
SIM._warn_monitor_simulation_frequency_range(val="test", values=values)
SIM._warn_monitor_simulation_frequency_range(val="test", values=values)
assert_log_level(log_capture, "WARNING")


def test_validate_size_run_time(monkeypatch):
Expand Down
4 changes: 3 additions & 1 deletion tidy3d/components/apodization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pydantic.v1 as pd
import numpy as np

from .base import Tidy3dBaseModel
from .base import Tidy3dBaseModel, skip_if_fields_missing
from ..constants import SECOND
from ..exceptions import SetupError
from .types import ArrayFloat1D, Ax
Expand Down Expand Up @@ -40,6 +40,7 @@ class ApodizationSpec(Tidy3dBaseModel):
)

@pd.validator("end", always=True, allow_reuse=True)
@skip_if_fields_missing(["start"])
def end_greater_than_start(cls, val, values):
"""Ensure end is greater than or equal to start."""
start = values.get("start")
Expand All @@ -48,6 +49,7 @@ def end_greater_than_start(cls, val, values):
return val

@pd.validator("width", always=True, allow_reuse=True)
@skip_if_fields_missing(["start", "end"])
def width_provided(cls, val, values):
"""Check that width is provided if either start or end apodization is requested."""
start = values.get("start")
Expand Down
22 changes: 22 additions & 0 deletions tidy3d/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,28 @@ def _get_valid_extension(fname: str) -> str:
)


def skip_if_fields_missing(fields: List[str]):
"""Decorate ``validator`` to check that other fields have passed validation."""

def actual_decorator(validator):
@wraps(validator)
def _validator(cls, val, values):
"""New validator function."""
for field in fields:
if field not in values:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to confirm, if field fails validation, it's just not put in values? (as opposed to put into values with None?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, just double checked with pydantic documentation https://docs.pydantic.dev/1.10/usage/validators/:

If validation fails on another field (or that field is missing) it will not be included in values, hence if 'password1' in values and ... in this example.

log.warning(
f"Could not execute validator '{validator.__name__}' because field "
f"'{field}' failed validation."
)
return val

return validator(cls, val, values)

return _validator

return actual_decorator


class Tidy3dBaseModel(pydantic.BaseModel):
"""Base pydantic model that all Tidy3d components inherit from.
Defines configuration for handling data structures
Expand Down
10 changes: 4 additions & 6 deletions tidy3d/components/base_sim/data/sim_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..simulation import AbstractSimulation
from ...data.dataset import UnstructuredGridDatasetType
from ...base import Tidy3dBaseModel
from ...base import skip_if_fields_missing
from ...types import FieldVal
from ....exceptions import DataError, Tidy3dKeyError, ValidationError

Expand Down Expand Up @@ -51,13 +52,13 @@ def monitor_data(self) -> Dict[str, AbstractMonitorData]:
return {monitor_data.monitor.name: monitor_data for monitor_data in self.data}

@pd.validator("data", always=True)
@skip_if_fields_missing(["simulation"])
def data_monitors_match_sim(cls, val, values):
"""Ensure each :class:`AbstractMonitorData` in ``.data`` corresponds to a monitor in
``.simulation``.
"""
sim = values.get("simulation")
if sim is None:
raise ValidationError("'.simulation' failed validation, can't validate data.")

for mnt_data in val:
try:
monitor_name = mnt_data.monitor.name
Expand All @@ -70,14 +71,11 @@ def data_monitors_match_sim(cls, val, values):
return val

@pd.validator("data", always=True)
@skip_if_fields_missing(["simulation"])
def validate_no_ambiguity(cls, val, values):
"""Ensure all :class:`AbstractMonitorData` entries in ``.data`` correspond to different
monitors in ``.simulation``.
"""
sim = values.get("simulation")
if sim is None:
raise ValidationError("'.simulation' failed validation, can't validate data.")

names = [mnt_data.monitor.name for mnt_data in val]

if len(set(names)) != len(names):
Expand Down
3 changes: 2 additions & 1 deletion tidy3d/components/base_sim/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .monitor import AbstractMonitor

from ..base import cached_property
from ..base import cached_property, skip_if_fields_missing
from ..validators import assert_unique_names, assert_objects_in_sim_bounds
from ..geometry.base import Box
from ..types import Ax, Bound, Axis, Symmetry, TYPE_TAG_STR
Expand Down Expand Up @@ -97,6 +97,7 @@ class AbstractSimulation(Box, ABC):
_structures_in_bounds = assert_objects_in_sim_bounds("structures", error=False)

@pd.validator("structures", always=True)
@skip_if_fields_missing(["size", "center"])
def _structures_not_at_edges(cls, val, values):
"""Warn if any structures lie at the simulation boundaries."""

Expand Down
7 changes: 3 additions & 4 deletions tidy3d/components/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from ..viz import equal_aspect, add_ax_if_none, plot_params_grid
from ..base import Tidy3dBaseModel, cached_property
from ..base import skip_if_fields_missing
from ..types import Axis, Bound, ArrayLike, Ax, Coordinate, Literal
from ..types import vtk, requires_vtk
from ...exceptions import DataError, ValidationError, Tidy3dNotImplementedError
Expand Down Expand Up @@ -524,13 +525,12 @@ def match_cells_to_vtk_type(cls, val):
return CellDataArray(val.data.astype(vtk["id_type"], copy=False), coords=val.coords)

@pd.validator("values", always=True)
@skip_if_fields_missing(["points"])
def number_of_values_matches_points(cls, val, values):
"""Check that the number of data values matches the number of grid points."""
num_values = len(val)

points = values.get("points")
if points is None:
raise ValidationError("Cannot validate '.values' because '.points' failed validation.")
num_points = len(points)

if num_points != num_values:
Expand Down Expand Up @@ -565,15 +565,14 @@ def cells_right_type(cls, val):
return val

@pd.validator("cells", always=True)
@skip_if_fields_missing(["points"])
def check_cell_vertex_range(cls, val, values):
"""Check that cell connections use only defined points."""
all_point_indices_used = val.data.ravel()
min_index_used = np.min(all_point_indices_used)
max_index_used = np.max(all_point_indices_used)

points = values.get("points")
if points is None:
raise ValidationError("Cannot validate '.values' because '.points' failed validation.")
num_points = len(points)

if max_index_used != num_points - 1 or min_index_used != 0:
Expand Down
3 changes: 2 additions & 1 deletion tidy3d/components/data/monitor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .data_array import FreqDataArray, TimeDataArray, FreqModeDataArray
from .dataset import Dataset, AbstractFieldDataset, ElectromagneticFieldDataset
from .dataset import FieldDataset, FieldTimeDataset, ModeSolverDataset, PermittivityDataset
from ..base import TYPE_TAG_STR, cached_property
from ..base import TYPE_TAG_STR, cached_property, skip_if_fields_missing
from ..types import Coordinate, Symmetry, ArrayFloat1D, ArrayFloat2D, Size, Numpy, TrackFreq
from ..types import EpsSpecType, Literal
from ..grid.grid import Grid, Coords
Expand Down Expand Up @@ -926,6 +926,7 @@ class ModeSolverData(ModeSolverDataset, ElectromagneticFieldData):
)

@pd.validator("eps_spec", always=True)
@skip_if_fields_missing(["monitor"])
def eps_spec_match_mode_spec(cls, val, values):
"""Raise validation error if frequencies in eps_spec does not match frequency list"""
if val:
Expand Down
3 changes: 2 additions & 1 deletion tidy3d/components/field_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .monitor import FieldProjectionCartesianMonitor, FieldProjectionKSpaceMonitor
from .types import Direction, Coordinate, ArrayComplex4D
from .medium import MediumType
from .base import Tidy3dBaseModel, cached_property
from .base import Tidy3dBaseModel, cached_property, skip_if_fields_missing
from ..exceptions import SetupError
from ..constants import C_0, MICROMETER, ETA_0, EPSILON_0, MU_0
from ..log import get_logging_console
Expand Down Expand Up @@ -72,6 +72,7 @@ class FieldProjector(Tidy3dBaseModel):
)

@pydantic.validator("origin", always=True)
@skip_if_fields_missing(["surfaces"])
def set_origin(cls, val, values):
"""Sets .origin as the average of centers of all surface monitors if not provided."""
if val is None:
Expand Down
5 changes: 3 additions & 2 deletions tidy3d/components/geometry/polyslab.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from matplotlib import path

from ..base import cached_property
from ..base import skip_if_fields_missing
from ..types import Axis, Bound, PlanePosition, ArrayFloat2D, Coordinate
from ..types import MatrixReal4x4, Shapely, trimesh
from ...log import log
Expand Down Expand Up @@ -105,6 +106,7 @@ def correct_shape(cls, val):
return val

@pydantic.validator("vertices", always=True)
@skip_if_fields_missing(["dilation"])
def no_complex_self_intersecting_polygon_at_reference_plane(cls, val, values):
"""At the reference plane, check if the polygon is self-intersecting.

Expand Down Expand Up @@ -154,6 +156,7 @@ def no_complex_self_intersecting_polygon_at_reference_plane(cls, val, values):
return val

@pydantic.validator("vertices", always=True)
@skip_if_fields_missing(["sidewall_angle", "dilation", "slab_bounds", "reference_plane"])
def no_self_intersecting_polygon_during_extrusion(cls, val, values):
"""In this simple polyslab, we don't support self-intersecting polygons yet, meaning that
any normal cross section of the PolySlab cannot be self-intersecting. This part checks
Expand All @@ -168,8 +171,6 @@ def no_self_intersecting_polygon_during_extrusion(cls, val, values):
To detect this, we sample _N_SAMPLE_POLYGON_INTERSECT cross sections to see if any creation
of polygons/holes, and changes in vertices number.
"""
if "sidewall_angle" not in values:
raise ValidationError("'sidewall_angle' failed validation.")

# no need to valiate anything here
if isclose(values["sidewall_angle"], 0):
Expand Down
3 changes: 2 additions & 1 deletion tidy3d/components/geometry/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import shapely

from ..base import cached_property
from ..base import cached_property, skip_if_fields_missing
from ..types import Axis, Bound, Coordinate, MatrixReal4x4, Shapely, trimesh
from ...exceptions import SetupError, ValidationError
from ...constants import MICROMETER, LARGE_NUMBER
Expand Down Expand Up @@ -191,6 +191,7 @@ class Cylinder(base.Centered, base.Circular, base.Planar):
)

@pydantic.validator("length", always=True)
@skip_if_fields_missing(["sidewall_angle", "reference_plane"])
def _only_middle_for_infinite_length_slanted_cylinder(cls, val, values):
"""For a slanted cylinder of infinite length, ``reference_plane`` can only
be ``middle``; otherwise, the radius at ``center`` is either td.inf or 0.
Expand Down
2 changes: 2 additions & 0 deletions tidy3d/components/heat/data/monitor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pydantic.v1 as pd

from ..monitor import TemperatureMonitor, HeatMonitorType
from ...base import skip_if_fields_missing
from ...base_sim.data.monitor_data import AbstractMonitorData
from ...data.data_array import SpatialDataArray
from ...data.dataset import TriangularGridDataset, TetrahedralGridDataset
Expand Down Expand Up @@ -74,6 +75,7 @@ class TemperatureData(HeatMonitorData):
)

@pd.validator("temperature", always=True)
@skip_if_fields_missing(["monitor"])
def warn_no_data(cls, val, values):
"""Warn if no data provided."""

Expand Down
3 changes: 2 additions & 1 deletion tidy3d/components/heat/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Union, Tuple
import pydantic.v1 as pd

from ..base import Tidy3dBaseModel
from ..base import Tidy3dBaseModel, skip_if_fields_missing
from ...constants import MICROMETER
from ...exceptions import ValidationError

Expand Down Expand Up @@ -107,6 +107,7 @@ class DistanceUnstructuredGrid(Tidy3dBaseModel):
)

@pd.validator("distance_bulk", always=True)
@skip_if_fields_missing(["distance_interface"])
def names_exist_bcs(cls, val, values):
"""Error if distance_bulk is less than distance_interface"""
distance_interface = values.get("distance_interface")
Expand Down
5 changes: 4 additions & 1 deletion tidy3d/components/heat/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .viz import plot_params_heat_bc, plot_params_heat_source, HEAT_SOURCE_CMAP

from ..base_sim.simulation import AbstractSimulation
from ..base import cached_property
from ..base import cached_property, skip_if_fields_missing
from ..types import Ax, Shapely, TYPE_TAG_STR, ScalarSymmetry, Bound
from ..viz import add_ax_if_none, equal_aspect, PlotParams
from ..structure import Structure
Expand Down Expand Up @@ -139,6 +139,7 @@ def check_zero_dim_domain(cls, val, values):
return val

@pd.validator("boundary_spec", always=True)
@skip_if_fields_missing(["structures", "medium"])
def names_exist_bcs(cls, val, values):
"""Error if boundary conditions point to non-existing structures/media."""

Expand Down Expand Up @@ -175,6 +176,7 @@ def names_exist_bcs(cls, val, values):
return val

@pd.validator("grid_spec", always=True)
@skip_if_fields_missing(["structures"])
def names_exist_grid_spec(cls, val, values):
"""Warn if UniformUnstructuredGrid points at a non-existing structure."""

Expand All @@ -191,6 +193,7 @@ def names_exist_grid_spec(cls, val, values):
return val

@pd.validator("sources", always=True)
@skip_if_fields_missing(["structures"])
def names_exist_sources(cls, val, values):
"""Error if a heat source point to non-existing structures."""
structures = values.get("structures")
Expand Down
Loading
Loading