From d4fca02a040f3acf79d829caaca4421ce6fa90fc Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 6 Oct 2021 16:57:54 -0500 Subject: [PATCH] refactor DictConfig._validate_non_optional, standardize err messages --- omegaconf/dictconfig.py | 34 ++++++++----------- .../structured_conf/test_structured_basic.py | 2 +- tests/test_errors.py | 4 +-- tests/test_matrix.py | 2 +- tests/test_merge.py | 2 +- 5 files changed, 19 insertions(+), 25 deletions(-) diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index 9d7272050..c591b6bfe 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -236,33 +236,27 @@ def _validate_merge(self, value: Any) -> None: ) raise ValidationError(msg) - def _validate_non_optional(self, key: Any, value: Any) -> None: + def _validate_non_optional(self, key: Optional[DictKeyType], value: Any) -> None: if _is_none(value, resolve=True, throw_on_resolution_failure=False): + if key is not None: child = self._get_node(key) if child is not None: assert isinstance(child, Node) - if not child._is_optional(): - self._format_and_raise( - key=key, - value=value, - cause=ValidationError("child '$FULL_KEY' is not Optional"), - ) + field_is_optional = child._is_optional() else: - is_optional, _ = _resolve_optional(self._metadata.element_type) - if not is_optional: - self._format_and_raise( - key=key, - value=value, - cause=ValidationError("field '$FULL_KEY' is not Optional"), - ) - else: - if not self._is_optional(): - self._format_and_raise( - key=None, - value=value, - cause=ValidationError("field '$FULL_KEY' is not Optional"), + field_is_optional, _ = _resolve_optional( + self._metadata.element_type ) + else: + field_is_optional = self._is_optional() + + if not field_is_optional: + self._format_and_raise( + key=key, + value=value, + cause=ValidationError("field '$FULL_KEY' is not Optional"), + ) def _raise_invalid_value( self, value: Any, value_type: Any, target_type: Any diff --git a/tests/structured_conf/test_structured_basic.py b/tests/structured_conf/test_structured_basic.py index 7fe3b3887..5d0d27af2 100644 --- a/tests/structured_conf/test_structured_basic.py +++ b/tests/structured_conf/test_structured_basic.py @@ -106,7 +106,7 @@ def test_merge_error_override_bad_type(self, module: Any) -> None: def test_error_message(self, module: Any) -> None: cfg = OmegaConf.structured(module.StructuredOptional) - msg = re.escape("child 'not_optional' is not Optional") + msg = re.escape("field 'not_optional' is not Optional") with raises(ValidationError, match=msg): cfg.not_optional = None diff --git a/tests/test_errors.py b/tests/test_errors.py index cc50f3f75..e2ac89a4e 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -131,7 +131,7 @@ def finalize(self, cfg: Any) -> None: create=lambda: OmegaConf.structured(StructuredWithMissing), op=lambda cfg: OmegaConf.update(cfg, "num", None), exception_type=ValidationError, - msg="child 'num' is not Optional", + msg="field 'num' is not Optional", parent_node=lambda cfg: cfg, child_node=lambda cfg: cfg._get_node("num"), object_type=StructuredWithMissing, @@ -370,7 +370,7 @@ def finalize(self, cfg: Any) -> None: ), op=lambda cfg: setattr(cfg, "foo", None), exception_type=ValidationError, - msg="child 'foo' is not Optional", + msg="field 'foo' is not Optional", key="foo", full_key="foo", child_node=lambda cfg: cfg.foo, diff --git a/tests/test_matrix.py b/tests/test_matrix.py index 5cce208a1..d26ae740e 100644 --- a/tests/test_matrix.py +++ b/tests/test_matrix.py @@ -111,7 +111,7 @@ def test_none_assignment_and_merging_in_dict( data = {"node": node} cfg = OmegaConf.create(obj=data) verify(cfg, "node", none=False, opt=False, missing=False, inter=False) - msg = "child 'node' is not Optional" + msg = "field 'node' is not Optional" with raises(ValidationError, match=re.escape(msg)): cfg.node = None diff --git a/tests/test_merge.py b/tests/test_merge.py index eae936934..e28b83cff 100644 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -630,7 +630,7 @@ def test_merge( match=re.escape( dedent( """\ - child 'foo' is not Optional + field 'foo' is not Optional full_key: foo object_type=dict""" )