Skip to content

Commit

Permalink
Fix missing autograd box schema and move schema creation code to comp…
Browse files Browse the repository at this point in the history
…onents/type_util.py
  • Loading branch information
yaugenst-flex authored and momchil-flex committed Jun 12, 2024
1 parent ba11a51 commit 474417d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 11 deletions.
4 changes: 4 additions & 0 deletions tidy3d/components/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
from autograd.extend import Box, defvjp, primitive
from autograd.tracer import getval

from tidy3d.components.type_util import _add_schema

from .types import ArrayFloat2D, ArrayLike, Bound, Size1D

_add_schema(Box, title="AutogradBox", field_type_str="autograd.tracer.Box")

# TODO: should we use ArrayBox? Box is more general

# Types for floats, or collections of floats that can also be autograd tracers
Expand Down
12 changes: 12 additions & 0 deletions tidy3d/components/type_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Utilities for type & schema creation."""


def _add_schema(arbitrary_type: type, title: str, field_type_str: str) -> None:
"""Adds a schema to the ``arbitrary_type`` class without subclassing."""

@classmethod
def mod_schema_fn(cls, field_schema: dict) -> None:
"""Function that gets set to ``arbitrary_type.__modify_schema__``."""
field_schema.update(dict(title=title, type=field_type_str))

arbitrary_type.__modify_schema__ = mod_schema_fn
13 changes: 2 additions & 11 deletions tidy3d/plugins/adjoint/components/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import numpy as np

from tidy3d.components.type_util import _add_schema

# special handling if we cant import the JVPTracer in the future (so it doesn't break tidy3d).
try:
from jax.interpreters.ad import JVPTracer
Expand Down Expand Up @@ -37,17 +39,6 @@ def __modify_schema__(cls, field_schema):
field_schema.update(schema)


def _add_schema(arbitrary_type: type, title: str, field_type_str: str) -> None:
"""Adds a schema to the ``arbitrary_type`` class without subclassing."""

@classmethod
def mod_schema_fn(cls, field_schema: dict) -> None:
"""Function that gets set to ``arbitrary_type.__modify_schema__``."""
field_schema.update(dict(title=title, type=field_type_str))

arbitrary_type.__modify_schema__ = mod_schema_fn


_add_schema(JaxArrayType, title="JaxArray", field_type_str="jax.numpy.ndarray")

# if the ImportError didnt occur, add the schema
Expand Down

0 comments on commit 474417d

Please sign in to comment.