Skip to content

Commit

Permalink
Internal API to allow non primitive values
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed Sep 21, 2020
1 parent 4ef3bab commit 433fe16
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 19 deletions.
13 changes: 12 additions & 1 deletion omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@ class Metadata:
# unset : inherit from parent (None if no parent specifies)
# set to true: flag is true
# set to false: flag is false
flags: Dict[str, bool] = field(default_factory=dict)
flags: Optional[Dict[str, bool]] = None
resolver_cache: Dict[str, Any] = field(default_factory=lambda: defaultdict(dict))

def __post_init__(self) -> None:
if self.flags is None:
self.flags = {}


@dataclass
class ContainerMetadata(Metadata):
Expand All @@ -37,6 +41,9 @@ def __post_init__(self) -> None:
if self.element_type is not None:
assert self.element_type is Any or isinstance(self.element_type, type)

if self.flags is None:
self.flags = {}


class Node(ABC):
_metadata: Metadata
Expand All @@ -59,9 +66,11 @@ def _get_parent(self) -> Optional["Container"]:
def _set_flag(self, flag: str, value: Optional[bool]) -> "Node":
assert value is None or isinstance(value, bool)
if value is None:
assert self._metadata.flags is not None
if flag in self._metadata.flags:
del self._metadata.flags[flag]
else:
assert self._metadata.flags is not None
self._metadata.flags[flag] = value
return self

Expand All @@ -70,6 +79,7 @@ def _get_node_flag(self, flag: str) -> Optional[bool]:
:param flag: flag to inspect
:return: the state of the flag on this node.
"""
assert self._metadata.flags is not None
return self._metadata.flags[flag] if flag in self._metadata.flags else None

def _get_flag(self, flag: str) -> Optional[bool]:
Expand All @@ -80,6 +90,7 @@ def _get_flag(self, flag: str) -> Optional[bool]:
:return:
"""
flags = self._metadata.flags
assert flags is not None
if flag in flags and flags[flag] is not None:
return flags[flag]

Expand Down
8 changes: 6 additions & 2 deletions omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,9 @@ def expand(node: Container) -> None:
dest._metadata.object_type = src_type

# explicit flags on the source config are replacing the flag values in the destination
for flag, value in src._metadata.flags.items():
flags = src._metadata.flags
assert flags is not None
for flag, value in flags.items():
if value is not None:
dest._set_flag(flag, value)

