Skip to content

Commit

Permalink
small fixes to config
Browse files Browse the repository at this point in the history
Summary:
- indicate location of OmegaConf.structured failures
- split the data gathering from enable_get_default_args to ease experimenting with it.
- comment fixes.
- nicer error when a_class_type has weird type.

Reviewed By: kjchalup

Differential Revision: D39434447

fbshipit-source-id: b80c7941547ca450e848038ef5be95b7ebbe8f3e
  • Loading branch information
bottler authored and facebook-github-bot committed Sep 15, 2022
1 parent cb7bd33 commit da7fe28
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 17 deletions.
67 changes: 52 additions & 15 deletions pytorch3d/implicitron/tools/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,10 @@ class type
raise ValueError(
f"Cannot look up {base_class_wanted}. Cannot tell what it is."
)
if not isinstance(name, str):
raise ValueError(
f"Cannot look up a {type(name)} in the registry. Got {name}."
)
result = self._mapping[base_class].get(name)
if result is None:
raise ValueError(f"{name} has not been registered.")
Expand Down Expand Up @@ -446,6 +450,11 @@ def create_pluggable(self, type_name, args):
setattr(self, name, None)
return

if not isinstance(type_name, str):
raise ValueError(
f"A {type(type_name)} was received as the type of {name}."
+ f" Perhaps this is from {name}{TYPE_SUFFIX}?"
)
chosen_class = registry.get(type_, type_name)
if self._known_implementations.get(type_name, chosen_class) is not chosen_class:
# If this warning is raised, it means that a new definition of
Expand Down Expand Up @@ -514,7 +523,10 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
# because in practice get_default_args_field is used for
# separate types than the outer type.

out: DictConfig = OmegaConf.structured(C)
try:
out: DictConfig = OmegaConf.structured(C)
except Exception as e:
raise ValueError(f"OmegaConf.structured({C}) failed") from e
exclude = getattr(C, "_processed_members", ())
with open_dict(out):
for field in exclude:
Expand All @@ -534,7 +546,11 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
f"Cannot get args for {C}. Was enable_get_default_args forgotten?"
)

return OmegaConf.structured(dataclass)
try:
out: DictConfig = OmegaConf.structured(dataclass)
except Exception as e:
raise ValueError(f"OmegaConf.structured failed for {dataclass_name}") from e
return out


def _dataclass_name_for_function(C: Any) -> str:
Expand All @@ -546,22 +562,21 @@ def _dataclass_name_for_function(C: Any) -> str:
return name


def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None:
def _field_annotations_for_default_args(
C: Any,
) -> List[Tuple[str, Any, dataclasses.Field]]:
"""
If C is a function or a plain class with an __init__ function,
and you want get_default_args(C) to work, then add
`enable_get_default_args(C)` straight after the definition of C.
This makes a dataclass corresponding to the default arguments of C
and stores it in the same module as C.
return the fields which `enable_get_default_args(C)` will need
to make a dataclass with.
Args:
C: a function, or a class with an __init__ function. Must
have types for all its defaulted args.
overwrite: whether to allow calling this a second time on
the same function.
Returns:
a list of fields for a dataclass.
"""
if not inspect.isfunction(C) and not inspect.isclass(C):
raise ValueError(f"Unexpected {C}")

field_annotations = []
for pname, defval in _params_iter(C):
Expand All @@ -572,8 +587,8 @@ def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None:

if defval.annotation == inspect._empty:
raise ValueError(
"All arguments of the input callable have to be typed."
+ f" Argument '{pname}' does not have a type annotation."
"All arguments of the input to enable_get_default_args have to"
f" be typed. Argument '{pname}' does not have a type annotation."
)

_, annotation = _resolve_optional(defval.annotation)
Expand All @@ -591,6 +606,28 @@ def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None:
field_ = dataclasses.field(default=default)
field_annotations.append((pname, defval.annotation, field_))

return field_annotations


def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None:
"""
If C is a function or a plain class with an __init__ function,
and you want get_default_args(C) to work, then add
`enable_get_default_args(C)` straight after the definition of C.
This makes a dataclass corresponding to the default arguments of C
and stores it in the same module as C.
Args:
C: a function, or a class with an __init__ function. Must
have types for all its defaulted args.
overwrite: whether to allow calling this a second time on
the same function.
"""
if not inspect.isfunction(C) and not inspect.isclass(C):
raise ValueError(f"Unexpected {C}")

field_annotations = _field_annotations_for_default_args(C)

name = _dataclass_name_for_function(C)
module = sys.modules[C.__module__]
if hasattr(module, name):
Expand Down Expand Up @@ -767,7 +804,7 @@ def create_x_impl(self, enabled, args):
Also adds the following class members, unannotated so that dataclass
ignores them.
- _creation_functions: Tuple[str] of all the create_ functions,
- _creation_functions: Tuple[str, ...] of all the create_ functions,
including those from base classes (not the create_x_impl ones).
- _known_implementations: Dict[str, Type] containing the classes which
have been found from the registry.
Expand Down Expand Up @@ -945,7 +982,7 @@ def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]:
return underlying, _ProcessType.OPTIONAL_CONFIGURABLE

if not isinstance(type_, type):
# e.g. any other Union or Tuple
# e.g. any other Union or Tuple. Or ClassVar.
return

if issubclass(type_, ReplaceableBase) and ReplaceableBase in type_.__bases__:
Expand Down
5 changes: 3 additions & 2 deletions tests/implicitron/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def test_registry_entries(self):
self.assertIn(Banana, all_fruit)
self.assertIn(Pear, all_fruit)
self.assertIn(LargePear, all_fruit)
self.assertEqual(set(registry.get_all(Pear)), {LargePear})
self.assertEqual(registry.get_all(Pear), [LargePear])

@registry.register
class Apple(Fruit):
Expand All @@ -178,7 +178,7 @@ class Apple(Fruit):
class CrabApple(Apple):
pass

self.assertEqual(set(registry.get_all(Apple)), {CrabApple})
self.assertEqual(registry.get_all(Apple), [CrabApple])

self.assertIs(registry.get(Fruit, "CrabApple"), CrabApple)

Expand Down Expand Up @@ -601,6 +601,7 @@ def __init__(self, a: A = A.B1) -> None:

for C_ in [C, C_fn, C_cl]:
base = get_default_args(C_)
self.assertEqual(OmegaConf.to_yaml(base), "a: B1\n")
self.assertEqual(base.a, A.B1)
replaced = OmegaConf.merge(base, {"a": "B2"})
self.assertEqual(replaced.a, A.B2)
Expand Down

0 comments on commit da7fe28

Please sign in to comment.