Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds version_base to launch #273

Merged
merged 5 commits into from May 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/changes.rst
Expand Up @@ -60,6 +60,10 @@ will no longer auto-broaden nested container types when ``OmegaConf v2.2.0+`` is
installed. (See :pull:`261`)


Hydra ``v1.2.0`` is introducing a ``version_base`` parameter that can control default behaviors in ``hydra.run`` and ``hydra.initialize``.
Correspondingly, ``version_base`` is now exposed via `~hydra_zen.launch`. See :pull:`273` for more details.


.. _0p7p0-deprecations:

Deprecations
Expand Down
35 changes: 25 additions & 10 deletions src/hydra_zen/_launch.py
Expand Up @@ -17,6 +17,10 @@
from hydra_zen.typing._implementations import DataClass


class _NotSet: # pragma: no cover
pass


def _store_config(
cfg: Union[DataClass, Type[DataClass], DictConfig, ListConfig, Mapping[Any, Any]],
config_name: str = "hydra_launch",
Expand Down Expand Up @@ -54,11 +58,12 @@ def launch(
config: Union[DataClass, Type[DataClass], Mapping[str, Any]],
task_function: Callable[[DictConfig], Any],
overrides: Optional[List[str]] = None,
multirun: bool = False,
version_base: Optional[Union[str, Type[_NotSet]]] = _NotSet,
to_dictconfig: bool = False,
config_name: str = "zen_launch",
job_name: str = "zen_launch",
with_log_configuration: bool = True,
multirun: bool = False,
to_dictconfig: bool = False,
) -> Union[JobReturn, Any]:
r"""Launch a Hydra job using a Python-based interface.

Expand All @@ -84,6 +89,19 @@ def launch(
If provided, sets/overrides values in ``config``. See [1]_ and [2]_
for a detailed discussion of the "grammar" supported by ``overrides``.

multirun : bool (default: False)
Launch a Hydra multi-run ([3]_).

version_base : Optional[str], optional (default=_NotSet)
Available starting with Hydra 1.2.0.
- If the `version_base parameter` is not specified, Hydra 1.x will use defaults compatible with version 1.1. Also in this case, a warning is issued to indicate an explicit version_base is preferred.
- If the `version_base parameter` is `None`, then the defaults are chosen for the current minor Hydra version. For example for Hydra 1.2, then would imply `config_path=None` and `hydra.job.chdir=False`.
- If the `version_base` parameter is an explicit version string like "1.1", then the defaults appropriate to that version are used.

to_dictconfig: bool (default: False)
If ``True``, convert a ``dataclasses.dataclass`` to a ``omegaconf.DictConfig``. Note, this
will remove Hydra's cabability for validation with structured configurations.

config_name : str (default: "zen_launch")
Name of the stored configuration in Hydra's ConfigStore API.

Expand All @@ -92,13 +110,6 @@ def launch(
with_log_configuration : bool (default: True)
If ``True``, enables the configuration of the logging subsystem from the loaded config.

multirun : bool (default: False)
Launch a Hydra multi-run ([3]_).

to_dictconfig: bool (default: False)
If ``True``, convert a ``dataclasses.dataclass`` to a ``omegaconf.DictConfig``. Note, this
will remove Hydra's cabability for validation with structured configurations.

Returns
-------
result : JobReturn | Any
Expand Down Expand Up @@ -221,7 +232,11 @@ def launch(
config_name = _store_config(config, config_name)

# Initializes Hydra and add the config_path to the config search path
with initialize(config_path=None, job_name=job_name):
with initialize(
config_path=None,
job_name=job_name,
**({} if version_base is _NotSet else {"version_base": version_base})
):

# taken from hydra.compose with support for MULTIRUN
gh = GlobalHydra.instance()
Expand Down
16 changes: 14 additions & 2 deletions tests/conftest.py
@@ -1,18 +1,19 @@
# Copyright (c) 2022 Massachusetts Institute of Technology
# SPDX-License-Identifier: MIT


import logging
import os
import sys
import tempfile
from typing import Iterable
from typing import Dict, Iterable, Optional

import hypothesis.strategies as st
import pkg_resources
import pytest
from omegaconf import DictConfig, ListConfig

from hydra_zen._compatibility import HYDRA_VERSION

# Skip collection of tests that don't work on the current version of Python.
collect_ignore_glob = []

Expand Down Expand Up @@ -55,6 +56,17 @@ def cleandir() -> Iterable[str]:
logging.shutdown()


@pytest.fixture()
def version_base() -> Dict[str, Optional[str]]:
"""Return version_base according to local version, or empty dict for versions
preceding version_base"""
return (
{"version_base": ".".join(str(i) for i in HYDRA_VERSION)}
if HYDRA_VERSION >= (1, 2, 0)
else {}
)


pytest_plugins = "pytester"

st.register_type_strategy(ListConfig, st.lists(st.integers()).map(ListConfig))
Expand Down
8 changes: 4 additions & 4 deletions tests/test_defaults_list.py
Expand Up @@ -13,19 +13,19 @@
from hydra_zen.errors import HydraZenValidationError


def test_hydra_defaults_work_builds():
def test_hydra_defaults_work_builds(version_base):
config_store = ConfigStore.instance()
config_store.store(group="x", name="a", node=builds(int, 10))
Conf = builds(dict, x=None, y="hi", hydra_defaults=["_self_", {"x": "a"}])
job = launch(Conf, instantiate)
job = launch(Conf, instantiate, **version_base)
assert job.return_value == {"x": 10, "y": "hi"}


def test_hydra_defaults_work_make_config():
def test_hydra_defaults_work_make_config(version_base):
config_store = ConfigStore.instance()
config_store.store(group="x", name="a", node=builds(int, 10))
Conf = make_config(x=None, y="hi", hydra_defaults=["_self_", {"x": "a"}])
job = launch(Conf, instantiate)
job = launch(Conf, instantiate, **version_base)
assert job.return_value == {"x": 10, "y": "hi"}


Expand Down
3 changes: 2 additions & 1 deletion tests/test_launch/test_callbacks.py
Expand Up @@ -73,7 +73,7 @@ def tracker(x=CustomCallback):

@pytest.mark.usefixtures("cleandir")
@pytest.mark.parametrize("multirun", [False, True])
def test_hydra_run_with_callback(multirun):
def test_hydra_run_with_callback(multirun, version_base: dict):
# Tests that callback methods are called during appropriate
# stages
try:
Expand All @@ -86,6 +86,7 @@ def test_hydra_run_with_callback(multirun):
task_function=instantiate,
overrides=["hydra/callbacks=test_callback"],
multirun=multirun,
**version_base,
)

if multirun:
Expand Down
58 changes: 42 additions & 16 deletions tests/test_launch/test_implementations.py
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: MIT

from pathlib import Path
from typing import Optional

import pytest
from hydra.core.config_store import ConfigStore
Expand All @@ -11,6 +12,7 @@
from omegaconf.omegaconf import OmegaConf

from hydra_zen import builds, instantiate, launch, make_config
from hydra_zen._compatibility import HYDRA_VERSION
from hydra_zen._launch import _store_config

try:
Expand Down Expand Up @@ -49,8 +51,8 @@ def test_store_config(cfg):
@pytest.mark.usefixtures("cleandir")
@pytest.mark.parametrize("cfg", CONFIG_TYPE_EXAMPLES)
@pytest.mark.parametrize("multirun", [False, True])
def test_launch_config_type(cfg, multirun):
job = launch(cfg, task_function=instantiate, multirun=multirun)
def test_launch_config_type(cfg, multirun, version_base):
job = launch(cfg, task_function=instantiate, multirun=multirun, **version_base)
if isinstance(job, list):
job = job[0][0]

Expand All @@ -64,13 +66,13 @@ def test_launch_config_type(cfg, multirun):
@pytest.mark.usefixtures("cleandir")
@pytest.mark.parametrize("cfg", DATACLASS_CONFIG_TYPE_EXAMPLES)
@pytest.mark.parametrize("to_dictconfig", [True, False])
def test_launch_to_dictconfig(cfg, to_dictconfig):
def test_launch_to_dictconfig(cfg, to_dictconfig, version_base):
pre_num_fields = len(dataclasses.fields(cfg))

def task_fn(cfg):
_ = cloudpickle.loads(cloudpickle.dumps(cfg))

launch(cfg, task_function=task_fn, to_dictconfig=to_dictconfig)
launch(cfg, task_function=task_fn, to_dictconfig=to_dictconfig, **version_base)

if pre_num_fields > 0:
if not to_dictconfig:
Expand All @@ -79,18 +81,15 @@ def task_fn(cfg):
assert len(dataclasses.fields(cfg)) > 0
else:
# run again with no error
launch(cfg, task_function=task_fn, to_dictconfig=to_dictconfig)
launch(cfg, task_function=task_fn, to_dictconfig=to_dictconfig, **version_base)


@pytest.mark.usefixtures("cleandir")
@pytest.mark.parametrize(
"overrides", [None, [], ["hydra.run.dir=test_hydra_overrided"]]
)
@pytest.mark.parametrize("with_log_configuration", [False, True])
def test_launch_job(
overrides,
with_log_configuration,
):
def test_launch_job(overrides, with_log_configuration, version_base):
cfg = dict(a=1, b=1)
override_exists = overrides and len(overrides) > 1

Expand All @@ -99,6 +98,7 @@ def test_launch_job(
task_function=instantiate,
overrides=overrides,
with_log_configuration=with_log_configuration,
**version_base,
)
assert job.return_value == {"a": 1, "b": 1}

Expand All @@ -113,9 +113,7 @@ def test_launch_job(
@pytest.mark.parametrize("multirun_overrides", [None, ["a=1,2"]])
@pytest.mark.parametrize("with_log_configuration", [False, True])
def test_launch_multirun(
overrides,
multirun_overrides,
with_log_configuration,
overrides, multirun_overrides, with_log_configuration, version_base
):
cfg = dict(a=1, b=1)
override_exists = overrides and len(overrides) > 1
Expand All @@ -134,6 +132,7 @@ def test_launch_multirun(
overrides=_overrides,
with_log_configuration=with_log_configuration,
multirun=True,
**version_base,
)
assert isinstance(job, list) and len(job) == 1
for i, j in enumerate(job[0]):
Expand All @@ -144,11 +143,11 @@ def test_launch_multirun(


@pytest.mark.usefixtures("cleandir")
def test_launch_with_multirun_overrides():
def test_launch_with_multirun_overrides(version_base):
cfg = builds(dict, a=1, b=1)
multirun_overrides = ["hydra/sweeper=basic", "a=1,2"]
with pytest.raises(ConfigCompositionException):
launch(cfg, instantiate, overrides=multirun_overrides)
launch(cfg, instantiate, overrides=multirun_overrides, **version_base)


###############################################
Expand Down Expand Up @@ -191,10 +190,37 @@ def sweep(self, arguments):
"plugin",
[["hydra/sweeper=basic"], ["hydra/sweeper=local_test"]],
)
def test_launch_with_multirun_plugin(plugin):
def test_launch_with_multirun_plugin(plugin, version_base):
cfg = builds(dict, a=1, b=1)
multirun_overrides = plugin + ["a=1,2"]
job = launch(cfg, instantiate, overrides=multirun_overrides, multirun=True)
job = launch(
cfg, instantiate, overrides=multirun_overrides, multirun=True, **version_base
)
assert isinstance(job, list) and len(job) == 1 and len(job[0]) == 2
for i, j in enumerate(job[0]):
assert j.return_value == {"a": i + 1, "b": 1}


@pytest.mark.skipif(HYDRA_VERSION < (1, 2, 0), reason="version_base not supported")
@pytest.mark.parametrize("version_base", ["1.1", "1.2", None])
@pytest.mark.usefixtures("cleandir")
def test_version_base(version_base: Optional[str]):
def task(cfg):
(Path().cwd() / "foo.txt").touch()

expected_dir = Path().cwd() if version_base != "1.1" else (Path().cwd() / "outputs")

glob_pattern = "foo.txt" if version_base != "1.1" else "./**/foo.txt"

assert len(list(expected_dir.glob(glob_pattern))) == 0

launch(make_config(), task, version_base=version_base)
assert len(list(expected_dir.glob(glob_pattern))) == 1, list(expected_dir.glob("*"))

# ensure the file isn't found in the opposite location
not_found = (
Path().cwd().glob("foo.txt")
if version_base == "1.1"
else (Path().cwd() / "outputs").glob("**/foo.txt")
)
assert len(list(not_found)) == 0
31 changes: 25 additions & 6 deletions tests/test_launch/test_logging.py
Expand Up @@ -8,6 +8,7 @@
from hydra.core.utils import JobReturn

from hydra_zen import builds, instantiate, launch
from hydra_zen._compatibility import HYDRA_VERSION

log = logging.getLogger(__name__)

Expand All @@ -18,12 +19,22 @@ def task(cfg):


@pytest.mark.usefixtures("cleandir")
def test_consecutive_logs():
def test_consecutive_logs(version_base):
overrides = ["hydra.run.dir=test"]
if HYDRA_VERSION >= (1, 2, 0):
overrides.append("hydra.job.chdir=True")

job1 = launch(
builds(dict, message="1"), task_function=task, overrides=["hydra.run.dir=test"]
builds(dict, message="1"),
task_function=task,
overrides=overrides,
**version_base,
)
job2 = launch(
builds(dict, message="2"), task_function=task, overrides=["hydra.run.dir=test"]
builds(dict, message="2"),
task_function=task,
overrides=overrides,
**version_base,
)

assert isinstance(job1, JobReturn) and job1.working_dir is not None
Expand All @@ -38,12 +49,20 @@ def test_consecutive_logs():


@pytest.mark.usefixtures("cleandir")
def test_seperate_logs():
def test_seperate_logs(version_base):
extra_overrides = ["hydra.job.chdir=True"] if HYDRA_VERSION >= (1, 2, 0) else []

job1 = launch(
builds(dict, message="1"), task_function=task, overrides=["hydra.run.dir=test1"]
builds(dict, message="1"),
task_function=task,
overrides=["hydra.run.dir=test1"] + extra_overrides,
**version_base,
)
job2 = launch(
builds(dict, message="2"), task_function=task, overrides=["hydra.run.dir=test2"]
builds(dict, message="2"),
task_function=task,
overrides=["hydra.run.dir=test2"] + extra_overrides,
**version_base,
)

assert isinstance(job1, JobReturn) and job1.working_dir is not None
Expand Down