Skip to content

Commit

Permalink
Merge pull request #598 from mit-ll-responsible-ai/kwargs-fn
Browse files Browse the repository at this point in the history
Adds `hydra_zen.kwargs_fn`
  • Loading branch information
rsokl committed Nov 24, 2023
2 parents bf4b880 + 26bf8d1 commit be42912
Show file tree
Hide file tree
Showing 10 changed files with 230 additions and 10 deletions.
3 changes: 2 additions & 1 deletion docs/source/api_reference.rst
Expand Up @@ -34,10 +34,11 @@ Creating Configs
.. autosummary::
:toctree: generated/

make_config
builds
just
kwargs_of
hydrated_dataclass
make_config


Storing Configs
Expand Down
4 changes: 3 additions & 1 deletion docs/source/changes.rst
Expand Up @@ -11,7 +11,7 @@ chronological order. All previous releases should still be available on pip.
.. _v0.12.0:

----------------------
0.12.0rc6 - 2023-11-15
0.12.0rc7 - 2023-11-15
----------------------


Expand Down Expand Up @@ -67,6 +67,8 @@ Improvements
- :class:`~hydra_zen.BuildsFn` was introduced to permit customizable auto-config and type-refinement support in config-creation functions. See :pull:`553`.
- :func:`~hydra_zen.builds` and :func:`~hydra_zen.make_custom_builds_fn` now accept a `zen_exclude` field for excluding parameters from auto-population, either by name, position-index, or by pattern. See :pull:`558`.
- :func:`~hydra_zen.builds` and :func:`~hydra_zen.just` can now configure static methods. Previously the incorrect ``_target_`` would be resolved. See :pull:`566`
- :func:`~hydra_zen.kwargs_of` is a new config-creation function that lets you generate a stand-along config from an object's signature. See :pull:`598`.
- Users can now manually specify the ``_target_`` entry of a config produced via :func:`~hydra_zen.builds`. See :pull:`597`.
- :func:`hydra_zen.zen` now has first class support for running code in an isolated :py:class:`contextvars.Context`. This enables users to safely leverage state via :py:class:`contextvars.ContextVar` in their task functions. See :pull:`583`.
- Adds formal support for Python 3.12. See :pull:`555`
- Several new methods were added to :class:`~hydra_zen.ZenStore`, including the abilities to copy, update, and merge stores. As well as remap the groups of a store's entries and delete individual entries. See :pull:`569`
Expand Down
6 changes: 6 additions & 0 deletions docs/source/generated/hydra_zen.kwargs_of.rst
@@ -0,0 +1,6 @@
hydra\_zen.kwargs_of
====================

.. currentmodule:: hydra_zen

.. autofunction:: kwargs_of
8 changes: 7 additions & 1 deletion src/hydra_zen/__init__.py
Expand Up @@ -19,7 +19,12 @@
make_custom_builds_fn,
mutable_value,
)
from .structured_configs._implementations import BuildsFn, DefaultBuilds, get_target
from .structured_configs._implementations import (
BuildsFn,
DefaultBuilds,
get_target,
kwargs_of,
)
from .structured_configs._type_guards import is_partial_builds, uses_zen_processing
from .wrapper import ZenStore, store, zen

