Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make flat-target path point to correct target with zen-processing #638

Merged
merged 5 commits into from Mar 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
65 changes: 46 additions & 19 deletions src/hydra_zen/structured_configs/_implementations.py
Expand Up @@ -2185,11 +2185,12 @@ def builds(self,target, populate_full_signature=False, **kw):
if (
zen_convert_settings["flat_target"]
and isinstance(target, type)
and is_builds(target)
and is_dataclass(target)
and hasattr(target, TARGET_FIELD_NAME)
):
# pass through _target_ field
target_path = safe_getattr(target, TARGET_FIELD_NAME)
target_path = get_target_path(target)
assert isinstance(target_path, str)
else:
target_path = cls._get_obj_path(target)
else:
Expand Down Expand Up @@ -2490,7 +2491,7 @@ def builds(self,target, populate_full_signature=False, **kw):
# We want to rely on `inspect.signature` logic for raising
# against an uninspectable sig, before we start inspecting
# class-specific attributes below.
signature_params = dict(inspect.signature(target).parameters)
signature_params = dict(inspect.signature(target).parameters) # type: ignore
except ValueError:
if populate_full_signature:
raise ValueError(
Expand Down Expand Up @@ -3465,6 +3466,43 @@ def __post_init__(self, CBuildsFn: Type[BuildsFn[Any]]) -> None: # pragma: no c
del CBuildsFn


def get_target_path(obj: Union[HasTarget, HasTargetInst]) -> Any:
"""
Returns the target-object from a targeted config.
rsokl marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
obj : HasTarget
An object with a ``_target_`` attribute.

Returns
-------
target_str : str
The import path stored on the config object.

Raises
------
TypeError: ``obj`` does not have a ``_target_`` attribute.
"""
if is_old_partial_builds(obj):
# obj._partial_target_ is `Just[obj]`
return get_target(getattr(obj, "_partial_target_"))
elif uses_zen_processing(obj):
field_name = ZEN_TARGET_FIELD_NAME
elif is_just(obj):
field_name = JUST_FIELD_NAME
elif is_builds(obj):
field_name = TARGET_FIELD_NAME
else:
raise TypeError(
f"`obj` must specify a target; i.e. it must have an attribute named"
f" {TARGET_FIELD_NAME} or named {ZEN_PARTIAL_FIELD_NAME} that"
f" points to a target-object or target-string"
)
target = safe_getattr(obj, field_name)
return target


@overload
def get_target(obj: InstOrType[Builds[_T]]) -> _T: ...

Expand All @@ -3489,6 +3527,10 @@ def get_target(obj: Union[HasTarget, HasTargetInst]) -> Any:
Returns
-------
target : Any
The target object of the config.

Note that this will import the object using the import
path specified by the config.

Raises
------
Expand Down Expand Up @@ -3535,22 +3577,7 @@ def get_target(obj: Union[HasTarget, HasTargetInst]) -> Any:
>>> get_target(loaded_conf) # type: ignore
__main__.B
"""
if is_old_partial_builds(obj):
# obj._partial_target_ is `Just[obj]`
return get_target(getattr(obj, "_partial_target_"))
elif uses_zen_processing(obj):
field_name = ZEN_TARGET_FIELD_NAME
elif is_just(obj):
field_name = JUST_FIELD_NAME
elif is_builds(obj):
field_name = TARGET_FIELD_NAME
else:
raise TypeError(
f"`obj` must specify a target; i.e. it must have an attribute named"
f" {TARGET_FIELD_NAME} or named {ZEN_PARTIAL_FIELD_NAME} that"
f" points to a target-object or target-string"
)
target = safe_getattr(obj, field_name)
target = get_target_path(obj=obj)

if isinstance(target, str):
target = get_obj(path=target)
Expand Down
15 changes: 15 additions & 0 deletions tests/test_flat_target.py
Expand Up @@ -3,6 +3,7 @@


from dataclasses import dataclass
from functools import partial

import pytest

Expand Down Expand Up @@ -64,3 +65,17 @@ class A:
assert instantiate(store[None, "a"])(x=10, y=22) == (10, 22)
assert instantiate(store[None, "b"])(x=2) == (2, -1)
assert instantiate(store[None, "c"])() == (-1, 2)


def test_supports_meta_fields_via_inheritance():
A = builds(dict, x=1, zen_meta={"META": 2})
B = builds(A, x="${META}", builds_bases=(A,))
assert instantiate(B) == {"x": 2}


def test_supports_meta_fields_with_partial_via_inheritance():
A = builds(dict, x=1, zen_partial=True, zen_meta={"META": 2})
B = builds(A, x="${META}", builds_bases=(A,))
obj = instantiate(B)
assert isinstance(obj, partial)
assert obj() == {"x": 2}