Skip to content

Commit

Permalink
Enable _partial_ for Hydra > 1.1.1 (#230)
Browse files Browse the repository at this point in the history
* Update tests for new instantiation error-raising

* enable native Hydra support for _partial_ and fix bad annotation

* fix protocol test

* Add github action test against hydra 1.1.2dev

* fix actions workflow

* patch coverage

Co-authored-by: rsokl <ryan.soklaski@ll.mit.edu>
  • Loading branch information
rsokl and rsokl committed Feb 23, 2022
1 parent 7ce21f5 commit d250c59
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 17 deletions.
23 changes: 23 additions & 0 deletions .github/workflows/tox_run.yml
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,29 @@ jobs:
- name: Test with tox
run: tox -e pre-release

test-against-hydra-1_1_2dev:
runs-on: ubuntu-latest

strategy:
max-parallel: 3
matrix:
python-version: [3.8]
fail-fast: false

steps:
- uses: actions/checkout@v1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install tox tox-gh-actions
- name: Test with tox
run: tox -e hydra-1p1p2-pre-release


run-pyright:
runs-on: ubuntu-latest

Expand Down
7 changes: 7 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ deps = {[testenv]deps}
beartype
basepython = python3.8

[testenv:hydra-1p1p2-pre-release] # test against pre-releases of dependencies
pip_pre = true
deps = hydra-core==1.1.2dev
{[testenv]deps}
pydantic
beartype
basepython = python3.8

[testenv:coverage]
setenv = NUMBA_DISABLE_JIT=1
Expand Down
2 changes: 1 addition & 1 deletion src/hydra_zen/_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _get_version(ver_str: str) -> Version:
#
# Uncomment dynamice setting of `HYDRA_SUPPORTS_PARTIAL` once we can
# begin testing against nightly builds of Hydra
HYDRA_SUPPORTS_PARTIAL: Final = False # Version(1, 1, 1) < HYDRA_VERSION
HYDRA_SUPPORTS_PARTIAL: Final = Version(1, 1, 1) < HYDRA_VERSION

# Indicates primitive types permitted in type-hints of structured configs
HYDRA_SUPPORTED_PRIMITIVE_TYPES: Final = {int, float, bool, str, Enum}
Expand Down
31 changes: 20 additions & 11 deletions src/hydra_zen/structured_configs/_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,19 @@
_CONVERT_FIELD_NAME: Final[str] = "_convert_"
_POS_ARG_FIELD_NAME: Final[str] = "_args_"

_HYDRA_FIELD_NAMES: FrozenSet[str] = frozenset(
(
_TARGET_FIELD_NAME,
_RECURSIVE_FIELD_NAME,
_CONVERT_FIELD_NAME,
_POS_ARG_FIELD_NAME,
)
)
_names = [
_TARGET_FIELD_NAME,
_RECURSIVE_FIELD_NAME,
_CONVERT_FIELD_NAME,
_POS_ARG_FIELD_NAME,
]

if HYDRA_SUPPORTS_PARTIAL: # pragma: no cover
_names.append(_PARTIAL_FIELD_NAME)

_HYDRA_FIELD_NAMES: FrozenSet[str] = frozenset(_names)

del _names

# hydra-zen-specific fields
_ZEN_PROCESSING_LOCATION: Final[str] = _utils.get_obj_path(zen_processing)
Expand Down Expand Up @@ -1210,7 +1215,7 @@ def builds(
),
(
_PARTIAL_FIELD_NAME,
str,
bool,
_utils.field(default=zen_partial, init=False),
),
]
Expand Down Expand Up @@ -1701,8 +1706,12 @@ def builds(
dataclass_name, fields=sanitized_base_fields, bases=builds_bases, frozen=frozen
)

if zen_partial is False and hasattr(out, _ZEN_PARTIAL_TARGET_FIELD_NAME):
# `out._partial_target_` has been inherited; this will lead to an error when
# TODO: revisit this constraint
if zen_partial is False and (
hasattr(out, _ZEN_PARTIAL_TARGET_FIELD_NAME)
or (HYDRA_SUPPORTS_PARTIAL and hasattr(out, _PARTIAL_FIELD_NAME))
):
# `out._partial_` has been inherited; this will lead to an error when
# hydra-instantiation occurs, since it will be passed to target.
# There is not an easy way to delete this, since it comes from a parent class
raise TypeError(
Expand Down
4 changes: 3 additions & 1 deletion tests/test_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
is_partial_builds,
)
from hydra_zen.typing import Builds, Just, Partial, PartialBuilds
from hydra_zen.typing._implementations import HydraPartialBuilds


@pytest.mark.parametrize(
"fn, protocol",
[
(just, Just),
(builds, Builds),
(partial(builds, zen_partial=True), PartialBuilds),
(partial(builds, zen_partial=True), (PartialBuilds, HydraPartialBuilds)),
(partial(builds, zen_partial=True, zen_meta=dict(y=1)), PartialBuilds),
],
)
def test_runtime_checkability_of_protocols(fn, protocol):
Expand Down
9 changes: 5 additions & 4 deletions tests/test_zen_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import hypothesis.strategies as st
import pytest
from hydra.errors import InstantiationException
from hypothesis import given, settings
from omegaconf import OmegaConf
from omegaconf.errors import InterpolationKeyError, InterpolationResolutionError
Expand Down Expand Up @@ -34,7 +35,7 @@ def _coordinate_meta_fields_for_interpolation(wrappers, zen_meta):
# interpolated strings map to the named decorators
if is_interpolated_string(wrappers):
# change level of interpolation
wrappers = wrappers.replace("..", ".") # type: ignore
wrappers = wrappers.replace("..", ".")
dec_name: str = wrappers[3:-1]
item = decorators_by_name[dec_name]
zen_meta[dec_name] = item if item is None else just(item)
Expand Down Expand Up @@ -134,7 +135,7 @@ def target(*args, **kwargs):
as_yaml=st.booleans(),
)
def test_zen_wrappers_expected_behavior(
wrappers: Union[ # type: ignore
wrappers: Union[
Union[TrackedFunc, Just[TrackedFunc], PartialBuilds[TrackedFunc], InterpStr],
List[
Union[TrackedFunc, Just[TrackedFunc], PartialBuilds[TrackedFunc], InterpStr]
Expand Down Expand Up @@ -231,7 +232,7 @@ class NotAWrapper:
],
)
def test_zen_wrappers_validation_during_builds(bad_wrapper):
with pytest.raises(TypeError):
with pytest.raises((TypeError, InstantiationException)):
builds(int, zen_wrappers=bad_wrapper)


Expand All @@ -246,7 +247,7 @@ def test_zen_wrappers_validation_during_builds(bad_wrapper):
)
def test_zen_wrappers_validation_during_instantiation(bad_wrapper):
conf = builds(int, zen_wrappers=bad_wrapper)
with pytest.raises(TypeError):
with pytest.raises((TypeError, InstantiationException)):
instantiate(conf)


Expand Down

0 comments on commit d250c59

Please sign in to comment.