Expand All @@ -29,6 +34,7 @@
"DefaultBuilds",
"hydrated_dataclass",
"just",
"kwargs_of",
"mutable_value",
"get_target",
"MISSING",
Expand Down
139 changes: 136 additions & 3 deletions src/hydra_zen/structured_configs/_implementations.py
Expand Up @@ -645,7 +645,9 @@ def __post_init__(


class BuildsFn(Generic[T]):
"""A class that can be modified to customize the behavior of `builds`, `just`, and `make_config`. These functions are exposed as class methods of `BuildsFn`.
"""A class that can be modified to customize the behavior of `builds`, `just`, `kwargs_of, and `make_config`.
These functions are exposed as class methods of `BuildsFn`.
To customize type-refinement support, override `_sanitized_type`.
To customize auto-config support, override `_make_hydra_compatible`.
Expand All @@ -654,6 +656,9 @@ class BuildsFn(Generic[T]):

__slots__ = ()

_default_dataclass_options_for_kwargs_of: Optional[DataclassOptions] = None
"""Specifies the default options for `cls.kwargs_of(..., zen_dataclass)"""

@classmethod
def _sanitized_type(
cls,
Expand Down Expand Up @@ -1985,6 +1990,7 @@ def builds(self,target, populate_full_signature=False, **kw):
_utils.parse_dataclass_options(zen_dataclass)

manual_target_path = zen_dataclass.pop("target", None)
target_repr = zen_dataclass.pop("target_repr", True)

if "frozen" in kwargs_for_target:
warnings.warn(
Expand Down Expand Up @@ -2252,7 +2258,7 @@ def builds(self,target, populate_full_signature=False, **kw):
(
TARGET_FIELD_NAME,
str,
_utils.field(default=target_path, init=False),
_utils.field(default=target_path, init=False, repr=target_repr),
),
(
PARTIAL_FIELD_NAME,
Expand Down Expand Up @@ -2338,7 +2344,7 @@ def builds(self,target, populate_full_signature=False, **kw):
(
TARGET_FIELD_NAME,
str,
_utils.field(default=target_path, init=False),
_utils.field(default=target_path, init=False, repr=target_repr),
)
]

Expand Down Expand Up @@ -3198,12 +3204,139 @@ def make_config(

return cast(Type[DataClass], out)

@overload
@classmethod
def kwargs_of(
cls,
__hydra_target: Callable[P, Any],
*,
zen_dataclass: Optional[DataclassOptions] = ...,
zen_exclude: Literal[None] = ...,
) -> Type[BuildsWithSig[Type[Dict[str, Any]], P]]:
...

@overload
@classmethod
def kwargs_of(
cls,
__hydra_target: Callable[P, Any],
*,
zen_dataclass: Optional[DataclassOptions] = ...,
zen_exclude: Union["Collection[Union[str, int]]", Callable[[str], bool]],
) -> Type[Builds[Type[Dict[str, Any]]]]:
...

@overload
@classmethod
def kwargs_of(
cls,
__hydra_target: Callable[P, Any],
*,
zen_dataclass: Optional[DataclassOptions] = ...,
zen_exclude: Union[
None, "Collection[Union[str, int]]", Callable[[str], bool]
] = ...,
**kwargs_for_target: T,
) -> Type[Builds[Type[Dict[str, Any]]]]:
...

@classmethod
def kwargs_of(
cls,
__hydra_target: Callable[P, Any],
*,
zen_dataclass: Optional[DataclassOptions] = None,
zen_exclude: Union[
None, "Collection[Union[str, int]]", Callable[[str], bool]
] = None,
**kwargs_for_target: T,
) -> Union[
Type[BuildsWithSig[Type[Dict[str, Any]], P]], Type[Builds[Type[Dict[str, Any]]]]
]:
"""Returns a config whose signature matches that of the provided target.
Instantiating the config returns a dictionary.
.. note::
``kwargs_of`` is a new feature as of hydra-zen v0.12.0rc7.
You can try out this pre-release feature using `pip install --pre hydra-zen`
Parameters
----------
__hydra_target : Callable[P, Any]
An object with an inspectable signature.
zen_exclude : Collection[str | int] | Callable[[str], bool], optional (default=[])
Specifies parameter names and/or indices, or a function for checking names,
to exclude those parameters from the config-creation process.
Returns
-------
type[Builds[type[dict[str, Any]]]]
Examples
--------
>>> from inspect import signature
>>> from hydra_zen import kwargs_of, instantiate
>>> Config = kwargs_of(lambda x, y: None)
>>> signature(Config)
<Signature (x:Any, y: Any) -> None>
>>> config = Config(x=1, y=2)
>>> config
kwargs_of_lambda(x=1, y=2)
>>> instantiate(config)
{'x': 1, 'y': 2}
Excluding the first parameter from the target's signature:
>>> Config = kwargs_of(lambda *, x, y: None, zen_exclude=[0])
>>> signature(Config)
<Signature (y: Any) -> None>
>>> instantiate(Config(y=88))
{'y': 88}
Overwriting a default
>>> Config = kwargs_of(lambda *, x, y: None, y=22)
>>> signature(Config)
<Signature (x: Any, y: Any = 22) -> None>
"""
base_zen_detaclass: DataclassOptions = (
cls._default_dataclass_options_for_kwargs_of.copy()
if cls._default_dataclass_options_for_kwargs_of
else {}
)
if zen_dataclass is None:
zen_dataclass = {}

zen_dataclass = {**base_zen_detaclass, **zen_dataclass}
zen_dataclass["target"] = "builtins.dict"
zen_dataclass.setdefault(
"cls_name", f"kwargs_of_{_utils.safe_name(__hydra_target)}"
)
zen_dataclass.setdefault("target_repr", False)

if zen_exclude is None:
zen_exclude = ()
return cls.builds( # type: ignore
__hydra_target,
populate_full_signature=True,
zen_exclude=zen_exclude,
zen_dataclass=zen_dataclass,
**kwargs_for_target, # type: ignore
)


class DefaultBuilds(BuildsFn[SupportedPrimitive]):
_default_dataclass_options_for_kwargs_of = {}
pass


builds: Final = DefaultBuilds.builds
kwargs_of: Final = DefaultBuilds.kwargs_of


@dataclass(unsafe_hash=True)
Expand Down
4 changes: 2 additions & 2 deletions src/hydra_zen/structured_configs/_utils.py
Expand Up @@ -129,10 +129,10 @@ def safe_name(obj: Any, repr_allowed: bool = True) -> str:
instead of raising - useful for writing descriptive/dafe error messages."""

if hasattr(obj, "__name__"):
return obj.__name__
return obj.__name__.replace("<lambda>", "lambda")

if repr_allowed and hasattr(obj, "__repr__"):
return repr(obj)
return repr(obj).replace("<lambda>", "lambda")

return UNKNOWN_NAME

Expand Down
1 change: 1 addition & 0 deletions src/hydra_zen/typing/_implementations.py
Expand Up @@ -575,6 +575,7 @@ class DataclassOptions(_Py312Dataclass, total=False):

module: Optional[str]
target: str
target_repr: bool


def _permitted_keys(typed_dict: Any) -> FrozenSet[str]:
Expand Down
23 changes: 23 additions & 0 deletions tests/annotations/declarations.py
Expand Up @@ -41,6 +41,7 @@
get_target,
instantiate,
just,
kwargs_of,
make_config,
make_custom_builds_fn,
mutable_value,
Expand Down Expand Up @@ -1483,3 +1484,25 @@ def check_target_override():
builds(int, zen_dataclass={"target": 1}) # type: ignore
builds(int, zen_dataclass={"target": ["a"]}) # type: ignore
builds(int, zen_dataclass={"target": "foo.bar"})


def check_kwargs_of():
def foo(x: int, y: str):
...

Conf = kwargs_of(foo)
reveal_type(
Conf,
expected_text="type[BuildsWithSig[type[Dict[str, Any]], (x: int, y: str)]]",
)

Conf2 = kwargs_of(foo, zen_exclude=[0])
reveal_type(Conf2, expected_text="type[Builds[type[Dict[str, Any]]]]")

Conf3 = kwargs_of(foo, x=1)
reveal_type(Conf3, expected_text="type[Builds[type[Dict[str, Any]]]]")

class NotSupported:
...

Conf3 = kwargs_of(foo, x=NotSupported()) # type: ignore
48 changes: 48 additions & 0 deletions tests/test_kwargs_of.py
@@ -0,0 +1,48 @@
# Copyright (c) 2023 Massachusetts Institute of Technology
# SPDX-License-Identifier: MIT
from inspect import signature

import pytest

from hydra_zen import BuildsFn, instantiate, kwargs_of


def test_basic():
Conf = kwargs_of(lambda x, y: None)
assert set(signature(Conf).parameters) == {"x", "y"}
out = instantiate(Conf(x=-9, y=10))
assert isinstance(out, dict)
assert out == dict(x=-9, y=10)


@pytest.mark.parametrize(
"exclude,params",
[
([0], {"y"}),
(["x", -1], set()),
],
)
def test_exclude(exclude, params):
Conf = kwargs_of((lambda x, y: None), zen_exclude=exclude)
assert set(signature(Conf).parameters) == params


def test_dataclass_options():
Conf = kwargs_of((lambda x, y: None), zen_dataclass={"cls_name": "foo"})
assert Conf.__name__ == "foo"


def test_dataclass_options_via_cls_defaults():
class Moo(BuildsFn):
_default_dataclass_options_for_kwargs_of = {"cls_name": "bar"}

Conf1 = kwargs_of((lambda: None), zen_dataclass={"cls_name": "foo"})
assert Conf1.__name__ == "foo"

Conf2 = Moo.kwargs_of((lambda: None))
assert Conf2.__name__ == "bar"


def test_kwarg_override():
Config = kwargs_of(lambda *, x, y: None, y=22)
assert instantiate(Config(x=1)) == {"x": 1, "y": 22}
4 changes: 2 additions & 2 deletions tests/test_zen.py
Expand Up @@ -50,10 +50,10 @@ def f(x: int, y: str):


def test_zen_repr():
assert repr(zen(lambda x, y: None)) == "zen[<lambda>(x, y)](cfg, /)"
assert repr(zen(lambda x, y: None)) == "zen[lambda(x, y)](cfg, /)"
assert (
repr(zen(pre_call=lambda x: x)(lambda x, y: None))
== "zen[<lambda>(x, y)](cfg, /)"
== "zen[lambda(x, y)](cfg, /)"
)
assert repr(zen(make_config("x", "y"))) == "zen[Config(x, y)](cfg, /)"

Expand Down

0 comments on commit be42912

Please sign in to comment.