Skip to content

Commit

Permalink
recursively update nested compoennts with updated_copy path kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex committed Apr 10, 2024
1 parent 44d8239 commit 52995d2
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Uniaxial medium Lithium Niobate to material library.
- Added support for conformal mesh methods near PEC structures that can be specified through the field `pec_conformal_mesh_spec` in the `Simulation` class.
- EME solver through `EMESimulation` class.
- Ability to add `path` to `updated_copy()` method to recursively update sub-components of a tidy3d model. For example `sim2 = sim.updated_copy(size=new_size, path="structures/0/geometry")` creates a recursively updated copy of `sim` where `sim.structures[0].geometry` is updated with `size=new_size`.

### Changed
- `run_time` of the adjoint simulation is set more robustly based on the adjoint sources and the forward simulation `run_time` as `sim_fwd.run_time + c / fwdith_adj` where `c=10`.
Expand Down
51 changes: 51 additions & 0 deletions tests/test_components/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,57 @@ def test_updated_copy():
assert s3 == s2


def test_updated_copy_path():
"""Make sure updated copying shortcut works as expected with defaults."""
b = td.Box(size=(1, 1, 1))
m = td.Medium(permittivity=1)

s = td.Structure(
geometry=b,
medium=m,
)

index = 5
structures = (index + 1) * [s]

sim = td.Simulation(
size=(4, 4, 4),
run_time=1e-12,
grid_spec=td.GridSpec.auto(wavelength=1.0),
structures=structures,
)

# works as expected
new_size = (2, 2, 2)
sim2 = sim.updated_copy(size=new_size, path=f"structures/{index}/geometry")
assert sim2.structures[index].geometry.size != sim.structures[index].geometry.size
assert sim2.structures[index].geometry.size == new_size

# wrong integer index
with pytest.raises(ValueError):
sim2 = sim.updated_copy(size=new_size, path="structures/blah/geometry")

# sim2 = sim.updated_copy(size=new_size, path="structures/blah/geometry")

# try with medium for good measure
new_permittivity = 2.0
sim3 = sim.updated_copy(permittivity=new_permittivity, path=f"structures/{index}/medium")
assert sim3.structures[index].medium.permittivity == new_permittivity
assert sim3.structures[index].medium.permittivity != sim.structures[index].medium.permittivity

# wrong field name
with pytest.raises(AttributeError):
sim3 = sim.updated_copy(
permittivity=new_permittivity, path=f"structures/{index}/not_a_field"
)

# forgot path
with pytest.raises(ValueError):
assert sim == sim.updated_copy(permittivity=2.0)

assert sim.updated_copy(size=(6, 6, 6)) == sim.updated_copy(size=(6, 6, 6), path=None)


def test_equality():
# test freqs / arraylike
mnt1 = td.FluxMonitor(size=(1, 1, 0), freqs=np.array([1, 2, 3]) * 1e12, name="1")
Expand Down
57 changes: 56 additions & 1 deletion tidy3d/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,62 @@ def copy(self, **kwargs) -> Tidy3dBaseModel:
new_copy = pydantic.BaseModel.copy(self, **kwargs)
return self.validate(new_copy.dict())

def updated_copy(self, **kwargs) -> Tidy3dBaseModel:
def updated_copy(self, path: str = None, **kwargs) -> Tidy3dBaseModel:
"""Make copy of a component instance with ``**kwargs`` indicating updated field values.
Note
----
If ``path`` supplied, applies the updated copy with the update performed on the sub-
component corresponding to the path. For indexing into a tuple or list, use the integer
value.
Example
-------
>>> sim = simulation.updated_copy(size=new_size, path=f"structures/{i}/geometry") # doctest: +SKIP
"""

if not path:
return self._updated_copy(**kwargs)

path_components = path.split("/")

field_name = path_components[0]
sub_path = "/".join(path_components[1:])

try:
sub_component = getattr(self, field_name)
except AttributeError as e:
raise AttributeError(
f"Could not field field '{field_name}' in the sub-component `path`. "
f"Found fields of '{tuple(self.__fields__.keys())}'. "
"Please double check the `path` passed to `.updated_copy()`."
) from e

if isinstance(sub_component, (list, tuple)):
integer_index_path = sub_path[0]

try:
index = int(integer_index_path)
except ValueError:
raise ValueError(
f"Could not grab integer index from path '{path}'. "
f"Please correct the sub path containing '{integer_index_path}' to be an "
f"integer index into '{field_name}', containing {len(sub_component)} elements."
)

sub_component_list = list(sub_component)
sub_component = sub_component_list[index]
sub_path = "/".join(path_components[2:])

sub_component_list[index] = sub_component.updated_copy(path=sub_path, **kwargs)
new_component = tuple(sub_component_list)
else:
sub_path = "/".join(path_components[1:])
new_component = sub_component.updated_copy(path=sub_path, **kwargs)

return self._updated_copy(**{field_name: new_component})

def _updated_copy(self, **kwargs) -> Tidy3dBaseModel:
"""Make copy of a component instance with ``**kwargs`` indicating updated field values."""
return self.copy(update=kwargs)

Expand Down

0 comments on commit 52995d2

Please sign in to comment.