Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type-widening and add tests #185

Merged
merged 2 commits into from Dec 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/hydra_zen/structured_configs/_implementations.py
Expand Up @@ -1565,6 +1565,7 @@ def builds(
)
del field_

# sanitize all types and configured values
sanitized_base_fields: List[Union[Tuple[str, Any], Tuple[str, Any, Field]]] = []

for item in base_fields:
Expand Down Expand Up @@ -1606,22 +1607,29 @@ def builds(
else:
_field = value

# If `.default` is not set, then `value` is a Hydra-supported mutable
# value, and thus it is "sanitized"
sanitized_value = getattr(_field, "default", value)
sanitized_type = (
_utils.sanitized_type(type_, wrap_optional=value is None)
_utils.sanitized_type(type_, wrap_optional=sanitized_value is None)
# OmegaConf's type-checking occurs before instantiation occurs.
# This means that, e.g., passing `Builds[int]` to a field `x: int`
# will fail Hydra's type-checking upon instantiation, even though
# the recursive instantiation will appropriately produce `int` for
# that field. This will not be addressed by hydra/omegaconf:
# https://github.com/facebookresearch/hydra/issues/1759
# Thus we will auto-broaden the annotation when we see that the user
# has specified a `Builds` as a default value.
if not is_builds(value) or hydra_recursive is False
# Thus we will auto-broaden the annotation when we see that a field
# is set with a structured config as a default value - assuming that
# the field isn't annotated with a structured config type.
if hydra_recursive is False
or not is_builds(sanitized_value)
or is_builds(type_)
else Any
)
sanitized_base_fields.append((name, sanitized_type, _field))
del value
del _field
del sanitized_value

out = make_dataclass(
dataclass_name, fields=sanitized_base_fields, bases=builds_bases, frozen=frozen
Expand Down
66 changes: 64 additions & 2 deletions tests/test_signature_parsing.py
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass
from enum import Enum
from inspect import Parameter
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Type

import hypothesis.strategies as st
import pytest
Expand Down Expand Up @@ -317,7 +317,7 @@ def expects_int(x: int) -> int:
],
)
@pytest.mark.parametrize("hydra_recursive", [True, None])
def test_setting_default_with_Builds_widens_type(builds_as_default, hydra_recursive):
def test_setting_default_with_builds_widens_type(builds_as_default, hydra_recursive):
# tests that we address https://github.com/facebookresearch/hydra/issues/1759
# via auto type-widening
kwargs = {} if hydra_recursive is None else dict(hydra_recursive=hydra_recursive)
Expand All @@ -329,6 +329,68 @@ def test_setting_default_with_Builds_widens_type(builds_as_default, hydra_recurs
instantiate(builds(expects_int, x=builds_as_default, hydra_recursive=False))


BuildsInt = builds(int)


def f_with_dataclass_annotation(x: BuildsInt = BuildsInt()):
return x


@pytest.mark.parametrize(
"bad_value",
[
1, # not a structured config
builds(str)(), # instance of different structured config
],
)
def test_builds_doesnt_widen_dataclass_type_annotation(bad_value):
with pytest.raises(ValidationError):
instantiate(builds(f_with_dataclass_annotation, x=bad_value))

with pytest.raises(ValidationError):
instantiate(
builds(f_with_dataclass_annotation, populate_full_signature=True),
x=bad_value,
)


def test_dataclass_type_annotation_with_subclass_default():
# ensures that configs that inherite from a base class used
# in an annotation passes Hydra's validation
Child = builds(str, builds_bases=(BuildsInt,))
assert (
instantiate(
builds(f_with_dataclass_annotation, populate_full_signature=True), x=Child()
)
== ""
)
assert instantiate(builds(f_with_dataclass_annotation, x=Child())) == ""


def func_with_list_annotation(x: List[int]):
return x


def test_type_widening_with_internal_conversion_to_Builds():
# This test relies on omegaconf <= 2.1.1 to exercise the desired behavior,
# but should pass regardless.
#
# We contrive a case where an annotation is supported by Hydra, but the supplied
# value requires us to cast internally to a targeted conf; this is due to the
# downstream patch: https://github.com/mit-ll-responsible-ai/hydra-zen/pull/172
#
# Thus in this case we detect that, even though the original configured value is
# valid, we must represent the value with a structured config. Thus we should widen
# the annotation from `List[int]` to `Any`.
#
# Generally, in cases where we need to internally cast to a Builds, the associated
# type annotation is not supported by Hydra to begin with, and thus is broadened
# via type-sanitization.
Base = make_config(x=1)
Conf = builds(func_with_list_annotation, x=[1, 2], builds_bases=(Base,))
instantiate(Conf)


def func_with_various_defaults(x=1, y="a", z=[1, 2]):
pass

Expand Down