Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds the ability to manually override _target_ in builds #597

Merged
merged 4 commits into from Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 20 additions & 11 deletions src/hydra_zen/structured_configs/_implementations.py
Expand Up @@ -1489,7 +1489,8 @@ def builds(
Type[PartialBuilds[Importable]],
Type[BuildsWithSig[Type[R], P]],
]:
"""builds(hydra_target, /, *pos_args, zen_partial=None, zen_wrappers=(), zen_meta=None, populate_full_signature=False, zen_exclude=(), hydra_recursive=None, hydra_convert=None, hydra_defaults=None, frozen=False, dataclass_name=None, builds_bases=(), **kwargs_for_target)
"""builds(hydra_target, /, *pos_args, zen_partial=None, zen_wrappers=(), zen_meta=None, populate_full_signature=False, zen_exclude=(), hydra_recursive=None, hydra_convert=None, hydra_defaults=None, builds_bases=(),
zen_dataclass=None, **kwargs_for_target)

`builds(target, *args, **kw)` returns a Hydra-compatible config that, when
instantiated, returns `target(*args, **kw)`.
Expand Down Expand Up @@ -1606,7 +1607,10 @@ def builds(
:py:func:`dataclasses.make_dataclass` other than `fields`.
The default value for `unsafe_hash` is `True`.

Additionally, the `module` field can be specified to enable pickle
`target` can be specified as a string to override the `_target_` field
set on the dataclass type returned by `builds`.

The `module` field can be specified to enable pickle
compatibility. See `hydra_zen.typing.DataclassOptions` for details.

frozen : bool, optional (default=False)
Expand Down Expand Up @@ -1980,6 +1984,8 @@ def builds(self,target, populate_full_signature=False, **kw):
# initial validation
_utils.parse_dataclass_options(zen_dataclass)

manual_target_path = zen_dataclass.pop("target", None)

if "frozen" in kwargs_for_target:
warnings.warn(
HydraZenDeprecationWarning(
Expand Down Expand Up @@ -2098,16 +2104,19 @@ def builds(self,target, populate_full_signature=False, **kw):
)

target_path: str
if (
zen_convert_settings["flat_target"]
and isinstance(target, type)
and is_dataclass(target)
and hasattr(target, TARGET_FIELD_NAME)
):
# pass through _target_ field
target_path = safe_getattr(target, TARGET_FIELD_NAME)
if manual_target_path is None:
if (
zen_convert_settings["flat_target"]
and isinstance(target, type)
and is_dataclass(target)
and hasattr(target, TARGET_FIELD_NAME)
):
# pass through _target_ field
target_path = safe_getattr(target, TARGET_FIELD_NAME)
else:
target_path = cls._get_obj_path(target)
else:
target_path = cls._get_obj_path(target)
target_path = manual_target_path

if zen_wrappers is not None:
if not isinstance(zen_wrappers, Sequence) or isinstance(zen_wrappers, str):
Expand Down
8 changes: 8 additions & 0 deletions src/hydra_zen/structured_configs/_utils.py
Expand Up @@ -367,6 +367,14 @@ def parse_dataclass_options(
f"dataclass option `{name}` must be a mapping with string-valued keys "
f"that are valid identifiers. Got {val}."
)
elif name == "target":
if not isinstance(val, str) or not all(
x.isidentifier() for x in val.split(".")
):
raise TypeError(
f"dataclass option `target` must be a string and an import path, "
f"got {val!r}"
)
elif not isinstance(val, bool):
raise TypeError(
f"dataclass option `{name}` must be of type `bool`. Got {val} "
Expand Down
4 changes: 3 additions & 1 deletion src/hydra_zen/typing/_implementations.py
Expand Up @@ -435,7 +435,8 @@ class DataclassOptions(_Py312Dataclass, total=False):
pickle-compatibility for that dataclass. See the Examples section for
clarification.

This is a hydra-zen exclusive feature.
target : str, optional (unspecified by default)
If specified, overrides the `_target_` field set on the resulting dataclass.

init : bool, optional (default=True)
If true (the default), a __init__() method will be generated. If the class
Expand Down Expand Up @@ -573,6 +574,7 @@ class DataclassOptions(_Py312Dataclass, total=False):
"""

module: Optional[str]
target: str


def _permitted_keys(typed_dict: Any) -> FrozenSet[str]:
Expand Down
6 changes: 6 additions & 0 deletions tests/annotations/declarations.py
Expand Up @@ -1477,3 +1477,9 @@ def foo(x: A):
assert_type(bg(A, B()), Type[Builds[Type[A]]])
assert_type(bg(A, B(), zen_partial=True), Type[PartialBuilds[Type[A]]])
bg(A, C()) # type: ignore


def check_target_override():
builds(int, zen_dataclass={"target": 1}) # type: ignore
builds(int, zen_dataclass={"target": ["a"]}) # type: ignore
builds(int, zen_dataclass={"target": "foo.bar"})
61 changes: 61 additions & 0 deletions tests/test_manual_target.py
@@ -0,0 +1,61 @@
# Copyright (c) 2023 Massachusetts Institute of Technology
# SPDX-License-Identifier: MIT
from functools import partial

import pytest

from hydra_zen import builds, instantiate


@pytest.mark.parametrize(
"target_path",
[
int,
1,
["a"],
"not a path",
],
)
def test_validation(target_path):
with pytest.raises(TypeError, match="dataclass option `target`"):
builds(int, zen_dataclass={"target": target_path})


def foo(x=1, y=2):
raise AssertionError("I should not get called")


def passthrough(x):
return x


@pytest.mark.parametrize(
"target",
[
foo,
partial(foo),
builds(foo, populate_full_signature=True),
],
)
@pytest.mark.parametrize(
"kwargs",
[
{},
{"x": 1},
# overrides default with different value
pytest.param({"x": 2}, marks=pytest.mark.xfail),
# exercise zen_processing branch
{"zen_wrappers": passthrough},
],
)
def test_manual(target, kwargs):
out = instantiate(
builds(
target,
populate_full_signature=True,
zen_dataclass={"target": "builtins.dict"},
**kwargs,
)
)
assert isinstance(out, dict)
assert out == dict(x=1, y=2)