Skip to content

Commit

Permalink
Improve type-widening and add tests (#185)
Browse files Browse the repository at this point in the history
* Improve type-widening and add tests

* auto-widen type based on sanitized value; add test that exercises this
  • Loading branch information
rsokl committed Dec 24, 2021
1 parent a57ad67 commit 67d0b86
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 6 deletions.
16 changes: 12 additions & 4 deletions src/hydra_zen/structured_configs/_implementations.py
Expand Up @@ -1565,6 +1565,7 @@ def builds(
)
del field_

# sanitize all types and configured values
sanitized_base_fields: List[Union[Tuple[str, Any], Tuple[str, Any, Field]]] = []

for item in base_fields:
Expand Down Expand Up @@ -1606,22 +1607,29 @@ def builds(
else:
_field = value

# If `.default` is not set, then `value` is a Hydra-supported mutable
# value, and thus it is "sanitized"
sanitized_value = getattr(_field, "default", value)
sanitized_type = (
_utils.sanitized_type(type_, wrap_optional=value is None)
_utils.sanitized_type(type_, wrap_optional=sanitized_value is None)
# OmegaConf's type-checking occurs before instantiation occurs.
# This means that, e.g., passing `Builds[int]` to a field `x: int`
# will fail Hydra's type-checking upon instantiation, even though
# the recursive instantiation will appropriately produce `int` for
# that field. This will not be addressed by hydra/omegaconf:
# https://github.com/facebookresearch/hydra/issues/1759
# Thus we will auto-broaden the annotation when we see that the user
# has specified a `Builds` as a default value.
if not is_builds(value) or hydra_recursive is False
# Thus we will auto-broaden the annotation when we see that a field
# is set with a structured config as a default value - assuming that
# the field isn't annotated with a structured config type.
if hydra_recursive is False
or not is_builds(sanitized_value)
or is_builds(type_)
else Any
)
sanitized_base_fields.append((name, sanitized_type, _field))
del value
del _field
del sanitized_value

out = make_dataclass(
dataclass_name, fields=sanitized_base_fields, bases=builds_bases, frozen=frozen
Expand Down
66 changes: 64 additions & 2 deletions tests/test_signature_parsing.py
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass
from enum import Enum
from inspect import Parameter
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Type

import hypothesis.strategies as st
import pytest
Expand Down Expand Up @@ -317,7 +317,7 @@ def expects_int(x: int) -> int:
],
)
@pytest.mark.parametrize("hydra_recursive", [True, None])
def test_setting_default_with_Builds_widens_type(builds_as_default, hydra_recursive):
def test_setting_default_with_builds_widens_type(builds_as_default, hydra_recursive):
# tests that we address https://github.com/facebookresearch/hydra/issues/1759
# via auto type-widening
kwargs = {} if hydra_recursive is None else dict(hydra_recursive=hydra_recursive)
Expand All @@ -329,6 +329,68 @@ def test_setting_default_with_Builds_widens_type(builds_as_default, hydra_recurs
instantiate(builds(expects_int, x=builds_as_default, hydra_recursive=False))


BuildsInt = builds(int)


def f_with_dataclass_annotation(x: BuildsInt = BuildsInt()):
return x


@pytest.mark.parametrize(
"bad_value",
[
1, # not a structured config
builds(str)(), # instance of different structured config
],
)
def test_builds_doesnt_widen_dataclass_type_annotation(bad_value):
with pytest.raises(ValidationError):
instantiate(builds(f_with_dataclass_annotation, x=bad_value))

with pytest.raises(ValidationError):
instantiate(
builds(f_with_dataclass_annotation, populate_full_signature=True),
x=bad_value,
)


def test_dataclass_type_annotation_with_subclass_default():
# ensures that configs that inherite from a base class used
# in an annotation passes Hydra's validation
Child = builds(str, builds_bases=(BuildsInt,))
assert (
instantiate(
builds(f_with_dataclass_annotation, populate_full_signature=True), x=Child()
)
== ""
)
assert instantiate(builds(f_with_dataclass_annotation, x=Child())) == ""


def func_with_list_annotation(x: List[int]):
return x


def test_type_widening_with_internal_conversion_to_Builds():
# This test relies on omegaconf <= 2.1.1 to exercise the desired behavior,
# but should pass regardless.
#
# We contrive a case where an annotation is supported by Hydra, but the supplied
# value requires us to cast internally to a targeted conf; this is due to the
# downstream patch: https://github.com/mit-ll-responsible-ai/hydra-zen/pull/172
#
# Thus in this case we detect that, even though the original configured value is
# valid, we must represent the value with a structured config. Thus we should widen
# the annotation from `List[int]` to `Any`.
#
# Generally, in cases where we need to internally cast to a Builds, the associated
# type annotation is not supported by Hydra to begin with, and thus is broadened
# via type-sanitization.
Base = make_config(x=1)
Conf = builds(func_with_list_annotation, x=[1, 2], builds_bases=(Base,))
instantiate(Conf)


def func_with_various_defaults(x=1, y="a", z=[1, 2]):
pass

Expand Down

0 comments on commit 67d0b86

Please sign in to comment.