Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions tidy3d/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ class Config:
copy_on_model_validation = "none"

_cached_properties = pydantic.PrivateAttr({})
_has_tracers: Optional[bool] = pydantic.PrivateAttr(default=None)

@pydantic.root_validator(skip_on_failure=True)
def _special_characters_not_in_name(cls, values):
Expand Down Expand Up @@ -283,6 +284,7 @@ def copy(self, deep: bool = True, validate: bool = True, **kwargs: Any) -> Self:
# cached property is cleared automatically when validation is on, but it
# needs to be manually cleared when validation is off
new_copy._cached_properties = {}
new_copy._has_tracers = None
return new_copy

def updated_copy(
Expand Down Expand Up @@ -1054,7 +1056,7 @@ def _json(self, indent=INDENT, exclude_unset=False, **kwargs: Any) -> str:
return json_string

def _strip_traced_fields(
self, starting_path: tuple[str] = (), include_untraced_data_arrays: bool = False
self, starting_path: tuple[str, ...] = (), include_untraced_data_arrays: bool = False
) -> AutogradFieldMap:
"""Extract a dictionary mapping paths in the model to the data traced by ``autograd``.

Expand All @@ -1073,6 +1075,10 @@ def _strip_traced_fields(

"""

path = tuple(starting_path)
if self._has_tracers is False and not include_untraced_data_arrays:
return dict_ag()

field_mapping = {}

def handle_value(x: Any, path: tuple[str, ...]) -> None:
Expand Down Expand Up @@ -1100,14 +1106,20 @@ def handle_value(x: Any, path: tuple[str, ...]) -> None:
self_dict = self.dict()

# if an include_only string was provided, only look at that subset of the dict
if starting_path:
for key in starting_path:
if path:
for key in path:
self_dict = self_dict[key]

handle_value(self_dict, path=starting_path)
handle_value(self_dict, path=path)

if field_mapping:
if not include_untraced_data_arrays:
self._has_tracers = True
return dict_ag(field_mapping)

# convert the resulting field_mapping to an autograd-traced dictionary
return dict_ag(field_mapping)
if not include_untraced_data_arrays and not path:
self._has_tracers = False
return dict_ag()

def _insert_traced_fields(self, field_mapping: AutogradFieldMap) -> Self:
"""Recursively insert a map of paths to autograd-traced fields into a copy of this obj."""
Expand Down Expand Up @@ -1157,18 +1169,24 @@ def _serialized_traced_field_keys(
def to_static(self) -> Self:
"""Version of object with all autograd-traced fields removed."""

if self._has_tracers is False:
return self

# get dictionary of all traced fields
field_mapping = self._strip_traced_fields()

# shortcut to just return self if no tracers found, for performance
if not field_mapping:
self._has_tracers = False
return self

# convert all fields to static values
field_mapping_static = {key: get_static(val) for key, val in field_mapping.items()}

# insert the static values into a copy of self
return self._insert_traced_fields(field_mapping_static)
static_self = self._insert_traced_fields(field_mapping_static)
static_self._has_tracers = False
return static_self

@classmethod
def add_type_field(cls) -> None:
Expand Down
Loading