Skip to content

Commit

Permalink
Merge pull request #384 from mit-ll-responsible-ai/improve-zen
Browse files Browse the repository at this point in the history
Improve zen pickle-compat and support for `hydra_main(config_path)`
  • Loading branch information
rsokl committed Jan 13, 2023
2 parents 9e7ad1a + c1a4be4 commit 7c455a6
Show file tree
Hide file tree
Showing 14 changed files with 155 additions and 15 deletions.
15 changes: 15 additions & 0 deletions docs/source/changes.rst
Expand Up @@ -8,6 +8,21 @@ Changelog
This is a record of all past hydra-zen releases and what went into them, in reverse
chronological order. All previous releases should still be available on pip.

.. _v0.9.1:

------------------
0.9.1 - 2023-01-13
------------------


Improvements
------------
- :func:`hydra_zen.zen` now returns pickle-compatible wrapped functions. See :pull:`384`.

Bug Fixes
---------
- :func:`hydra_zen.zen`'s `hydra_main` method now handles string `config_path` entries properly (only for Hydra 1.3.0+). Previously Hydra could not find the path to the wrapped task function. hydra-zen will warn users that a string `config_path` is not supported via :func:`hydra_zen.zen` for Hydra 1.2 and earlier. See :pull:`384`.

.. _v0.9.0:

------------------
Expand Down
61 changes: 48 additions & 13 deletions src/hydra_zen/wrapper/_implementations.py
Expand Up @@ -2,7 +2,9 @@
# SPDX-License-Identifier: MIT
# pyright: strict

import warnings
from collections import defaultdict, deque
from functools import wraps
from inspect import Parameter, signature
from typing import (
Any,
Expand Down Expand Up @@ -42,6 +44,7 @@
)