Expand Down Expand Up @@ -363,7 +365,9 @@ def _merge_with(
self.append(item)

# explicit flags on the source config are replacing the flag values in the destination
for flag, value in other._metadata.flags.items():
flags = other._metadata.flags
assert flags is not None
for flag, value in flags.items():
if value is not None:
self._set_flag(flag, value)
else:
Expand Down
12 changes: 8 additions & 4 deletions omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
key_type: Union[Any, Type[Any]] = Any,
element_type: Union[Any, Type[Any]] = Any,
is_optional: bool = True,
flags: Optional[Dict[str, bool]] = None,
) -> None:
try:
super().__init__(
Expand All @@ -71,6 +72,7 @@ def __init__(
object_type=None,
key_type=key_type,
element_type=element_type,
flags=flags,
),
)
if not valid_value_annotation_type(
Expand Down Expand Up @@ -470,8 +472,7 @@ def setdefault(self, key: Union[str, Enum], default: Any = None) -> Any:
def items_ex(
self, resolve: bool = True, keys: Optional[List[str]] = None
) -> AbstractSet[Tuple[str, Any]]:
# Using a dictionary because the keys are ordered
items: Dict[Tuple[str, Any], None] = {}
items: List[Tuple[str, Any]] = []
for key in self.keys():
if resolve:
value = self.get(key)
Expand All @@ -480,9 +481,12 @@ def items_ex(
if isinstance(value, ValueNode):
value = value._value()
if keys is None or key in keys:
items[(key, value)] = None
items.append((key, value))

return items.keys()
# For some reason items wants to return a Set, but if the values are not
# hashable this is a problem. We use a list instead. most use cases should just
# be iterating on pairs anyway.
return items # type: ignore

def __eq__(self, other: Any) -> bool:
if other is None:
Expand Down
2 changes: 2 additions & 0 deletions omegaconf/listconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
element_type: Optional[Type[Any]] = None,
is_optional: bool = True,
ref_type: Union[Type[Any], Any] = Any,
flags: Optional[Dict[str, bool]] = None,
) -> None:
try:
super().__init__(
Expand All @@ -61,6 +62,7 @@ def __init__(
optional=is_optional,
element_type=element_type,
key_type=int,
flags=flags,
),
)
if not (valid_value_annotation_type(self._metadata.element_type)):
Expand Down
5 changes: 4 additions & 1 deletion omegaconf/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ def __init__(
def validate_and_convert(self, value: Any) -> Any:
from ._utils import is_primitive_type

if not is_primitive_type(value):
# _allow_non_primitive_ is internal and not an official API. use at your own risk.
if self._get_flag(
"_allow_non_primitive_"
) is not True and not is_primitive_type(value):
t = get_type_of(value)
raise UnsupportedValueType(
f"Value '{t.__name__}' is not a supported primitive type"
Expand Down
52 changes: 41 additions & 11 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,49 +137,75 @@ def __init__(self) -> None:
raise NotImplementedError("Use one of the static construction functions")

@staticmethod
def structured(obj: Any, parent: Optional[BaseContainer] = None) -> Any:
return OmegaConf.create(obj, parent)
def structured(
obj: Any,
parent: Optional[BaseContainer] = None,
flags: Optional[Dict[str, bool]] = None,
) -> Any:
return OmegaConf.create(obj, parent, flags)

@staticmethod
@overload
def create(
obj: str, parent: Optional[BaseContainer] = None
obj: str,
parent: Optional[BaseContainer] = None,
flags: Optional[Dict[str, bool]] = None,
) -> Union[DictConfig, ListConfig]:
...

@staticmethod
@overload
def create(
obj: Union[List[Any], Tuple[Any, ...]], parent: Optional[BaseContainer] = None
obj: Union[List[Any], Tuple[Any, ...]],
parent: Optional[BaseContainer] = None,
flags: Optional[Dict[str, bool]] = None,
) -> ListConfig:
...

@staticmethod
@overload
def create(obj: DictConfig, parent: Optional[BaseContainer] = None) -> DictConfig:
def create(
obj: DictConfig,
parent: Optional[BaseContainer] = None,
flags: Optional[Dict[str, bool]] = None,
) -> DictConfig:
...

@staticmethod
@overload
def create(obj: ListConfig, parent: Optional[BaseContainer] = None) -> ListConfig:
def create(
obj: ListConfig,
parent: Optional[BaseContainer] = None,
flags: Optional[Dict[str, bool]] = None,
) -> ListConfig:
...

@staticmethod
@overload
def create(
obj: Union[Dict[str, Any], None] = None, parent: Optional[BaseContainer] = None
obj: Union[Dict[str, Any], None] = None,
parent: Optional[BaseContainer] = None,
flags: Optional[Dict[str, bool]] = None,
) -> DictConfig:
...

@staticmethod
def create( # noqa F811
obj: Any = _EMPTY_MARKER_, parent: Optional[BaseContainer] = None
obj: Any = _EMPTY_MARKER_,
parent: Optional[BaseContainer] = None,
flags: Optional[Dict[str, bool]] = None,
) -> Union[DictConfig, ListConfig]:
return OmegaConf._create_impl(obj=obj, parent=parent)
return OmegaConf._create_impl(
obj=obj,
parent=parent,
flags=flags,
)

@staticmethod
def _create_impl( # noqa F811
obj: Any = _EMPTY_MARKER_, parent: Optional[BaseContainer] = None
obj: Any = _EMPTY_MARKER_,
parent: Optional[BaseContainer] = None,
flags: Optional[Dict[str, bool]] = None,
) -> Union[DictConfig, ListConfig]:
try:
from ._utils import get_yaml_loader
Expand Down Expand Up @@ -225,12 +251,16 @@ def _create_impl( # noqa F811
ref_type=ref_type,
key_type=key_type,
element_type=element_type,
flags=flags,
)
elif is_primitive_list(obj) or OmegaConf.is_list(obj):
ref_type = OmegaConf.get_type(obj)
element_type = get_list_element_type(ref_type)
return ListConfig(
element_type=element_type, content=obj, parent=parent
element_type=element_type,
content=obj,
parent=parent,
flags=flags,
)
else:
if isinstance(obj, type):
Expand Down

0 comments on commit 433fe16

Please sign in to comment.