diff --git a/src/hydra_zen/structured_configs/_implementations.py b/src/hydra_zen/structured_configs/_implementations.py index ba756c98..7073dd9a 100644 --- a/src/hydra_zen/structured_configs/_implementations.py +++ b/src/hydra_zen/structured_configs/_implementations.py @@ -2186,11 +2186,12 @@ def builds(self,target, populate_full_signature=False, **kw): if ( zen_convert_settings["flat_target"] and isinstance(target, type) + and is_builds(target) and is_dataclass(target) - and hasattr(target, TARGET_FIELD_NAME) ): # pass through _target_ field - target_path = safe_getattr(target, TARGET_FIELD_NAME) + target_path = get_target_path(target) + assert isinstance(target_path, str) else: target_path = cls._get_obj_path(target) else: @@ -2491,7 +2492,7 @@ def builds(self,target, populate_full_signature=False, **kw): # We want to rely on `inspect.signature` logic for raising # against an uninspectable sig, before we start inspecting # class-specific attributes below. - signature_params = dict(inspect.signature(target).parameters) + signature_params = dict(inspect.signature(target).parameters) # type: ignore except ValueError: if populate_full_signature: raise ValueError( @@ -3466,6 +3467,43 @@ def __post_init__(self, CBuildsFn: Type[BuildsFn[Any]]) -> None: # pragma: no c del CBuildsFn +def get_target_path(obj: Union[HasTarget, HasTargetInst]) -> Any: + """ + Returns the import-path from a targeted config. + + Parameters + ---------- + obj : HasTarget + An object with a ``_target_`` attribute. + + Returns + ------- + target_str : str + The import path stored on the config object. + + Raises + ------ + TypeError: ``obj`` does not have a ``_target_`` attribute. + """ + if is_old_partial_builds(obj): + # obj._partial_target_ is `Just[obj]` + return get_target(getattr(obj, "_partial_target_")) + elif uses_zen_processing(obj): + field_name = ZEN_TARGET_FIELD_NAME + elif is_just(obj): + field_name = JUST_FIELD_NAME + elif is_builds(obj): + field_name = TARGET_FIELD_NAME + else: + raise TypeError( + f"`obj` must specify a target; i.e. it must have an attribute named" + f" {TARGET_FIELD_NAME} or named {ZEN_PARTIAL_FIELD_NAME} that" + f" points to a target-object or target-string" + ) + target = safe_getattr(obj, field_name) + return target + + @overload def get_target(obj: InstOrType[Builds[_T]]) -> _T: ... @@ -3490,6 +3528,10 @@ def get_target(obj: Union[HasTarget, HasTargetInst]) -> Any: Returns ------- target : Any + The target object of the config. + + Note that this will import the object using the import + path specified by the config. Raises ------ @@ -3536,22 +3578,7 @@ def get_target(obj: Union[HasTarget, HasTargetInst]) -> Any: >>> get_target(loaded_conf) # type: ignore __main__.B """ - if is_old_partial_builds(obj): - # obj._partial_target_ is `Just[obj]` - return get_target(getattr(obj, "_partial_target_")) - elif uses_zen_processing(obj): - field_name = ZEN_TARGET_FIELD_NAME - elif is_just(obj): - field_name = JUST_FIELD_NAME - elif is_builds(obj): - field_name = TARGET_FIELD_NAME - else: - raise TypeError( - f"`obj` must specify a target; i.e. it must have an attribute named" - f" {TARGET_FIELD_NAME} or named {ZEN_PARTIAL_FIELD_NAME} that" - f" points to a target-object or target-string" - ) - target = safe_getattr(obj, field_name) + target = get_target_path(obj=obj) if isinstance(target, str): target = get_obj(path=target) diff --git a/tests/test_flat_target.py b/tests/test_flat_target.py index 5a858069..63c8070a 100644 --- a/tests/test_flat_target.py +++ b/tests/test_flat_target.py @@ -3,6 +3,7 @@ from dataclasses import dataclass +from functools import partial import pytest @@ -64,3 +65,17 @@ class A: assert instantiate(store[None, "a"])(x=10, y=22) == (10, 22) assert instantiate(store[None, "b"])(x=2) == (2, -1) assert instantiate(store[None, "c"])() == (-1, 2) + + +def test_supports_meta_fields_via_inheritance(): + A = builds(dict, x=1, zen_meta={"META": 2}) + B = builds(A, x="${META}", builds_bases=(A,)) + assert instantiate(B) == {"x": 2} + + +def test_supports_meta_fields_with_partial_via_inheritance(): + A = builds(dict, x=1, zen_partial=True, zen_meta={"META": 2}) + B = builds(A, x="${META}", builds_bases=(A,)) + obj = instantiate(B) + assert isinstance(obj, partial) + assert obj() == {"x": 2}