from hydra_zen import instantiate, just, make_custom_builds_fn
from hydra_zen._compatibility import HYDRA_VERSION, Version
from hydra_zen.errors import HydraZenValidationError
from hydra_zen.structured_configs._type_guards import safe_getattr
from hydra_zen.structured_configs._utils import get_obj_path
Expand Down Expand Up @@ -161,7 +164,10 @@ def __init__(
self.func: Callable[P, R] = __func

try:
self.parameters: Mapping[str, Parameter] = signature(self.func).parameters
# Must cast to dict so that `self` is pickle-compatible.
self.parameters: Mapping[str, Parameter] = dict(
signature(self.func).parameters
)
except (ValueError, TypeError):
raise HydraZenValidationError(
"hydra_zen.zen can only wrap callables that possess inspectable signatures."
Expand Down Expand Up @@ -372,15 +378,20 @@ def hydra_main(
Parameters
----------
config_path : Optional[str]
The config path, a directory relative to the declaring python file.
The config path, an absolute path to a directory or a directory relative to
the declaring python file. If `config_path` is not specified no directory is
added to the config search path.
If config_path is not specified no directory is added to the Config search path.
Specifying `config_path` via `Zen.hydra_main` is only supported for
Hydra 1.3.0+.
config_name : Optional[str]
The name of the config (usually the file name without the .yaml extension)
version_base : Optional[str]
There are three classes of values that the version_base parameter supports, given new and existing users greater control of the default behaviors to use.
There are three classes of values that the version_base parameter supports,
given new and existing users greater control of the default behaviors to
use.
- If the version_base parameter is not specified, Hydra 1.x will use defaults compatible with version 1.1. Also in this case, a warning is issued to indicate an explicit version_base is preferred.
- If the version_base parameter is None, then the defaults are chosen for the current minor Hydra version. For example for Hydra 1.2, then would imply config_path=None and hydra.job.chdir=False.
Expand All @@ -394,6 +405,36 @@ def hydra_main(

kw = dict(config_name=config_name)

# For relative config paths, Hydra looks in the directory relative to the file
# in which the task function is defined. Unfortunately, it is only able to
# follow wrappers starting in Hydra 1.3.0. Thus `Zen.hydra_main` cannot
# handle string config_path entries until Hydra 1.3.0
if (config_path is _UNSPECIFIED_ and HYDRA_VERSION < Version(1, 2, 0)) or (
(
isinstance(config_path, str)
or (config_path is _UNSPECIFIED_ and version_base == "1.1")
)
and HYDRA_VERSION < Version(1, 3, 0)
): # pragma: no cover
warnings.warn(
"Specifying config_path via hydra_zen.zen(...).hydra_main "
"is only supported for Hydra 1.3.0+"
)
if Version(1, 3, 0) <= HYDRA_VERSION and isinstance(config_path, str):
# Here we create an on-the-fly wrapper so that Hydra can trace
# back through the wrapper to the original task function
# We could give `Zen` as `__wrapped__` attr, but this messes with
# things like `inspect.signature`.
#
# A downside of this is that `wrapper` is not pickle-able.
@wraps(self.func)
def wrapper(cfg: Any):
return self(cfg)

target = wrapper
else:
target = self

if config_path is not _UNSPECIFIED_:
kw["config_path"] = config_path

Expand All @@ -402,7 +443,7 @@ def hydra_main(
): # pragma: no cover
kw["version_base"] = version_base

return hydra.main(**kw)(self)()
return hydra.main(**kw)(target)()


@overload
Expand Down Expand Up @@ -517,6 +558,8 @@ def wrapped(cfg):
will cause `zen` to pass the full, resolved config to that field. This specific
parameter name can be overridden via `Zen.CFG_NAME`.
Specifying `config_path` via `Zen.hydra_main` is only supported for Hydra 1.3.0+.
Examples
--------
**Basic Usage**
Expand Down Expand Up @@ -549,14 +592,6 @@ def wrapped(cfg):
>>> zen_f.func(-1, 1)
0
`zen` can be used as a decorator
>>> @zen
... def zen_g(x, y):
... return x + y
>>> zen_g({'x': 1, 'y': 2})
3
`zen` is compatible with partial'd functions.
>>> from functools import partial
Expand Down
2 changes: 2 additions & 0 deletions tests/example_app/__init__.py
@@ -0,0 +1,2 @@
# Copyright (c) 2023 Massachusetts Institute of Technology
# SPDX-License-Identifier: MIT
1 change: 1 addition & 0 deletions tests/example_app/config.yaml
@@ -0,0 +1 @@
default_default: true
2 changes: 2 additions & 0 deletions tests/example_app/dir1/__init__.py
@@ -0,0 +1,2 @@
# Copyright (c) 2023 Massachusetts Institute of Technology
# SPDX-License-Identifier: MIT
1 change: 1 addition & 0 deletions tests/example_app/dir1/cfg1.yaml
@@ -0,0 +1 @@
dir1_cfg1: true
1 change: 1 addition & 0 deletions tests/example_app/dir1/cfg2.yaml
@@ -0,0 +1 @@
dir1_cfg2: true
2 changes: 2 additions & 0 deletions tests/example_app/dir2/__init__.py
@@ -0,0 +1,2 @@
# Copyright (c) 2023 Massachusetts Institute of Technology
# SPDX-License-Identifier: MIT
1 change: 1 addition & 0 deletions tests/example_app/dir2/cfg1.yaml
@@ -0,0 +1 @@
dir2_cfg1: true
1 change: 1 addition & 0 deletions tests/example_app/dir2/cfg2.yaml
@@ -0,0 +1 @@
dir2_cfg2: true
File renamed without changes.
16 changes: 16 additions & 0 deletions tests/example_app/zen_main_w_config_path.py
@@ -0,0 +1,16 @@
# Copyright (c) 2023 Massachusetts Institute of Technology
# SPDX-License-Identifier: MIT

from pathlib import Path

from hydra_zen import zen

cwd = Path.cwd()


def main(zen_cfg):
print(zen_cfg)


if __name__ == "__main__":
zen(main).hydra_main(config_name="config", config_path=".")
2 changes: 1 addition & 1 deletion tests/test_with_hydra_submitit.py
Expand Up @@ -18,7 +18,7 @@ def test_pickling_with_hydra_main():
import subprocess
from pathlib import Path

path = (Path(__file__).parent / "dummy_zen_main.py").absolute()
path = (Path(__file__).parent / "example_app" / "dummy_zen_main.py").absolute()
assert not (Path.cwd() / "multirun").is_dir()
subprocess.run(
["python", path, "x=1", "y=2", "hydra/launcher=submitit_local", "--multirun"]
Expand Down
65 changes: 64 additions & 1 deletion tests/test_zen_decorator.py → tests/test_zen.py
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: MIT

import os
import pickle
import random
import sys
from dataclasses import dataclass
Expand All @@ -12,6 +13,7 @@
from omegaconf import DictConfig

from hydra_zen import builds, make_config, to_yaml, zen
from hydra_zen._compatibility import HYDRA_VERSION
from hydra_zen.errors import HydraZenValidationError
from hydra_zen.wrapper import Zen
from tests.custom_strategies import everything_except
Expand Down Expand Up @@ -438,7 +440,7 @@ def test_hydra_main():

from hydra_zen import load_from_yaml

path = (Path(__file__).parent / "dummy_zen_main.py").absolute()
path = (Path(__file__).parent / "example_app" / "dummy_zen_main.py").absolute()
assert not (Path.cwd() / "outputs").is_dir()
subprocess.run(["python", path, "x=1", "y=2"]).check_returncode()
assert (Path.cwd() / "outputs").is_dir()
Expand All @@ -453,6 +455,55 @@ def test_hydra_main():
}


@pytest.mark.xfail(
HYDRA_VERSION < (1, 3, 0),
reason="hydra_main(config_path=...) only supports wrapped task functions starting "
"in Hydra 1.3.0",
)
@pytest.mark.skipif(
sys.platform.startswith("win") and bool(os.environ.get("CI")),
reason="Things are weird on GitHub Actions and Windows",
)
@pytest.mark.parametrize(
"dir_, name",
[
("dir1", "cfg1"),
("dir1", "cfg2"),
("dir2", "cfg1"),
("dir2", "cfg2"),
(None, None),
],
)
@pytest.mark.usefixtures("cleandir")
def test_hydra_main_config_path(dir_, name):
# regression test for https://github.com/mit-ll-responsible-ai/hydra-zen/issues/381
import subprocess
from pathlib import Path

from hydra_zen import load_from_yaml

path = (
Path(__file__).parent / "example_app" / "zen_main_w_config_path.py"
).absolute()
assert not (Path.cwd() / "outputs").is_dir()

run_in = ["python", path]

if dir_ is not None:
run_in.extend([f"--config-name={name}", f"--config-path={dir_}"])
else:
dir_, name = "default", "default"
subprocess.run(run_in).check_returncode()

assert (Path.cwd() / "outputs").is_dir()

*_, latest_job = sorted((Path.cwd() / "outputs").glob("*/*"))

assert load_from_yaml(latest_job / ".hydra" / "config.yaml") == {
f"{dir_}_{name}": 1
}


@pytest.mark.parametrize(
"zen_func",
[
Expand Down Expand Up @@ -496,3 +547,15 @@ def f(y):
return y.x

assert zen(f)(Conf) == 1


def pikl(x):
return x * 2


zpikl = zen(pikl)


def test_pickle_compatible():
loaded = pickle.loads(pickle.dumps(zpikl))
assert loaded({"x": 3}) == pikl(3)

0 comments on commit 7c455a6

Please sign in to comment.