From 433fe16a8ccb7e9c6a3585e9d6442b1e1540ef54 Mon Sep 17 00:00:00 2001 From: Omry Yadan Date: Mon, 21 Sep 2020 14:03:21 -0700 Subject: [PATCH] Internal API to allow non primitive values --- omegaconf/base.py | 13 +++++++++- omegaconf/basecontainer.py | 8 ++++-- omegaconf/dictconfig.py | 12 ++++++--- omegaconf/listconfig.py | 2 ++ omegaconf/nodes.py | 5 +++- omegaconf/omegaconf.py | 52 ++++++++++++++++++++++++++++++-------- 6 files changed, 73 insertions(+), 19 deletions(-) diff --git a/omegaconf/base.py b/omegaconf/base.py index 1539277f9..e12c44713 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -23,9 +23,13 @@ class Metadata: # unset : inherit from parent (None if no parent specifies) # set to true: flag is true # set to false: flag is false - flags: Dict[str, bool] = field(default_factory=dict) + flags: Optional[Dict[str, bool]] = None resolver_cache: Dict[str, Any] = field(default_factory=lambda: defaultdict(dict)) + def __post_init__(self) -> None: + if self.flags is None: + self.flags = {} + @dataclass class ContainerMetadata(Metadata): @@ -37,6 +41,9 @@ def __post_init__(self) -> None: if self.element_type is not None: assert self.element_type is Any or isinstance(self.element_type, type) + if self.flags is None: + self.flags = {} + class Node(ABC): _metadata: Metadata @@ -59,9 +66,11 @@ def _get_parent(self) -> Optional["Container"]: def _set_flag(self, flag: str, value: Optional[bool]) -> "Node": assert value is None or isinstance(value, bool) if value is None: + assert self._metadata.flags is not None if flag in self._metadata.flags: del self._metadata.flags[flag] else: + assert self._metadata.flags is not None self._metadata.flags[flag] = value return self @@ -70,6 +79,7 @@ def _get_node_flag(self, flag: str) -> Optional[bool]: :param flag: flag to inspect :return: the state of the flag on this node. """ + assert self._metadata.flags is not None return self._metadata.flags[flag] if flag in self._metadata.flags else None def _get_flag(self, flag: str) -> Optional[bool]: @@ -80,6 +90,7 @@ def _get_flag(self, flag: str) -> Optional[bool]: :return: """ flags = self._metadata.flags + assert flags is not None if flag in flags and flags[flag] is not None: return flags[flag] diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 41e56540a..83f23a798 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -313,7 +313,9 @@ def expand(node: Container) -> None: dest._metadata.object_type = src_type # explicit flags on the source config are replacing the flag values in the destination - for flag, value in src._metadata.flags.items(): + flags = src._metadata.flags + assert flags is not None + for flag, value in flags.items(): if value is not None: dest._set_flag(flag, value) @@ -363,7 +365,9 @@ def _merge_with( self.append(item) # explicit flags on the source config are replacing the flag values in the destination - for flag, value in other._metadata.flags.items(): + flags = other._metadata.flags + assert flags is not None + for flag, value in flags.items(): if value is not None: self._set_flag(flag, value) else: diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index f0ab24793..828757c33 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -60,6 +60,7 @@ def __init__( key_type: Union[Any, Type[Any]] = Any, element_type: Union[Any, Type[Any]] = Any, is_optional: bool = True, + flags: Optional[Dict[str, bool]] = None, ) -> None: try: super().__init__( @@ -71,6 +72,7 @@ def __init__( object_type=None, key_type=key_type, element_type=element_type, + flags=flags, ), ) if not valid_value_annotation_type( @@ -470,8 +472,7 @@ def setdefault(self, key: Union[str, Enum], default: Any = None) -> Any: def items_ex( self, resolve: bool = True, keys: Optional[List[str]] = None ) -> AbstractSet[Tuple[str, Any]]: - # Using a dictionary because the keys are ordered - items: Dict[Tuple[str, Any], None] = {} + items: List[Tuple[str, Any]] = [] for key in self.keys(): if resolve: value = self.get(key) @@ -480,9 +481,12 @@ def items_ex( if isinstance(value, ValueNode): value = value._value() if keys is None or key in keys: - items[(key, value)] = None + items.append((key, value)) - return items.keys() + # For some reason items wants to return a Set, but if the values are not + # hashable this is a problem. We use a list instead. most use cases should just + # be iterating on pairs anyway. + return items # type: ignore def __eq__(self, other: Any) -> bool: if other is None: diff --git a/omegaconf/listconfig.py b/omegaconf/listconfig.py index 3fe829949..5c6ea753e 100644 --- a/omegaconf/listconfig.py +++ b/omegaconf/listconfig.py @@ -50,6 +50,7 @@ def __init__( element_type: Optional[Type[Any]] = None, is_optional: bool = True, ref_type: Union[Type[Any], Any] = Any, + flags: Optional[Dict[str, bool]] = None, ) -> None: try: super().__init__( @@ -61,6 +62,7 @@ def __init__( optional=is_optional, element_type=element_type, key_type=int, + flags=flags, ), ) if not (valid_value_annotation_type(self._metadata.element_type)): diff --git a/omegaconf/nodes.py b/omegaconf/nodes.py index 315cc603e..16ecaa8c7 100644 --- a/omegaconf/nodes.py +++ b/omegaconf/nodes.py @@ -147,7 +147,10 @@ def __init__( def validate_and_convert(self, value: Any) -> Any: from ._utils import is_primitive_type - if not is_primitive_type(value): + # _allow_non_primitive_ is internal and not an official API. use at your own risk. + if self._get_flag( + "_allow_non_primitive_" + ) is not True and not is_primitive_type(value): t = get_type_of(value) raise UnsupportedValueType( f"Value '{t.__name__}' is not a supported primitive type" diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index ea163d7fd..329e6550a 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -137,49 +137,75 @@ def __init__(self) -> None: raise NotImplementedError("Use one of the static construction functions") @staticmethod - def structured(obj: Any, parent: Optional[BaseContainer] = None) -> Any: - return OmegaConf.create(obj, parent) + def structured( + obj: Any, + parent: Optional[BaseContainer] = None, + flags: Optional[Dict[str, bool]] = None, + ) -> Any: + return OmegaConf.create(obj, parent, flags) @staticmethod @overload def create( - obj: str, parent: Optional[BaseContainer] = None + obj: str, + parent: Optional[BaseContainer] = None, + flags: Optional[Dict[str, bool]] = None, ) -> Union[DictConfig, ListConfig]: ... @staticmethod @overload def create( - obj: Union[List[Any], Tuple[Any, ...]], parent: Optional[BaseContainer] = None + obj: Union[List[Any], Tuple[Any, ...]], + parent: Optional[BaseContainer] = None, + flags: Optional[Dict[str, bool]] = None, ) -> ListConfig: ... @staticmethod @overload - def create(obj: DictConfig, parent: Optional[BaseContainer] = None) -> DictConfig: + def create( + obj: DictConfig, + parent: Optional[BaseContainer] = None, + flags: Optional[Dict[str, bool]] = None, + ) -> DictConfig: ... @staticmethod @overload - def create(obj: ListConfig, parent: Optional[BaseContainer] = None) -> ListConfig: + def create( + obj: ListConfig, + parent: Optional[BaseContainer] = None, + flags: Optional[Dict[str, bool]] = None, + ) -> ListConfig: ... @staticmethod @overload def create( - obj: Union[Dict[str, Any], None] = None, parent: Optional[BaseContainer] = None + obj: Union[Dict[str, Any], None] = None, + parent: Optional[BaseContainer] = None, + flags: Optional[Dict[str, bool]] = None, ) -> DictConfig: ... @staticmethod def create( # noqa F811 - obj: Any = _EMPTY_MARKER_, parent: Optional[BaseContainer] = None + obj: Any = _EMPTY_MARKER_, + parent: Optional[BaseContainer] = None, + flags: Optional[Dict[str, bool]] = None, ) -> Union[DictConfig, ListConfig]: - return OmegaConf._create_impl(obj=obj, parent=parent) + return OmegaConf._create_impl( + obj=obj, + parent=parent, + flags=flags, + ) @staticmethod def _create_impl( # noqa F811 - obj: Any = _EMPTY_MARKER_, parent: Optional[BaseContainer] = None + obj: Any = _EMPTY_MARKER_, + parent: Optional[BaseContainer] = None, + flags: Optional[Dict[str, bool]] = None, ) -> Union[DictConfig, ListConfig]: try: from ._utils import get_yaml_loader @@ -225,12 +251,16 @@ def _create_impl( # noqa F811 ref_type=ref_type, key_type=key_type, element_type=element_type, + flags=flags, ) elif is_primitive_list(obj) or OmegaConf.is_list(obj): ref_type = OmegaConf.get_type(obj) element_type = get_list_element_type(ref_type) return ListConfig( - element_type=element_type, content=obj, parent=parent + element_type=element_type, + content=obj, + parent=parent, + flags=flags, ) else: if isinstance(obj, type):