Skip to content

Commit

Permalink
proper error when trying to use Union types with with structured conf…
Browse files Browse the repository at this point in the history
…igs (currently unsupported)
  • Loading branch information
omry committed Apr 16, 2020
1 parent 71caceb commit 08bced2
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 10 deletions.
26 changes: 22 additions & 4 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .errors import (
ConfigIndexError,
ConfigTypeError,
ConfigValueError,
KeyValidationError,
OmegaConfBaseException,
ValidationError,
Expand Down Expand Up @@ -76,9 +77,11 @@ def _get_class(path: str) -> type:
return klass


def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
from typing import Union
def _is_union(type_: Any) -> bool:
return getattr(type_, "__origin__", None) is Union


def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
if getattr(type_, "__origin__", None) is Union:
args = type_.__args__
if len(args) == 2 and args[1] == type(None): # noqa E721
Expand Down Expand Up @@ -107,8 +110,8 @@ def get_attr_data(obj: Any) -> Dict[str, Any]:
obj_type = obj if is_type else type(obj)
for name, attrib in attr.fields_dict(obj_type).items():
is_optional, type_ = _resolve_optional(attrib.type)
type_ = _resolve_forward(type_, obj.__module__)
is_nested = is_attr_class(type_)
type_ = _resolve_forward(type_, obj.__module__)
if not is_type:
value = getattr(obj, name)
else:
Expand All @@ -121,6 +124,12 @@ def get_attr_data(obj: Any) -> Dict[str, Any]:
f"Missing default value for {name}, to indicate "
"default must be populated later use OmegaConf.MISSING"
)
if _is_union(type_):
e = ConfigValueError(
f"Union types are not supported:\n{name}: {type_str(type_)}"
)
format_and_raise(node=None, key=None, value=value, cause=e, msg=str(e))

d[name] = _maybe_wrap(
ref_type=type_, is_optional=is_optional, key=name, value=value, parent=None,
)
Expand Down Expand Up @@ -150,6 +159,11 @@ def get_dataclass_data(obj: Any) -> Dict[str, Any]:
"default must be populated later use OmegaConf.MISSING"
)

if _is_union(type_):
e = ConfigValueError(
f"Union types are not supported:\n{name}: {type_str(type_)}"
)
format_and_raise(node=None, key=None, value=value, cause=e, msg=str(e))
d[name] = _maybe_wrap(
ref_type=type_, is_optional=is_optional, key=name, value=value, parent=None,
)
Expand Down Expand Up @@ -500,7 +514,11 @@ def type_str(t: Any) -> str:
if hasattr(t, "__name__"):
name = str(t.__name__)
else:
name = str(t._name) # pragma: no cover
if t._name is None:
if t.__origin__ is not None:
name = type_str(t.__origin__)
else:
name = str(t._name) # pragma: no cover

args = getattr(t, "__args__", None)
if args is not None:
Expand Down
3 changes: 2 additions & 1 deletion omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
is_primitive_type,
is_structured_config,
isint,
type_str,
valid_value_annotation_type,
)
from .base import Container, Node
Expand Down Expand Up @@ -586,7 +587,7 @@ def _node_wrap(
) -> ValueNode:
if not valid_value_annotation_type(type_):
raise ValidationError(
f"Annotated class '{type_.__name__}' is not a structured config. "
f"Annotated class '{type_str(type_)}' is not a structured config. "
"did you forget to decorate it as a dataclass?"
)
node: ValueNode
Expand Down
7 changes: 6 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, Iterator, List, Optional
from typing import Any, Dict, Iterator, List, Optional, Union

from omegaconf import II, MISSING

Expand Down Expand Up @@ -83,3 +83,8 @@ class StructuredWithMissing:
inter_num: int = II("num")
inter_user: User = II("user")
inter_opt_user: Optional[User] = II("opt_user")


@dataclass
class UnionError:
x: Union[int, str] = 10
7 changes: 6 additions & 1 deletion tests/structured_conf/data/attr_classes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

import attr # noqaE402
import pytest
Expand Down Expand Up @@ -387,3 +387,8 @@ class Missing2:
@attr.s(auto_attribs=True)
class NestedWithNone:
plugin: Optional[Plugin] = None


@attr.s(auto_attribs=True)
class UnionError:
x: Union[int, str] = 10
7 changes: 6 additions & 1 deletion tests/structured_conf/data/dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field # noqaE402
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

import pytest

Expand Down Expand Up @@ -394,3 +394,8 @@ class Missing2:
@dataclass
class NestedWithNone:
plugin: Optional[Plugin] = None


@dataclass
class UnionError:
x: Union[int, str] = 10
5 changes: 5 additions & 0 deletions tests/structured_conf/test_structured_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ def test_value_without_a_default(self, class_type: str) -> None:
"no_default": 10
}

def test_union_errors(self, class_type: str) -> None:
module: Any = import_module(class_type)
with pytest.raises(ValueError):
OmegaConf.structured(module.UnionError)

def test_config_with_list(self, class_type: str) -> None:
module: Any = import_module(class_type)

Expand Down
20 changes: 19 additions & 1 deletion tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@
MissingMandatoryValue,
)

from . import Color, ConcretePlugin, IllegalType, Plugin, StructuredWithMissing
from . import (
Color,
ConcretePlugin,
IllegalType,
Plugin,
StructuredWithMissing,
UnionError,
)


@dataclass
Expand Down Expand Up @@ -405,6 +412,17 @@ def finalize(self, cfg: Any) -> None:
),
id="structured:create_from_unsupported_object",
),
pytest.param(
Expected(
create=lambda: None,
op=lambda cfg: OmegaConf.structured(UnionError),
exception_type=ValueError,
msg="Union types are not supported:\nx: Union[int, str]",
object_type_str=None,
ref_type_str=None,
),
id="structured:create_with_union_error",
),
# assign
pytest.param(
Expected(
Expand Down
3 changes: 2 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List
from typing import Any, Dict, List, Union

import attr
import pytest
Expand Down Expand Up @@ -308,6 +308,7 @@ def test_is_primitive_type(type_: Any, is_primitive: bool) -> None:
(List[str], "List[str]"),
(List[Color], "List[Color]"),
(List[Dict[str, Color]], "List[Dict[str, Color]]"),
(Union[str, int, Color], "Union[str, int, Color]"),
],
)
def test_type_str(type_: Any, expected: str) -> None:
Expand Down

0 comments on commit 08bced2

Please sign in to comment.