diff --git a/.isort.cfg b/.isort.cfg index ba2778dc8..3b3abdd03 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -4,3 +4,6 @@ include_trailing_comma=True force_grid_wrap=0 use_parentheses=True line_length=88 +ensure_newline_before_comments=True +known_third_party=omegaconf,ray,pytest +known_first_party=hydra,hydra_plugins diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0b44a8eee..02bcb354f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,8 @@ repos: language: system entry: flake8 types: [python] - - repo: https://github.com/pre-commit/mirrors-isort - rev: '' # Use the revision sha / tag you want to point at + # isort to ensure imports remains sorted + - repo: https://github.com/timothycrosley/isort + rev: 'c54b3dd' hooks: - - id: isort \ No newline at end of file + - id: isort \ No newline at end of file diff --git a/docs/source/structured_config.rst b/docs/source/structured_config.rst index ee28abf50..13555d80b 100644 --- a/docs/source/structured_config.rst +++ b/docs/source/structured_config.rst @@ -20,7 +20,7 @@ Two types of structures classes that are supported: dataclasses and attr classes This documentation will use dataclasses, but you can use the annotation `@attr.s(auto_attribs=True)` from attrs instead of `@dataclass`. -Basic usage involves passing in a structured config class or instance to OmegaConf.create(), which will return an OmegaConf config that matches +Basic usage involves passing in a structured config class or instance to OmegaConf.structured(), which will return an OmegaConf config that matches the values and types specified in the input. OmegaConf will validate modifications to the created config object at runtime against the schema specified in the input class. @@ -58,8 +58,8 @@ fields during construction. .. doctest:: - >>> conf1 = OmegaConf.create(SimpleTypes) - >>> conf2 = OmegaConf.create(SimpleTypes()) + >>> conf1 = OmegaConf.structured(SimpleTypes) + >>> conf2 = OmegaConf.structured(SimpleTypes()) >>> # The two configs are identical in this case >>> assert conf1 == conf2 >>> # But the second form allow for easy customization of the values: @@ -80,7 +80,7 @@ Configs in struct mode rejects attempts to access or set fields that are not alr .. doctest:: - >>> conf = OmegaConf.create(SimpleTypes) + >>> conf = OmegaConf.structured(SimpleTypes) >>> with raises(KeyError): ... conf.does_not_exist @@ -91,7 +91,7 @@ Python type annotation can be used by static type checkers like Mypy/Pyre or by .. doctest:: - >>> conf: SimpleTypes = OmegaConf.create(SimpleTypes) + >>> conf: SimpleTypes = OmegaConf.structured(SimpleTypes) >>> # passes static type checking >>> assert conf.description == "text" >>> with raises(ValidationError): @@ -155,7 +155,7 @@ Structured configs can be nested. ... # You can also specify different defaults for nested classes ... manager: User = User(name="manager", height=Height.TALL) - >>> conf : Group = OmegaConf.create(Group) + >>> conf : Group = OmegaConf.structured(Group) >>> print(conf.pretty()) admin: height: ??? @@ -200,7 +200,7 @@ OmegaConf verifies at runtime that your Lists contains only values of the correc .. doctest:: - >>> conf : Lists = OmegaConf.create(Lists) + >>> conf : Lists = OmegaConf.structured(Lists) >>> # Okay, 10 is an int >>> conf.ints.append(10) @@ -232,7 +232,7 @@ OmegaConf supports field modifiers such as MISSING and Optional. ... optional_num: Optional[int] = 10 ... another_num: int = MISSING - >>> conf : Modifiers = OmegaConf.create(Modifiers) + >>> conf : Modifiers = OmegaConf.structured(Modifiers) Mandatory missing values ^^^^^^^^^^^^^^^^^^^^^^^^ @@ -283,7 +283,7 @@ To work around it, use SI and II described below. ... # wrapped with ${} automatically. ... c: int = II("val") - >>> conf : Interpolation = OmegaConf.create(Interpolation) + >>> conf : Interpolation = OmegaConf.structured(Interpolation) >>> assert conf.a == 100 >>> assert conf.b == 100 >>> assert conf.c == 100 @@ -301,7 +301,7 @@ Frozen dataclasses and attr classes are supported via OmegaConf :ref:`read-only- ... x: int = 10 ... list: List = field(default_factory=lambda: [1, 2, 3]) - >>> conf = OmegaConf.create(FrozenClass) + >>> conf = OmegaConf.structured(FrozenClass) >>> with raises(ReadonlyConfigError): ... conf.x = 20 @@ -349,7 +349,7 @@ This will cause a validation error when merging the config from the file with th .. doctest:: - >>> schema = OmegaConf.create(MyConfig) + >>> schema = OmegaConf.structured(MyConfig) >>> conf = OmegaConf.load("source/example.yaml") >>> with raises(ValidationError): ... OmegaConf.merge(schema, conf) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index c6105eb13..42675b2d0 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -151,7 +151,8 @@ See :doc:`structured_config` for more details, or keep reading for a minimal exa ... class MyConfig: ... port: int = 80 ... host: str = "localhost" - >>> conf = OmegaConf.create(MyConfig) + >>> # For strict typing purposes, prefer OmegaConf.structured() when creating structured configs + >>> conf = OmegaConf.structured(MyConfig) >>> print(conf.pretty()) host: localhost port: 80 @@ -161,7 +162,7 @@ You can use an object to initialize the config as well: .. doctest:: - >>> conf = OmegaConf.create(MyConfig(port=443)) + >>> conf = OmegaConf.structured(MyConfig(port=443)) >>> print(conf.pretty()) host: localhost port: 443 diff --git a/news/114.feature b/news/114.feature new file mode 100644 index 000000000..19e93f75a --- /dev/null +++ b/news/114.feature @@ -0,0 +1 @@ +DictConfig and ListConfig now implements typing.MutableMapping and typing.MutableSequence. \ No newline at end of file diff --git a/noxfile.py b/noxfile.py index 447222375..eb95423bf 100644 --- a/noxfile.py +++ b/noxfile.py @@ -35,8 +35,7 @@ def coverage(session): session.run("pip", "install", ".[coverage]", silent=True) session.run("coverage", "erase") session.run("coverage", "run", "--append", "-m", "pytest", silent=True) - # Increase the fail_under as coverage improves - session.run("coverage", "report", "--fail-under=95") + session.run("coverage", "report", "--fail-under=100") # report to coveralls session.run("coveralls", success_codes=[0, 1]) @@ -53,10 +52,9 @@ def lint(session): # if this fails you need to format your code with black session.run("black", "--check", ".") - session.run("mypy", "tests") - session.run("mypy", "omegaconf", "--strict") + session.run("mypy", ".", "--strict") - session.run("isort", "--check") + session.run("isort", ".", "--check") @nox.session(python=PYTHON_VERSIONS) diff --git a/omegaconf/__init__.py b/omegaconf/__init__.py index 785fba1e2..7f13a5ed6 100644 --- a/omegaconf/__init__.py +++ b/omegaconf/__init__.py @@ -1,5 +1,4 @@ -from .base import Node -from .basecontainer import BaseContainer +from .base import Container, Node from .dictconfig import DictConfig from .errors import ( MissingMandatoryValue, @@ -18,7 +17,16 @@ StringNode, ValueNode, ) -from .omegaconf import II, MISSING, SI, OmegaConf, flag_override, open_dict, read_write +from .omegaconf import ( + II, + MISSING, + SI, + OmegaConf, + Resolver, + flag_override, + open_dict, + read_write, +) from .version import __version__ __all__ = [ @@ -28,10 +36,11 @@ "ReadonlyConfigError", "UnsupportedValueType", "UnsupportedKeyType", - "BaseContainer", + "Container", "ListConfig", "DictConfig", "OmegaConf", + "Resolver", "flag_override", "read_write", "open_dict", diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index 3f0b71d6f..f8848a973 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -8,11 +8,13 @@ try: import dataclasses + except ImportError: # pragma: no cover dataclasses = None # type: ignore # pragma: no cover try: import attr + except ImportError: # pragma: no cover attr = None # type: ignore # pragma: no cover @@ -261,10 +263,11 @@ def is_int(st: str) -> bool: # noinspection PyProtectedMember def _re_parent(node: Node) -> None: - from .listconfig import ListConfig from .dictconfig import DictConfig + from .listconfig import ListConfig # update parents of first level Config nodes to self + assert isinstance(node, Node) if isinstance(node, DictConfig): for _key, value in node.__dict__["content"].items(): @@ -274,3 +277,19 @@ def _re_parent(node: Node) -> None: for item in node.__dict__["content"]: item._set_parent(node) _re_parent(item) + + +def is_primitive_list(obj: Any) -> bool: + from .base import Container + + return not isinstance(obj, Container) and isinstance(obj, (list, tuple)) + + +def is_primitive_dict(obj: Any) -> bool: + from .base import Container + + return not isinstance(obj, Container) and isinstance(obj, (dict)) + + +def is_primitive_container(obj: Any) -> bool: + return is_primitive_list(obj) or is_primitive_dict(obj) diff --git a/omegaconf/base.py b/omegaconf/base.py index b90e98ec0..e7966b2bf 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -1,5 +1,5 @@ -from abc import ABC -from typing import Optional +from abc import ABC, abstractmethod +from typing import Any, Iterator, Optional, Union class Node(ABC): @@ -12,7 +12,7 @@ def __init__(self, parent: Optional["Node"]): # set to false: flag is false self.__dict__["flags"] = {} - def _set_parent(self, parent: "Container") -> None: + def _set_parent(self, parent: Optional["Container"]) -> None: assert parent is None or isinstance(parent, Container) self.__dict__["parent"] = parent @@ -60,4 +60,46 @@ class Container(Node): Container tagging interface """ - pass + @abstractmethod + def pretty(self, resolve: bool = False) -> str: + ... # pragma: no cover + + @abstractmethod + def update_node(self, key: str, value: Any = None) -> None: + ... # pragma: no cover + + @abstractmethod + def select(self, key: str) -> Any: + ... # pragma: no cover + + @abstractmethod + def __delitem__(self, key: Union[str, int, slice]) -> None: + ... # pragma: no cover + + @abstractmethod + def __setitem__(self, key: Any, value: Any) -> None: + ... # pragma: no cover + + @abstractmethod + def get_node(self, key: Any) -> Node: + ... # pragma: no cover + + @abstractmethod + def __eq__(self, other: Any) -> bool: + ... # pragma: no cover + + @abstractmethod + def __ne__(self, other: Any) -> bool: + ... # pragma: no cover + + @abstractmethod + def __hash__(self) -> int: + ... # pragma: no cover + + @abstractmethod + def __iter__(self) -> Iterator[str]: + ... # pragma: no cover + + @abstractmethod + def __getitem__(self, key_or_index: Any) -> Any: + ... # pragma: no cover diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 34b0cedc4..ecd828d67 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -1,10 +1,9 @@ import copy import sys import warnings -from abc import abstractmethod from collections import defaultdict from enum import Enum -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import yaml @@ -13,6 +12,7 @@ _re_parent, get_value_kind, get_yaml_loader, + is_primitive_container, is_structured_config, ) from .base import Container, Node @@ -33,34 +33,6 @@ def __init__(self, element_type: type, parent: Optional["Container"]): self.__dict__["_resolver_cache"] = defaultdict(dict) self.__dict__["_element_type"] = element_type - @abstractmethod - def __setitem__(self, key: Any, value: Any) -> None: - ... # pragma: no cover - - @abstractmethod - def get_node(self, key: Any) -> Node: - ... # pragma: no cover - - @abstractmethod - def __eq__(self, other: Any) -> bool: - ... # pragma: no cover - - @abstractmethod - def __ne__(self, other: Any) -> bool: - ... # pragma: no cover - - @abstractmethod - def __hash__(self) -> int: - ... # pragma: no cover - - @abstractmethod - def __iter__(self) -> Iterator[str]: - ... # pragma: no cover - - @abstractmethod - def __getitem__(self, key_or_index: Any) -> Any: - ... # pragma: no cover - def save(self, f: str) -> None: warnings.warn( "Use OmegaConf.save(config, filename) (since 1.4.0)", @@ -94,8 +66,8 @@ def is_mandatory_missing(val: Any) -> bool: return value def get_full_key(self, key: str) -> str: - from .listconfig import ListConfig from .dictconfig import DictConfig + from .listconfig import ListConfig full_key: Union[str, int] = "" child = None @@ -179,11 +151,11 @@ def fail() -> None: value = arg[idx + 1 :] value = yaml.load(value, Loader=get_yaml_loader()) - self.update(key, value) + self.update_node(key, value) - def update(self, key: str, value: Any = None) -> None: - from .listconfig import ListConfig + def update_node(self, key: str, value: Any = None) -> None: from .dictconfig import DictConfig + from .listconfig import ListConfig from .omegaconf import _select_one """Updates a dot separated key sequence to a value""" @@ -244,8 +216,8 @@ def is_empty(self) -> bool: def _to_content( conf: Container, resolve: bool, enum_to_str: bool = False ) -> Union[Dict[str, Any], List[Any]]: - from .listconfig import ListConfig from .dictconfig import DictConfig + from .listconfig import ListConfig def convert(val: Any) -> Any: if enum_to_str: @@ -257,7 +229,7 @@ def convert(val: Any) -> Any: assert isinstance(conf, Container) if isinstance(conf, DictConfig): retdict: Dict[str, Any] = {} - for key, value in conf.items(resolve=resolve): + for key, value in conf.items_ex(resolve=resolve): if isinstance(value, Container): retdict[key] = BaseContainer._to_content( value, resolve=resolve, enum_to_str=enum_to_str @@ -299,7 +271,9 @@ def pretty(self, resolve: bool = False) -> str: :return: A string containing the yaml representation. """ container = OmegaConf.to_container(self, resolve=resolve, enum_to_str=True) - return yaml.dump(container, default_flow_style=False, allow_unicode=True) # type: ignore + return yaml.dump( # type: ignore + container, default_flow_style=False, allow_unicode=True + ) @staticmethod def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None: @@ -311,7 +285,7 @@ def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None: assert isinstance(src, DictConfig) src = copy.deepcopy(src) - for key, value in src.items(resolve=False): + for key, value in src.items_ex(resolve=False): dest_type = dest.__dict__["_element_type"] typed = dest_type not in (None, Any) if (dest.get_node(key) is not None) or typed: @@ -338,13 +312,13 @@ def merge_with( self, *others: Union["BaseContainer", Dict[str, Any], List[Any], Tuple[Any], Any], ) -> None: - from .omegaconf import OmegaConf - from .listconfig import ListConfig from .dictconfig import DictConfig + from .listconfig import ListConfig + from .omegaconf import OmegaConf """merge a list of other Config objects into this one, overriding as needed""" for other in others: - if isinstance(other, (dict, list, tuple)) or is_structured_config(other): + if is_primitive_container(other) or is_structured_config(other): other = OmegaConf.create(other) if other is None: @@ -414,9 +388,10 @@ def _resolve_single(self, value: Any) -> Any: # noinspection PyProtectedMember def _set_item_impl(self, key: Union[str, int], value: Any) -> None: from omegaconf.omegaconf import _maybe_wrap + from .nodes import ValueNode - must_wrap = isinstance(value, (dict, list)) + must_wrap = is_primitive_container(value) input_config = isinstance(value, Container) input_node = isinstance(value, ValueNode) if isinstance(self.__dict__["content"], dict): @@ -520,8 +495,8 @@ def _dict_conf_eq(d1: "BaseContainer", d2: "BaseContainer") -> bool: @staticmethod def _config_eq(c1: "BaseContainer", c2: "BaseContainer") -> bool: - from .listconfig import ListConfig from .dictconfig import DictConfig + from .listconfig import ListConfig assert isinstance(c1, Container) assert isinstance(c2, Container) diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index a26a476bc..157d6af5d 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -1,10 +1,21 @@ import copy from enum import Enum -from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union +from typing import ( + Any, + Dict, + Iterable, + Iterator, + List, + MutableMapping, + Optional, + Tuple, + Union, +) from ._utils import ( _re_parent, get_structured_config_data, + is_primitive_dict, is_structured_config, is_structured_config_frozen, ) @@ -21,7 +32,7 @@ from .nodes import ValueNode -class DictConfig(BaseContainer): +class DictConfig(BaseContainer, MutableMapping[str, Any]): def __init__( self, content: Union[Dict[str, Any], Any], @@ -49,6 +60,10 @@ def __init__( for k, v in content.items(): self.__setitem__(k, v) + if isinstance(content, BaseContainer): + for field in ["flags", "_element_type", "_resolver_cache"]: + self.__dict__[field] = copy.deepcopy(content.__dict__[field]) + def __deepcopy__(self, memo: Dict[int, Any] = {}) -> "DictConfig": res = DictConfig({}) res.__dict__["content"] = copy.deepcopy(self.__dict__["content"], memo=memo) @@ -185,7 +200,7 @@ def pop(self, key: Union[str, Enum], default: Any = __marker) -> Any: def keys(self) -> Any: return self.content.keys() - def __contains__(self, key: Union[str, Enum]) -> bool: + def __contains__(self, key: object) -> bool: """ A key is contained in a DictConfig if there is an associated value and it is not a mandatory missing value ('???'). @@ -197,6 +212,7 @@ def __contains__(self, key: Union[str, Enum]) -> bool: if isinstance(key, Enum): str_key = key.name else: + assert isinstance(key, str) str_key = key try: @@ -219,9 +235,13 @@ def __contains__(self, key: Union[str, Enum]) -> bool: def __iter__(self) -> Iterator[str]: return iter(self.keys()) - def items( + # TODO: figure out why this is incompatible with Mapping + def items(self) -> Iterator[Tuple[str, Any]]: # type: ignore + return self.items_ex(resolve=True, keys=None) + + def items_ex( self, resolve: bool = True, keys: Optional[List[str]] = None - ) -> Iterator[Any]: + ) -> Iterator[Tuple[str, Any]]: class MyItems(Iterator[Any]): def __init__(self, m: DictConfig) -> None: self.map = m @@ -251,7 +271,7 @@ def _next_pair(self) -> Tuple[str, Any]: return MyItems(self) def __eq__(self, other: Any) -> bool: - if isinstance(other, dict): + if is_primitive_dict(other): return BaseContainer._dict_conf_eq(self, DictConfig(other)) if isinstance(other, DictConfig): return BaseContainer._dict_conf_eq(self, other) @@ -274,9 +294,12 @@ def _validate_access(self, key: str) -> None: if is_typed and node_open: return if is_typed or is_closed: - msg = "Accessing unknown key in a struct : {}".format( - self.get_full_key(key) - ) + if is_typed: + msg = f"Accessing unknown key in {self.__dict__['_type'].__name__} : {self.get_full_key(key)}" + else: + msg = "Accessing unknown key in a struct : {}".format( + self.get_full_key(key) + ) if is_closed: raise AttributeError(msg) else: diff --git a/omegaconf/listconfig.py b/omegaconf/listconfig.py index b1b6b06a8..0669cb8c5 100644 --- a/omegaconf/listconfig.py +++ b/omegaconf/listconfig.py @@ -1,15 +1,26 @@ import copy import itertools -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union - -from ._utils import _re_parent, isint +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + MutableSequence, + Optional, + Tuple, + Union, +) + +from ._utils import _re_parent, is_primitive_list, isint from .base import Container, Node from .basecontainer import BaseContainer from .errors import ReadonlyConfigError, UnsupportedKeyType, UnsupportedValueType from .nodes import AnyNode, ValueNode -class ListConfig(BaseContainer): +class ListConfig(BaseContainer, MutableSequence[Any]): def __init__( self, content: Union[List[Any], Tuple[Any, ...]], @@ -18,7 +29,7 @@ def __init__( ) -> None: super().__init__(parent=parent, element_type=element_type) self.__dict__["content"] = [] - assert isinstance(content, (list, tuple)) + assert is_primitive_list(content) or isinstance(content, ListConfig) for item in content: self.append(item) @@ -45,7 +56,7 @@ def __dir__(self) -> Iterable[str]: def __len__(self) -> int: return len(self.content) - def __getitem__(self, index: int) -> Any: + def __getitem__(self, index: Union[int, slice]) -> Any: assert isinstance(index, (int, slice)) if isinstance(index, slice): result = [] @@ -62,7 +73,7 @@ def __getitem__(self, index: int) -> Any: key=index, value=self.content[index], default_value=None ) - def _set_at_index(self, index: int, value: Any) -> None: + def _set_at_index(self, index: Union[int, slice], value: Any) -> None: if not isinstance(index, int): raise UnsupportedKeyType(f"Key type {type(index).__name__} is not an int") @@ -84,7 +95,7 @@ def _set_at_index(self, index: int, value: Any) -> None: ) ) - def __setitem__(self, index: int, value: Any) -> None: + def __setitem__(self, index: Union[int, slice], value: Any) -> None: self._set_at_index(index, value) def append(self, item: Any) -> None: @@ -118,7 +129,7 @@ def insert(self, index: int, item: Any) -> None: del self.__dict__["content"][index] raise - def extend(self, lst: Union[List[Any], Tuple[Any, ...], "ListConfig"]) -> None: + def extend(self, lst: Iterable[Any]) -> None: assert isinstance(lst, (tuple, list, ListConfig)) for x in lst: self.append(x) @@ -129,9 +140,18 @@ def remove(self, x: Any) -> None: def clear(self) -> None: del self[:] - def index(self, x: Any) -> int: + def index( + self, x: Any, start: Optional[int] = None, end: Optional[int] = None, + ) -> int: + if start is None: + start = 0 + if end is None: + end = len(self) + assert start >= 0 + assert end <= len(self) found_idx = -1 - for idx, item in enumerate(self): + for idx in range(start, end): + item = self[idx] if x == item: found_idx = idx break @@ -220,9 +240,19 @@ def next(self) -> Any: return MyItems(self.content) - def __add__(self, o: List[Any]) -> "ListConfig": + def __add__(self, other: Union[List[Any], "ListConfig"]) -> "ListConfig": # res is sharing this list's parent to allow interpolation to work as expected res = ListConfig(parent=self._get_parent(), content=[]) res.extend(self) - res.extend(o) + res.extend(other) return res + + def __iadd__(self, other: Iterable[Any]) -> "ListConfig": + self.extend(other) + return self + + def __contains__(self, item: Any) -> bool: + for x in iter(self): + if x == item: + return True + return False diff --git a/omegaconf/nodes.py b/omegaconf/nodes.py index f95238789..7628fe91b 100644 --- a/omegaconf/nodes.py +++ b/omegaconf/nodes.py @@ -49,7 +49,6 @@ def __repr__(self) -> str: return repr(self.val) if hasattr(self, "val") else "__INVALID__" def __eq__(self, other: Any) -> bool: - # TODO: this type ignore makes no sense. return self.val == other # type: ignore def __ne__(self, other: Any) -> bool: @@ -135,9 +134,7 @@ def validate_and_convert(self, value: Any) -> Optional[int]: else: raise ValueError() except ValueError: - raise ValidationError( - "Value '{}' could not be converted to Integer".format(value) - ) + raise ValidationError(f"Value '{value}' could not be converted to Integer") return val def __deepcopy__(self, memo: Dict[int, Any] = {}) -> "IntegerNode": @@ -166,9 +163,7 @@ def validate_and_convert(self, value: Any) -> Optional[float]: else: raise ValueError() except ValueError: - raise ValidationError( - "Value '{}' could not be converted to float".format(value) - ) + raise ValidationError(f"Value '{value}' could not be converted to float") def __eq__(self, other: Any) -> bool: if isinstance(other, ValueNode): @@ -223,9 +218,7 @@ def validate_and_convert(self, value: Any) -> Optional[bool]: ) else: raise ValidationError( - "Value '{}' is not a valid bool (type {})".format( - value, type(value).__name__ - ) + f"Value '{value}' is not a valid bool (type {type(value).__name__})" ) def __deepcopy__(self, memo: Dict[int, Any] = {}) -> "BooleanNode": @@ -251,7 +244,7 @@ def __init__( super().__init__(parent=parent, is_optional=is_optional) if not isinstance(enum_type, type) or not issubclass(enum_type, Enum): raise ValidationError( - "EnumNode can only operate on Enum subclasses ({})".format(enum_type) + f"EnumNode can only operate on Enum subclasses ({enum_type})" ) self.fields: Dict[str, str] = {} self.val = None @@ -266,9 +259,7 @@ def validate_and_convert(self, value: Any) -> Optional[Enum]: type_ = type(value) if not issubclass(type_, self.enum_type) and type_ not in (str, int): raise ValidationError( - "Value {} ({}) is not a valid input for {}".format( - value, type_, self.enum_type - ) + f"Value {value} ({type_}) is not a valid input for {self.enum_type}" ) if isinstance(value, self.enum_type): diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index f9591f753..b64f42c2c 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -6,12 +6,32 @@ import sys from contextlib import contextmanager from enum import Enum -from typing import Any, Callable, Dict, Generator, List, Match, Optional, Tuple, Union +from typing import ( + IO, + Any, + Callable, + Dict, + Generator, + List, + Match, + Optional, + Tuple, + Union, + overload, +) import yaml +from typing_extensions import Protocol from . import DictConfig, ListConfig -from ._utils import decode_primitive, is_structured_config, isint +from ._utils import ( + decode_primitive, + is_primitive_container, + is_primitive_dict, + is_primitive_list, + is_structured_config, + isint, +) from .base import Container, Node from .basecontainer import BaseContainer from .errors import MissingMandatoryValue, ValidationError @@ -43,10 +63,32 @@ def SI(interpolation: str) -> Any: :param interpolation: interpolation string :return: input interpolation with type Any """ - assert interpolation.find("${") != -1 return interpolation +class Resolver0(Protocol): + def __call__(self) -> Any: + ... # pragma: no cover + + +class Resolver1(Protocol): + def __call__(self, __x1: str) -> Any: + ... # pragma: no cover + + +class Resolver2(Protocol): + def __call__(self, __x1: str, __x2: str) -> Any: + ... # pragma: no cover + + +class Resolver3(Protocol): + def __call__(self, __x1: str, __x2: str, __x3: str) -> Any: + ... # pragma: no cover + + +Resolver = Union[Resolver0, Resolver1, Resolver2, Resolver3] + + def register_default_resolvers() -> None: def env(key: str) -> Any: try: @@ -64,47 +106,85 @@ def __init__(self) -> None: raise NotImplementedError("Use one of the static construction functions") @staticmethod - def create( + def structured(obj: Any, parent: Optional[BaseContainer] = None) -> DictConfig: + assert is_structured_config(obj) + ret = OmegaConf.create(obj, parent) + assert isinstance(ret, DictConfig) + return ret + + @staticmethod + @overload + def create( # noqa F811 + obj: Union[List[Any], Tuple[Any, ...]], parent: Optional[BaseContainer] = None + ) -> ListConfig: + ... # pragma: no cover + + @staticmethod + @overload + def create( # noqa F811 + obj: Union[BaseContainer, str], parent: Optional[BaseContainer] = None, + ) -> Union[DictConfig, ListConfig]: + ... # pragma: no cover + + @staticmethod + @overload + def create( # noqa F811 + obj: Union[Dict[str, Any], None] = None, parent: Optional[BaseContainer] = None, + ) -> DictConfig: + ... # pragma: no cover + + @staticmethod + def create( # noqa F811 obj: Any = None, parent: Optional[BaseContainer] = None ) -> Union[DictConfig, ListConfig]: + from ._utils import get_yaml_loader from .dictconfig import DictConfig from .listconfig import ListConfig - from ._utils import get_yaml_loader if isinstance(obj, str): - new_obj = yaml.load(obj, Loader=get_yaml_loader()) - if new_obj is None: - new_obj = {} - elif isinstance(new_obj, str): - new_obj = {obj: None} - return OmegaConf.create(new_obj) + obj = yaml.load(obj, Loader=get_yaml_loader()) + new_obj: Dict[str, Any] + if obj is None: + return OmegaConf.create({}) + elif isinstance(obj, str): + return OmegaConf.create({obj: None}) + else: + assert isinstance(obj, (list, dict)) + return OmegaConf.create(obj) + else: if obj is None: obj = {} - if isinstance(obj, BaseContainer): - obj = OmegaConf.to_container(obj) - if isinstance(obj, dict) or is_structured_config(obj): + if ( + is_primitive_dict(obj) + or OmegaConf.is_dict(obj) + or is_structured_config(obj) + ): return DictConfig(obj, parent) - elif isinstance(obj, list) or isinstance(obj, tuple): + elif is_primitive_list(obj) or OmegaConf.is_list(obj): return ListConfig(obj, parent) else: raise ValidationError("Unsupported type {}".format(type(obj).__name__)) @staticmethod - def load(file_: str) -> Union[DictConfig, ListConfig]: + def load(file_: Union[str, IO[bytes]]) -> Union[DictConfig, ListConfig]: from ._utils import get_yaml_loader if isinstance(file_, str): with io.open(os.path.abspath(file_), "r", encoding="utf-8") as f: - return OmegaConf.create(yaml.load(f, Loader=get_yaml_loader())) + obj = yaml.load(f, Loader=get_yaml_loader()) + assert isinstance(obj, (list, dict)) + return OmegaConf.create(obj) elif getattr(file_, "read", None): - return OmegaConf.create(yaml.load(file_, Loader=get_yaml_loader())) + obj = yaml.load(file_, Loader=get_yaml_loader()) + assert isinstance(obj, (list, dict)) + return OmegaConf.create(obj) else: raise TypeError("Unexpected file type") @staticmethod - def save(config: BaseContainer, f: str, resolve: bool = False) -> None: + def save(config: Container, f: Union[str, IO[str]], resolve: bool = False) -> None: """ Save as configuration object to a file :param config: omegaconf.Config object (DictConfig or ListConfig). @@ -137,17 +217,18 @@ def from_dotlist(dotlist: List[str]) -> DictConfig: """ conf = OmegaConf.create() conf.merge_with_dotlist(dotlist) - return conf # type: ignore + return conf @staticmethod def merge( - *others: Union[BaseContainer, Dict[str, Any], List[Any], Tuple[Any], Any] - ) -> BaseContainer: + *others: Union[BaseContainer, Dict[str, Any], List[Any], Tuple[Any, ...], Any] + ) -> Union[ListConfig, DictConfig]: """Merge a list of previously created configs into a single one""" assert len(others) > 0 target = copy.deepcopy(others[0]) - if isinstance(target, (dict, list, tuple)) or is_structured_config(target): + if is_primitive_container(target) or is_structured_config(target): target = OmegaConf.create(target) + assert isinstance(target, (DictConfig, ListConfig)) target.merge_with(*others[1:]) return target @@ -166,7 +247,7 @@ def _unescape_word_boundary(match: Match[str]) -> str: return [re.sub(r"(\\([ ,]))", lambda x: x.group(2), x) for x in escaped] @staticmethod - def register_resolver(name: str, resolver: Callable[[Any], Any]) -> None: + def register_resolver(name: str, resolver: Resolver) -> None: assert callable(resolver), "resolver must be callable" # noinspection PyProtectedMember assert ( @@ -184,6 +265,7 @@ def caching(config: BaseContainer, key: str) -> Any: # noinspection PyProtectedMember BaseContainer._resolvers[name] = caching + # TODO : improve this API (return type seems wrong) # noinspection PyProtectedMember @staticmethod def get_resolver(name: str) -> Optional[Callable[[Container, Any], Any]]: @@ -244,12 +326,12 @@ def masked_copy(conf: DictConfig, keys: Union[str, List[str]]) -> DictConfig: if isinstance(keys, str): keys = [keys] - content = {key: value for key, value in conf.items(resolve=False, keys=keys)} + content = {key: value for key, value in conf.items_ex(resolve=False, keys=keys)} return DictConfig(content=content) @staticmethod def to_container( - cfg: BaseContainer, resolve: bool = False, enum_to_str: bool = False + cfg: Container, resolve: bool = False, enum_to_str: bool = False ) -> Union[Dict[str, Any], List[Any]]: """ Resursively converts an OmegaConf config to a primitive container (dict or list). @@ -258,7 +340,7 @@ def to_container( :param enum_to_str: True to convert Enum values to strings :return: A dict or a list representing this config as a primitive container. """ - assert isinstance(cfg, BaseContainer) + assert isinstance(cfg, Container) # noinspection PyProtectedMember return BaseContainer._to_content(cfg, resolve=resolve, enum_to_str=enum_to_str) @@ -284,9 +366,9 @@ def is_dict(obj: Any) -> bool: @staticmethod def is_config(obj: Any) -> bool: - from . import BaseContainer + from . import Container - return isinstance(obj, BaseContainer) + return isinstance(obj, Container) # register all default resolvers @@ -432,8 +514,8 @@ def _maybe_wrap( def _select_one(c: BaseContainer, key: str) -> Tuple[Any, Union[str, int]]: - from .listconfig import ListConfig from .dictconfig import DictConfig + from .listconfig import ListConfig ret_key: Union[str, int] = key assert isinstance(c, (DictConfig, ListConfig)), f"Unexpected type : {c}" @@ -452,5 +534,7 @@ def _select_one(c: BaseContainer, key: str) -> Tuple[Any, Union[str, int]]: val = None else: val = c[ret_key] + else: + assert False # pragma: no cover return val, ret_key diff --git a/omegaconf/version.py b/omegaconf/version.py index 040726379..ae5fe967d 100644 --- a/omegaconf/version.py +++ b/omegaconf/version.py @@ -1,6 +1,6 @@ import sys # pragma: no cover -__version__ = "2.0.0rc2" +__version__ = "2.0.0rc3" msg = """OmegaConf 2.0 and above is compatible with Python 3.6 and newer. You have the following options: diff --git a/setup.cfg b/setup.cfg index 1a63147e4..152465c9d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -4,4 +4,8 @@ test=pytest [mypy] python_version = 3.6 mypy_path=.stubs -strict_equality=True \ No newline at end of file +strict_equality=True +warn_unused_configs = True + +;[mypy-tests.*] +;disallow_untyped_decorators: True \ No newline at end of file diff --git a/setup.py b/setup.py index e1e99669a..f14321eea 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,4 @@ +# type: ignore """ OmegaConf setup Instructions: @@ -7,7 +8,6 @@ # Upload: twine upload dist/* """ - import codecs import os import re @@ -56,6 +56,7 @@ def read(*parts): "PyYAML", # Use dataclasses backport for Python 3.6. "dataclasses;python_version=='3.6'", + "typing-extensions", ], # Install development dependencies with # pip install -e ".[dev]" @@ -64,6 +65,7 @@ def read(*parts): "black", "coveralls", "flake8", + "pyflakes@git+git://github.com/pycqa/pyflakes.git@1911c20#egg=pyflakes", "pre-commit", "pytest", "pytest-mock", @@ -72,10 +74,17 @@ def read(*parts): "twine", "sphinx", "mypy", - "isort", + "isort@git+git://github.com/timothycrosley/isort.git@c54b3dd#egg=isort", ], "coverage": ["coveralls"], - "lint": ["pytest", "black", "flake8", "mypy", "isort"], + "lint": [ + "pytest", + "black", + "flake8", + "pyflakes@git+git://github.com/pycqa/pyflakes.git@1911c20#egg=pyflakes", + "mypy", + "isort@git+git://github.com/timothycrosley/isort.git@c54b3dd#egg=isort", + ], }, package_data={"omegaconf": ["py.typed"]}, ) diff --git a/tests/__init__.py b/tests/__init__.py index 923b7f019..ba95b7ee8 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,11 +1,12 @@ from contextlib import contextmanager +from typing import Any, Iterator class IllegalType: - def __init__(self): + def __init__(self) -> None: pass @contextmanager -def does_not_raise(enter_result=None): +def does_not_raise(enter_result: Any = None) -> Iterator[Any]: yield enter_result diff --git a/tests/examples/test_dataclass_example.py b/tests/examples/test_dataclass_example.py index 8e2ac13cd..957817787 100644 --- a/tests/examples/test_dataclass_example.py +++ b/tests/examples/test_dataclass_example.py @@ -3,7 +3,6 @@ from typing import Dict, List, Optional import pytest - from omegaconf import ( MISSING, MissingMandatoryValue, diff --git a/tests/structured_conf/attr_test_data.py b/tests/structured_conf/attr_test_data.py index 355fdfcd7..d065367ce 100644 --- a/tests/structured_conf/attr_test_data.py +++ b/tests/structured_conf/attr_test_data.py @@ -1,8 +1,7 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import attr # noqaE402 -import pytest # type: ignore - +import pytest from omegaconf import II, MISSING, SI from .common import Color @@ -122,17 +121,13 @@ class EnumConfig: @attr.s(auto_attribs=True) class ConfigWithList: - list1: List = [1, 2, 3] - list2: list = [1, 2, 3] - # tuples are converted to ListConfig - list3: tuple = (1, 2, 3) - list4: List[int] = [1, 2, 3] + list1: List[int] = [1, 2, 3] + list2: Tuple[int, int, int] = (1, 2, 3) @attr.s(auto_attribs=True) class ConfigWithDict: - dict1: Dict = {"foo": "bar"} - dict2: dict = {"foo": "bar"} + dict1: Dict[str, Any] = {"foo": "bar"} @attr.s(auto_attribs=True) @@ -230,7 +225,7 @@ class EnumOptional: class FrozenClass: user: User = User(name="Bart", age=10) x: int = 10 - list: List = [1, 2, 3] + list: List[int] = [1, 2, 3] @attr.s(auto_attribs=True) @@ -267,8 +262,7 @@ class ErrorListUnsupportedValue: @attr.s(auto_attribs=True) class ListExamples: - any1: List = [1, "foo"] - any2: List[Any] = [1, "foo"] + any: List[Any] = [1, "foo"] ints: List[int] = [1, 2] strings: List[str] = ["foo", "bar"] booleans: List[bool] = [True, False] @@ -277,8 +271,7 @@ class ListExamples: @attr.s(auto_attribs=True) class DictExamples: - any1: Dict = {"a": 1, "b": "foo"} - any2: Dict[str, Any] = {"a": 1, "b": "foo"} + any: Dict[str, Any] = {"a": 1, "b": "foo"} ints: Dict[str, int] = {"a": 10, "b": 20} strings: Dict[str, str] = {"a": "foo", "b": "bar"} booleans: Dict[str, bool] = {"a": True, "b": False} diff --git a/tests/structured_conf/dataclass_test_data.py b/tests/structured_conf/dataclass_test_data.py index 631bffb32..47ed177a9 100644 --- a/tests/structured_conf/dataclass_test_data.py +++ b/tests/structured_conf/dataclass_test_data.py @@ -1,8 +1,7 @@ from dataclasses import dataclass, field # noqaE402 -from typing import Any, Dict, List, Optional - -import pytest # type: ignore +from typing import Any, Dict, List, Optional, Tuple +import pytest from omegaconf import II, MISSING, SI from .common import Color @@ -117,17 +116,13 @@ class EnumConfig: @dataclass class ConfigWithList: - list1: List = field(default_factory=lambda: [1, 2, 3]) - list2: list = field(default_factory=lambda: [1, 2, 3]) - # tuples are converted to ListConfig - list3: tuple = field(default_factory=lambda: (1, 2, 3)) - list4: List[int] = field(default_factory=lambda: [1, 2, 3]) + list1: List[int] = field(default_factory=lambda: [1, 2, 3]) + list2: Tuple[int, int, int] = field(default_factory=lambda: (1, 2, 3)) @dataclass class ConfigWithDict: - dict1: Dict = field(default_factory=lambda: {"foo": "bar"}) - dict2: dict = field(default_factory=lambda: {"foo": "bar"}) + dict1: Dict[str, Any] = field(default_factory=lambda: {"foo": "bar"}) @dataclass @@ -225,7 +220,7 @@ class EnumOptional: class FrozenClass: user: User = User(name="Bart", age=10) x: int = 10 - list: List = field(default_factory=lambda: [1, 2, 3]) + list: List[int] = field(default_factory=lambda: [1, 2, 3]) @dataclass @@ -268,8 +263,7 @@ class ErrorListUnsupportedStructuredConfig: @dataclass class ListExamples: - any1: List = field(default_factory=lambda: [1, "foo"]) - any2: List[Any] = field(default_factory=lambda: [1, "foo"]) + any: List[Any] = field(default_factory=lambda: [1, "foo"]) ints: List[int] = field(default_factory=lambda: [1, 2]) strings: List[str] = field(default_factory=lambda: ["foo", "bar"]) booleans: List[bool] = field(default_factory=lambda: [True, False]) @@ -278,8 +272,7 @@ class ListExamples: @dataclass class DictExamples: - any1: Dict = field(default_factory=lambda: {"a": 1, "b": "foo"}) - any2: Dict[str, Any] = field(default_factory=lambda: {"a": 1, "b": "foo"}) + any: Dict[str, Any] = field(default_factory=lambda: {"a": 1, "b": "foo"}) ints: Dict[str, int] = field(default_factory=lambda: {"a": 10, "b": 20}) strings: Dict[str, str] = field(default_factory=lambda: {"a": "foo", "b": "bar"}) booleans: Dict[str, bool] = field(default_factory=lambda: {"a": True, "b": False}) diff --git a/tests/structured_conf/test_config_eq.py b/tests/structured_conf/test_config_eq.py new file mode 100644 index 000000000..5ce8ad990 --- /dev/null +++ b/tests/structured_conf/test_config_eq.py @@ -0,0 +1,109 @@ +from typing import Any, List + +import pytest +from omegaconf import AnyNode, OmegaConf +from omegaconf.basecontainer import BaseContainer + + +@pytest.mark.parametrize( # type: ignore + "l1,l2", + [ + # === LISTS === + # empty list + ([], []), + # simple list + (["a", 12, "15"], ["a", 12, "15"]), + # raw vs any + ([1, 2, 12], [1, 2, AnyNode(12)]), + # nested empty dict + ([12, dict()], [12, dict()]), + # nested dict + ([12, dict(c=10)], [12, dict(c=10)]), + # nested list + ([1, 2, 3, [10, 20, 30]], [1, 2, 3, [10, 20, 30]]), + # nested list with any + ([1, 2, 3, [1, 2, AnyNode(3)]], [1, 2, 3, [1, 2, AnyNode(3)]]), + # === DICTS == + # empty + (dict(), dict()), + # simple + (dict(a=12), dict(a=12)), + # any vs raw + (dict(a=12), dict(a=AnyNode(12))), + # nested dict empty + (dict(a=12, b=dict()), dict(a=12, b=dict())), + # nested dict + (dict(a=12, b=dict(c=10)), dict(a=12, b=dict(c=10))), + # nested list + (dict(a=12, b=[1, 2, 3]), dict(a=12, b=[1, 2, 3])), + # nested list with any + (dict(a=12, b=[1, 2, AnyNode(3)]), dict(a=12, b=[1, 2, AnyNode(3)])), + # In python 3.6 insert order changes iteration order. this ensures that equality is preserved. + (dict(a=1, b=2, c=3, d=4, e=5), dict(e=5, b=2, c=3, d=4, a=1)), + # With interpolations + ([10, "${0}"], [10, 10]), + (dict(a=12, b="${a}"), dict(a=12, b=12)), + ], +) +def test_list_eq(l1: List[Any], l2: List[Any]) -> None: + c1 = OmegaConf.create(l1) + c2 = OmegaConf.create(l2) + + def eq(a: Any, b: Any) -> None: + assert a == b + assert b == a + assert not a != b + assert not b != a + + eq(c1, c2) + eq(c1, l1) + eq(c2, l2) + + +@pytest.mark.parametrize( # type: ignore + "input1, input2", + [ + # Dicts + (dict(), dict(a=10)), + ({}, []), + (dict(a=12), dict(a=13)), + (dict(a=0), dict(b=0)), + (dict(a=12), dict(a=AnyNode(13))), + (dict(a=12, b=dict()), dict(a=13, b=dict())), + (dict(a=12, b=dict(c=10)), dict(a=13, b=dict(c=10))), + (dict(a=12, b=[1, 2, 3]), dict(a=12, b=[10, 2, 3])), + (dict(a=12, b=[1, 2, AnyNode(3)]), dict(a=12, b=[1, 2, AnyNode(30)])), + # Lists + ([], [10]), + ([10], [11]), + ([12], [AnyNode(13)]), + ([12, dict()], [13, dict()]), + ([12, dict(c=10)], [13, dict(c=10)]), + ([12, [1, 2, 3]], [12, [10, 2, 3]]), + ([12, [1, 2, AnyNode(3)]], [12, [1, 2, AnyNode(30)]]), + ], +) +def test_not_eq(input1: Any, input2: Any) -> None: + c1 = OmegaConf.create(input1) + c2 = OmegaConf.create(input2) + + def neq(a: Any, b: Any) -> None: + assert a != b + assert b != a + assert not a == b + assert not b == a + + neq(c1, c2) + + +# --- +def test_config_eq_mismatch_types() -> None: + c1 = OmegaConf.create({}) + c2 = OmegaConf.create([]) + assert not BaseContainer._config_eq(c1, c2) + assert not BaseContainer._config_eq(c2, c1) + + +def test_dict_not_eq_with_another_class() -> None: + assert OmegaConf.create({}) != "string" + assert OmegaConf.create([]) != "string" diff --git a/tests/test_base_config.py b/tests/test_base_config.py index a409029a8..37c5678a2 100644 --- a/tests/test_base_config.py +++ b/tests/test_base_config.py @@ -1,9 +1,10 @@ import copy +from typing import Any, Dict, List, Union import pytest - from omegaconf import ( MISSING, + Container, DictConfig, IntegerNode, ListConfig, @@ -19,7 +20,7 @@ from . import does_not_raise -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "input_, key, value, expected", [ # dict @@ -37,13 +38,15 @@ ([1, StringNode("str")], 1, IntegerNode(10), [1, 10]), ], ) -def test_set_value(input_, key, value, expected): +def test_set_value( + input_: Any, key: Union[str, int], value: Any, expected: Any +) -> None: c = OmegaConf.create(input_) c[key] = value assert c == expected -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "input_, key, value", [ # dict @@ -52,13 +55,13 @@ def test_set_value(input_, key, value, expected): ([1, IntegerNode(10)], 1, "str"), ], ) -def test_set_value_validation_fail(input_, key, value): +def test_set_value_validation_fail(input_: Any, key: Any, value: Any) -> None: c = OmegaConf.create(input_) with pytest.raises(ValidationError): c[key] = value -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "input_, key, value", [ # dict @@ -67,14 +70,16 @@ def test_set_value_validation_fail(input_, key, value): ([1, IntegerNode(10)], 1, StringNode("str")), ], ) -def test_replace_value_node_type_with_another(input_, key, value): +def test_replace_value_node_type_with_another( + input_: Any, key: Any, value: Any +) -> None: c = OmegaConf.create(input_) c[key] = value assert c[key] == value assert c[key] == value.value() -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "input_", [ [1, 2, 3], @@ -84,8 +89,8 @@ def test_replace_value_node_type_with_another(input_, key, value): dict(b=[1, 2, 3]), ], ) -def test_to_container_returns_primitives(input_): - def assert_container_with_primitives(container): +def test_to_container_returns_primitives(input_: Any) -> None: + def assert_container_with_primitives(container: Any) -> None: if isinstance(container, list): for v in container: assert_container_with_primitives(v) @@ -104,15 +109,15 @@ def assert_container_with_primitives(container): assert_container_with_primitives(res) -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "input_, is_empty", [([], True), ({}, True), ([1, 2], False), (dict(a=10), False)] ) -def test_empty(input_, is_empty): +def test_empty(input_: Any, is_empty: bool) -> None: c = OmegaConf.create(input_) assert c.is_empty() == is_empty -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "input_", [ [], @@ -124,12 +129,12 @@ def test_empty(input_, is_empty): dict(b=[1, 2, 3]), ], ) -def test_repr(input_): +def test_repr(input_: Any) -> None: c = OmegaConf.create(input_) assert repr(input_) == repr(c) -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "input_", [ [], @@ -141,13 +146,13 @@ def test_repr(input_): dict(b=[1, 2, 3]), ], ) -def test_str(input_): +def test_str(input_: Any) -> None: c = OmegaConf.create(input_) assert str(input_) == str(c) -@pytest.mark.parametrize("flag", ["readonly", "struct"]) -def test_flag_dict(flag): +@pytest.mark.parametrize("flag", ["readonly", "struct"]) # type: ignore +def test_flag_dict(flag: str) -> None: c = OmegaConf.create() assert c._get_flag(flag) is None c._set_flag(flag, True) @@ -158,8 +163,8 @@ def test_flag_dict(flag): assert c._get_flag(flag) is None -@pytest.mark.parametrize("flag", ["readonly", "struct"]) -def test_freeze_nested_dict(flag): +@pytest.mark.parametrize("flag", ["readonly", "struct"]) # type: ignore +def test_freeze_nested_dict(flag: str) -> None: c = OmegaConf.create(dict(a=dict(b=2))) assert not c._get_flag(flag) assert not c.a._get_flag(flag) @@ -179,7 +184,7 @@ def test_freeze_nested_dict(flag): @pytest.mark.parametrize("src", [[], [1, 2, 3], dict(), dict(a=10)]) class TestDeepCopy: - def test_deepcopy(self, src): + def test_deepcopy(self, src: Any) -> None: c1 = OmegaConf.create(src) c2 = copy.deepcopy(c1) assert c1 == c2 @@ -189,7 +194,7 @@ def test_deepcopy(self, src): c2.foo = "bar" assert c1 != c2 - def test_deepcopy_readonly(self, src): + def test_deepcopy_readonly(self, src: Any) -> None: c1 = OmegaConf.create(src) OmegaConf.set_readonly(c1, True) c2 = copy.deepcopy(c1) @@ -202,7 +207,7 @@ def test_deepcopy_readonly(self, src): c2.foo = "bar" assert c1 == c2 - def test_deepcopy_struct(self, src): + def test_deepcopy_struct(self, src: Any) -> None: c1 = OmegaConf.create(src) OmegaConf.set_struct(c1, True) c2 = copy.deepcopy(c1) @@ -214,7 +219,7 @@ def test_deepcopy_struct(self, src): c2.foo = "bar" -def test_deepcopy_after_del(): +def test_deepcopy_after_del() -> None: # make sure that deepcopy does not resurrect deleted fields (as it once did, believe it or not). c1 = OmegaConf.create(dict(foo=[1, 2, 3], bar=10)) c2 = copy.deepcopy(c1) @@ -224,7 +229,7 @@ def test_deepcopy_after_del(): assert c1 == c3 -def test_deepcopy_with_interpolation(): +def test_deepcopy_with_interpolation() -> None: c1 = OmegaConf.create(dict(a=dict(b="${c}"), c=10)) assert c1.a.b == 10 c2 = copy.deepcopy(c1) @@ -232,7 +237,7 @@ def test_deepcopy_with_interpolation(): # Yes, there was a bug that was a combination of an interaction between the three -def test_deepcopy_and_merge_and_flags(): +def test_deepcopy_and_merge_and_flags() -> None: c1 = OmegaConf.create( {"dataset": {"name": "imagenet", "path": "/datasets/imagenet"}, "defaults": []} ) @@ -242,19 +247,19 @@ def test_deepcopy_and_merge_and_flags(): OmegaConf.merge(c2, OmegaConf.from_dotlist(["dataset.bad_key=yes"])) -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "cfg", [ ListConfig(content=[], element_type=int), DictConfig(content={}, element_type=int), ], ) -def test_deepcopy_preserves_container_type(cfg): +def test_deepcopy_preserves_container_type(cfg: Container) -> None: cp = copy.deepcopy(cfg) assert cp.__dict__["_element_type"] == cfg.__dict__["_element_type"] -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "src, flag_name, flag_value, func, expectation", [ ( @@ -273,7 +278,9 @@ def test_deepcopy_preserves_container_type(cfg): ), ], ) -def test_flag_override(src, flag_name, flag_value, func, expectation): +def test_flag_override( + src: Dict[str, Any], flag_name: str, flag_value: bool, func: Any, expectation: Any +) -> None: c = OmegaConf.create(src) c._set_flag(flag_name, True) with expectation: @@ -284,14 +291,14 @@ def test_flag_override(src, flag_name, flag_value, func, expectation): func(c) -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "src, func, expectation", [ ({}, lambda c: c.__setitem__("foo", 1), pytest.raises(ReadonlyConfigError)), ([], lambda c: c.append(1), pytest.raises(ReadonlyConfigError)), ], ) -def test_read_write_override(src, func, expectation): +def test_read_write_override(src: Any, func: Any, expectation: Any) -> None: c = OmegaConf.create(src) OmegaConf.set_readonly(c, True) @@ -303,7 +310,7 @@ def test_read_write_override(src, func, expectation): func(c) -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "string, tokenized", [ ("dog,cat", ["dog", "cat"]), @@ -319,15 +326,15 @@ def test_read_write_override(src, func, expectation): ("no , escape", ["no", "escape"]), ], ) -def test_tokenize_with_escapes(string, tokenized): +def test_tokenize_with_escapes(string: str, tokenized: List[str]) -> None: assert OmegaConf._tokenize_args(string) == tokenized -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "src, func, expectation", [({}, lambda c: c.__setattr__("foo", 1), pytest.raises(AttributeError))], ) -def test_struct_override(src, func, expectation): +def test_struct_override(src: Any, func: Any, expectation: Any) -> None: c = OmegaConf.create(src) OmegaConf.set_struct(c, True) @@ -339,10 +346,10 @@ def test_struct_override(src, func, expectation): func(c) -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "flag_name,ctx", [("struct", open_dict), ("readonly", read_write)] ) -def test_open_dict_restore(flag_name, ctx): +def test_open_dict_restore(flag_name: str, ctx: Any) -> None: """ Tests that internal flags are restored properly when applying context on a child node """ @@ -358,22 +365,22 @@ def test_open_dict_restore(flag_name, ctx): @pytest.mark.parametrize("copy_method", [lambda x: copy.copy(x), lambda x: x.copy()]) class TestCopy: - @pytest.mark.parametrize( + @pytest.mark.parametrize( # type: ignore "src", [[], [1, 2], ["a", "b", "c"], {}, {"a": "b"}, {"a": {"b": []}}] ) - def test_copy(self, copy_method, src): + def test_copy(self, copy_method: Any, src: Any) -> None: src = OmegaConf.create(src) cp = copy_method(src) assert id(src) != id(cp) assert src == cp - @pytest.mark.parametrize( + @pytest.mark.parametrize( # type: ignore "src,interpolating_key,interpolated_key", [([1, 2, "${0}"], 2, 0), ({"a": 10, "b": "${a}"}, "b", "a")], ) def test_copy_with_interpolation( - self, copy_method, src, interpolating_key, interpolated_key - ): + self, copy_method: Any, src: Any, interpolating_key: str, interpolated_key: str + ) -> None: cfg = OmegaConf.create(src) assert cfg[interpolated_key] == cfg[interpolating_key] cp = copy_method(cfg) @@ -389,32 +396,32 @@ def test_copy_with_interpolation( cp[interpolated_key] = "XXX" assert cp[interpolated_key] == cp[interpolating_key] - def test_list_copy_is_shallow(self, copy_method): + def test_list_copy_is_shallow(self, copy_method: Any) -> None: cfg = OmegaConf.create([[10, 20]]) cp = copy_method(cfg) assert id(cfg) != id(cp) assert id(cfg[0]) == id(cp[0]) -def test_not_implemented(): +def test_not_implemented() -> None: with pytest.raises(NotImplementedError): OmegaConf() -@pytest.mark.parametrize("query, result", [("a", "a"), ("${foo}", 10), ("${bar}", 10)]) -def test_resolve_single(query, result): +@pytest.mark.parametrize("query, result", [("a", "a"), ("${foo}", 10), ("${bar}", 10)]) # type: ignore +def test_resolve_single(query: str, result: Any) -> None: cfg = OmegaConf.create({"foo": 10, "bar": "${foo}"}) assert cfg._resolve_single(value=query) == result -def test_omegaconf_create(): +def test_omegaconf_create() -> None: assert OmegaConf.create([]) == [] assert OmegaConf.create({}) == {} with pytest.raises(ValidationError): - assert OmegaConf.create(10) + assert OmegaConf.create(10) # type: ignore -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "cfg, key, expected", [ ({}, "foo", False), @@ -423,12 +430,12 @@ def test_omegaconf_create(): ({"foo": "${bar}", "bar": MISSING}, "foo", True), ], ) -def test_is_missing(cfg, key, expected): +def test_is_missing(cfg: Any, key: str, expected: Any) -> None: cfg = OmegaConf.create(cfg) assert OmegaConf.is_missing(cfg, key) == expected -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "cfg, is_conf, is_list, is_dict", [ (None, False, False, False), @@ -442,13 +449,13 @@ def test_is_missing(cfg, key, expected): (OmegaConf.create([]), True, True, False), ], ) -def test_is_config(cfg, is_conf, is_list, is_dict): +def test_is_config(cfg: Any, is_conf: bool, is_list: bool, is_dict: bool) -> None: assert OmegaConf.is_config(cfg) == is_conf assert OmegaConf.is_list(cfg) == is_list assert OmegaConf.is_dict(cfg) == is_dict -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "parent, index, value, expected", [ ([10, 11], 0, ["a", "b"], [["a", "b"], 11]), @@ -459,7 +466,7 @@ def test_is_config(cfg, is_conf, is_list, is_dict): ({}, "foo", OmegaConf.create({"foo": "bar"}), {"foo": {"foo": "bar"}}), ], ) -def test_assign(parent, index, value, expected): +def test_assign(parent: Any, index: Union[str, int], value: Any, expected: Any) -> None: c = OmegaConf.create(parent) c[index] = value assert c == expected diff --git a/tests/test_basic_ops_dict.py b/tests/test_basic_ops_dict.py index ba8c900b5..67ef44af1 100644 --- a/tests/test_basic_ops_dict.py +++ b/tests/test_basic_ops_dict.py @@ -2,19 +2,17 @@ import re import tempfile from enum import Enum -from typing import Any - -import pytest # type: ignore +from typing import Any, Dict, List, Union +import pytest from omegaconf import ( - AnyNode, - BaseContainer, DictConfig, MissingMandatoryValue, OmegaConf, UnsupportedKeyType, UnsupportedValueType, ) +from omegaconf.basecontainer import BaseContainer from . import IllegalType @@ -24,13 +22,13 @@ class Enum1(Enum): BAR = 2 -def test_setattr_deep_value(): +def test_setattr_deep_value() -> None: c = OmegaConf.create(dict(a=dict(b=dict(c=1)))) c.a.b = 9 assert {"a": {"b": 9}} == c -def test_setattr_deep_from_empty(): +def test_setattr_deep_from_empty() -> None: c = OmegaConf.create() # Unfortunately we can't just do c.a.b = 9 here. # The reason is that if c.a is being resolved first and it does not exist, so there @@ -38,55 +36,56 @@ def test_setattr_deep_from_empty(): # The alternative is to auto-create fields as they are being accessed, but this is opening # a whole new can of worms, and is also breaking map semantics. c.a = {} - c.a.b = 9 + c.a.b = 9 # type: ignore assert {"a": {"b": 9}} == c -def test_setattr_deep_map(): +def test_setattr_deep_map() -> None: c = OmegaConf.create(dict(a=dict(b=dict(c=1)))) c.a.b = {"z": 10} assert {"a": {"b": {"z": 10}}} == c -def test_getattr(): +def test_getattr() -> None: c = OmegaConf.create("a: b") assert "b" == c.a -def test_getattr_dict(): +def test_getattr_dict() -> None: c = OmegaConf.create("a: {b: 1}") assert {"b": 1} == c.a -def test_mandatory_value(): +def test_mandatory_value() -> None: c = OmegaConf.create(dict(a="???")) with pytest.raises(MissingMandatoryValue, match="a"): c.a -def test_nested_dict_mandatory_value(): +def test_nested_dict_mandatory_value() -> None: c = OmegaConf.create(dict(a=dict(b="???"))) with pytest.raises(MissingMandatoryValue): c.a.b -def test_mandatory_with_default(): +def test_mandatory_with_default() -> None: c = OmegaConf.create(dict(name="???")) assert c.get("name", "default value") == "default value" -def test_subscript_get(): +def test_subscript_get() -> None: c = OmegaConf.create("a: b") + assert isinstance(c, DictConfig) assert "b" == c["a"] -def test_subscript_set(): +def test_subscript_set() -> None: c = OmegaConf.create() c["a"] = "b" assert {"a": "b"} == c -def test_pretty_dict(): +def test_pretty_dict() -> None: c = OmegaConf.create(dict(hello="world", list=[1, 2])) expected = """hello: world list: @@ -97,7 +96,7 @@ def test_pretty_dict(): assert OmegaConf.create(c.pretty()) == c -def test_pretty_dict_unicode(): +def test_pretty_dict_unicode() -> None: c = OmegaConf.create(dict(你好="世界", list=[1, 2])) expected = """list: - 1 @@ -108,49 +107,51 @@ def test_pretty_dict_unicode(): assert OmegaConf.create(c.pretty()) == c -def test_default_value(): +def test_default_value() -> None: c = OmegaConf.create() assert c.missing_key or "a default value" == "a default value" -def test_get_default_value(): +def test_get_default_value() -> None: c = OmegaConf.create() assert c.get("missing_key", "a default value") == "a default value" -def test_scientific_notation_float(): +def test_scientific_notation_float() -> None: c = OmegaConf.create("a: 10e-3") assert 10e-3 == c.a -def test_dict_get_with_default(): +def test_dict_get_with_default() -> None: s = "{hello: {a : 2}}" c = OmegaConf.create(s) + assert isinstance(c, DictConfig) assert c.get("missing", 4) == 4 assert c.hello.get("missing", 5) == 5 -def test_map_expansion(): +def test_map_expansion() -> None: c = OmegaConf.create("{a: 2, b: 10}") + assert isinstance(c, DictConfig) - def foo(a, b): + def foo(a: int, b: int) -> int: return a + b assert 12 == foo(**c) -def test_items(): +def test_items() -> None: c = OmegaConf.create(dict(a=2, b=10)) assert sorted([("a", 2), ("b", 10)]) == sorted(list(c.items())) ii = c.items() - next(ii) == ("a", 2) - next(ii) == ("b", 10) + assert next(ii) == ("a", 2) + assert next(ii) == ("b", 10) with pytest.raises(StopIteration): next(ii) -def test_items2(): +def test_items2() -> None: c = OmegaConf.create(dict(a=dict(v=1), b=dict(v=1))) for k, v in c.items(): v.v = 2 @@ -159,7 +160,7 @@ def test_items2(): assert c.b.v == 2 -def test_items_with_interpolation(): +def test_items_with_interpolation() -> None: c = OmegaConf.create(dict(a=2, b="${a}")) r = {} for k, v in c.items(): @@ -168,18 +169,19 @@ def test_items_with_interpolation(): assert r["b"] == 2 -def test_dict_keys(): +def test_dict_keys() -> None: c = OmegaConf.create("{a: 2, b: 10}") assert {"a": 2, "b": 10}.keys() == c.keys() -def test_pickle_get_root(): +def test_pickle_get_root() -> None: # Test that get_root() is reconstructed correctly for pickle loaded files. with tempfile.TemporaryFile() as fp: c1 = OmegaConf.create(dict(a=dict(a1=1, a2=2))) - c2 = OmegaConf.create(dict(b=dict(b1="???", b2=4, bb=dict(bb1=3, bb2=4)))) + c2 = OmegaConf.create(dict(b=dict(b1="${a.a1}", b2=4, bb=dict(bb1=3, bb2=4)))) c3 = OmegaConf.merge(c1, c2) + assert isinstance(c3, DictConfig) import pickle @@ -188,7 +190,7 @@ def test_pickle_get_root(): fp.seek(0) loaded_c3 = pickle.load(fp) - def test(conf): + def test(conf: DictConfig) -> None: assert conf._get_root() == conf assert conf.a._get_root() == conf assert conf.b._get_root() == conf @@ -199,7 +201,7 @@ def test(conf): test(loaded_c3) -def test_iterate_dictionary(): +def test_iterate_dictionary() -> None: c = OmegaConf.create(dict(a=1, b=2)) m2 = {} for key in c: @@ -207,7 +209,7 @@ def test_iterate_dictionary(): assert m2 == c -def test_dict_pop(): +def test_dict_pop() -> None: c = OmegaConf.create(dict(a=1, b=2)) assert c.pop("a") == 1 assert c.pop("not_found", "default") == "default" @@ -216,9 +218,10 @@ def test_dict_pop(): c.pop("not_found") -def test_dict_enum_pop(): +def test_dict_enum_pop() -> None: - c = OmegaConf.create({Enum1.FOO: "bar"}) + # TODO: fix this type: ignore + c = OmegaConf.create({Enum1.FOO: "bar"}) # type: ignore with pytest.raises(KeyError): c.pop(Enum1.BAR) @@ -227,7 +230,7 @@ def test_dict_enum_pop(): c.pop(Enum1.FOO) -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "conf,key,expected", [ ({"a": 1, "b": {}}, "a", True), @@ -243,19 +246,19 @@ def test_dict_enum_pop(): ({Enum1.FOO: 1, "b": {}}, Enum1.FOO, True), ], ) -def test_in_dict(conf, key, expected): +def test_in_dict(conf: Any, key: str, expected: Any) -> None: conf = OmegaConf.create(conf) ret = key in conf assert ret == expected -def test_get_root(): +def test_get_root() -> None: c = OmegaConf.create(dict(a=123, b=dict(bb=456, cc=7))) assert c._get_root() == c assert c.b._get_root() == c -def test_get_root_of_merged(): +def test_get_root_of_merged() -> None: c1 = OmegaConf.create(dict(a=dict(a1=1, a2=2))) c2 = OmegaConf.create(dict(b=dict(b1="???", b2=4, bb=dict(bb1=3, bb2=4)))) @@ -267,12 +270,12 @@ def test_get_root_of_merged(): assert c3.b.bb._get_root() == c3 -def test_dict_config(): +def test_dict_config() -> None: c = OmegaConf.create(dict()) assert isinstance(c, DictConfig) -def test_dict_delitem(): +def test_dict_delitem() -> None: c = OmegaConf.create(dict(a=10, b=11)) assert c == dict(a=10, b=11) del c["a"] @@ -281,31 +284,31 @@ def test_dict_delitem(): del c["not_found"] -def test_dict_len(): +def test_dict_len() -> None: c = OmegaConf.create(dict(a=10, b=11)) assert len(c) == 2 -def test_dict_assign_illegal_value(): +def test_dict_assign_illegal_value() -> None: c = OmegaConf.create(dict()) with pytest.raises(UnsupportedValueType, match=re.escape("key a")): c.a = IllegalType() -def test_dict_assign_illegal_value_nested(): +def test_dict_assign_illegal_value_nested() -> None: c = OmegaConf.create(dict(a=dict())) with pytest.raises(UnsupportedValueType, match=re.escape("key a.b")): c.a.b = IllegalType() -def test_assign_dict_in_dict(): +def test_assign_dict_in_dict() -> None: c = OmegaConf.create(dict()) c.foo = dict(foo="bar") assert c.foo == dict(foo="bar") assert isinstance(c.foo, DictConfig) -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "src", [ dict(a=1, b=2, c=dict(aa=10)), @@ -313,128 +316,43 @@ def test_assign_dict_in_dict(): {"a": 1, "b": 2, "c": {"aa": 10, "lst": [1, 2, 3]}}, ], ) -def test_to_container(src): +def test_to_container(src: Any) -> None: c = OmegaConf.create(src) result = OmegaConf.to_container(c) assert type(result) == type(src) assert result == src -def test_pretty_without_resolve(): +def test_pretty_without_resolve() -> None: c = OmegaConf.create(dict(a1="${ref}", ref="bar")) # without resolve, references are preserved c2 = OmegaConf.create(c.pretty(resolve=False)) + assert isinstance(c2, DictConfig) assert c2.a1 == "bar" c2.ref = "changed" assert c2.a1 == "changed" -def test_pretty_with_resolve(): +def test_pretty_with_resolve() -> None: c = OmegaConf.create(dict(a1="${ref}", ref="bar")) c2 = OmegaConf.create(c.pretty(resolve=True)) + assert isinstance(c2, DictConfig) assert c2.a1 == "bar" c2.ref = "changed" assert c2.a1 == "bar" -def test_instantiate_config_fails(): +def test_instantiate_config_fails() -> None: with pytest.raises(TypeError): - BaseContainer(element_type=Any, parent=None) + BaseContainer(element_type=Any, parent=None) # type: ignore -def test_dir(): +def test_dir() -> None: c = OmegaConf.create(dict(a=1, b=2, c=3)) assert ["a", "b", "c"] == dir(c) -@pytest.mark.parametrize( - "input1, input2", - [ - # empty - (dict(), dict()), - # simple - (dict(a=12), dict(a=12)), - # any vs raw - (dict(a=12), dict(a=AnyNode(12))), - # nested dict empty - (dict(a=12, b=dict()), dict(a=12, b=dict())), - # nested dict - (dict(a=12, b=dict(c=10)), dict(a=12, b=dict(c=10))), - # nested list - (dict(a=12, b=[1, 2, 3]), dict(a=12, b=[1, 2, 3])), - # nested list with any - (dict(a=12, b=[1, 2, AnyNode(3)]), dict(a=12, b=[1, 2, AnyNode(3)])), - # In python 3.6 insert order changes iteration order. this ensures that equality is preserved. - (dict(a=1, b=2, c=3, d=4, e=5), dict(e=5, b=2, c=3, d=4, a=1)), - ], -) -def test_dict_eq(input1, input2): - c1 = OmegaConf.create(input1) - c2 = OmegaConf.create(input2) - - def eq(a, b): - assert a == b - assert b == a - assert not a != b - assert not b != a - - eq(c1, c2) - eq(c1, input1) - eq(c2, input2) - - -@pytest.mark.parametrize("input1, input2", [(dict(a=12, b="${a}"), dict(a=12, b=12))]) -def test_dict_eq_with_interpolation(input1, input2): - c1 = OmegaConf.create(input1) - c2 = OmegaConf.create(input2) - - def eq(a, b): - assert a == b - assert b == a - assert not a != b - assert not b != a - - eq(c1, c2) - - -@pytest.mark.parametrize( - "input1, input2", - [ - (dict(), dict(a=10)), - ({}, []), - (dict(a=12), dict(a=13)), - (dict(a=0), dict(b=0)), - (dict(a=12), dict(a=AnyNode(13))), - (dict(a=12, b=dict()), dict(a=13, b=dict())), - (dict(a=12, b=dict(c=10)), dict(a=13, b=dict(c=10))), - (dict(a=12, b=[1, 2, 3]), dict(a=12, b=[10, 2, 3])), - (dict(a=12, b=[1, 2, AnyNode(3)]), dict(a=12, b=[1, 2, AnyNode(30)])), - ], -) -def test_dict_not_eq(input1, input2): - c1 = OmegaConf.create(input1) - c2 = OmegaConf.create(input2) - - def neq(a, b): - assert a != b - assert b != a - assert not a == b - assert not b == a - - neq(c1, c2) - - -def test_config_eq_mismatch_types(): - c1 = OmegaConf.create({}) - c2 = OmegaConf.create([]) - assert not BaseContainer._config_eq(c1, c2) - - -def test_dict_not_eq_with_another_class(): - assert OmegaConf.create() != "string" - - -def test_hash(): +def test_hash() -> None: c1 = OmegaConf.create(dict(a=10)) c2 = OmegaConf.create(dict(a=10)) assert hash(c1) == hash(c2) @@ -442,20 +360,20 @@ def test_hash(): assert hash(c1) != hash(c2) -def test_get_with_default_from_struct_not_throwing(): +def test_get_with_default_from_struct_not_throwing() -> None: c = OmegaConf.create(dict(a=10, b=20)) OmegaConf.set_struct(c, True) assert c.get("z", "default") == "default" -def test_members(): +def test_members() -> None: # Make sure accessing __members__ does not return None or throw. c = OmegaConf.create({"foo": {}}) assert c.__members__ == {} -@pytest.mark.parametrize( - "cfg, mask_keys, expected", +@pytest.mark.parametrize( # type: ignore + "in_cfg, mask_keys, expected", [ ({}, [], {}), ({"a": 1}, "a", {"a": 1}), @@ -464,13 +382,15 @@ def test_members(): ({"a": 1, "b": 2}, ["a", "b"], {"a": 1, "b": 2}), ], ) -def test_masked_copy(cfg, mask_keys, expected): - cfg = OmegaConf.create(cfg) +def test_masked_copy( + in_cfg: Dict[str, Any], mask_keys: Union[str, List[str]], expected: Any +) -> None: + cfg = OmegaConf.create(in_cfg) masked = OmegaConf.masked_copy(cfg, keys=mask_keys) assert masked == expected -def test_masked_copy_is_deep(): +def test_masked_copy_is_deep() -> None: cfg = OmegaConf.create({"a": {"b": 1, "c": 2}}) expected = {"a": {"b": 1, "c": 2}} masked = OmegaConf.masked_copy(cfg, keys=["a"]) @@ -479,27 +399,27 @@ def test_masked_copy_is_deep(): assert cfg != expected with pytest.raises(ValueError): - OmegaConf.masked_copy("fail", []) + OmegaConf.masked_copy("fail", []) # type: ignore -def test_creation_with_invalid_key(): +def test_creation_with_invalid_key() -> None: with pytest.raises(UnsupportedKeyType): - OmegaConf.create({1: "a"}) + OmegaConf.create({1: "a"}) # type: ignore -def test_set_with_invalid_key(): +def test_set_with_invalid_key() -> None: cfg = OmegaConf.create() with pytest.raises(UnsupportedKeyType): - cfg[1] = "a" + cfg[1] = "a" # type: ignore -def test_get_with_invalid_key(): +def test_get_with_invalid_key() -> None: cfg = OmegaConf.create() with pytest.raises(UnsupportedKeyType): - cfg[1] + cfg[1] # type: ignore -def test_hasattr(): +def test_hasattr() -> None: cfg = OmegaConf.create({"foo": "bar"}) OmegaConf.set_struct(cfg, True) assert hasattr(cfg, "foo") diff --git a/tests/test_basic_ops_list.py b/tests/test_basic_ops_list.py index 830188cd8..48b3a925e 100644 --- a/tests/test_basic_ops_list.py +++ b/tests/test_basic_ops_list.py @@ -1,28 +1,28 @@ # -*- coding: utf-8 -*- import re +from typing import Any, List, Optional import pytest - -from omegaconf import AnyNode, DictConfig, ListConfig, OmegaConf +from omegaconf import AnyNode, ListConfig, OmegaConf from omegaconf.errors import UnsupportedKeyType, UnsupportedValueType from omegaconf.nodes import IntegerNode, StringNode from . import IllegalType, does_not_raise -def test_list_value(): +def test_list_value() -> None: c = OmegaConf.create("a: [1,2]") assert {"a": [1, 2]} == c -def test_list_of_dicts(): +def test_list_of_dicts() -> None: v = [dict(key1="value1"), dict(key2="value2")] c = OmegaConf.create(v) assert c[0].key1 == "value1" assert c[1].key2 == "value2" -def test_pretty_list(): +def test_pretty_list() -> None: c = OmegaConf.create(["item1", "item2", dict(key3="value3")]) expected = """- item1 - item2 @@ -32,7 +32,7 @@ def test_pretty_list(): assert OmegaConf.create(c.pretty()) == c -def test_pretty_list_unicode(): +def test_pretty_list_unicode() -> None: c = OmegaConf.create(["item一", "item二", dict(key三="value三")]) expected = """- item一 - item二 @@ -42,27 +42,26 @@ def test_pretty_list_unicode(): assert OmegaConf.create(c.pretty()) == c -def test_list_get_with_default(): +def test_list_get_with_default() -> None: c = OmegaConf.create([None, "???", "found"]) assert c.get(0, "default_value") == "default_value" assert c.get(1, "default_value") == "default_value" assert c.get(2, "default_value") == "found" -def test_iterate_list(): +def test_iterate_list() -> None: c = OmegaConf.create([1, 2]) items = [x for x in c] assert items[0] == 1 assert items[1] == 2 -def test_items_with_interpolation(): +def test_items_with_interpolation() -> None: c = OmegaConf.create(["foo", "${0}"]) - assert c == ["foo", "foo"] -def test_list_pop(): +def test_list_pop() -> None: c = OmegaConf.create([1, 2, 3, 4]) assert c.pop(0) == 1 assert c.pop() == 4 @@ -71,7 +70,7 @@ def test_list_pop(): c.pop(100) -def test_in_list(): +def test_in_list() -> None: c = OmegaConf.create([10, 11, dict(a=12)]) assert 10 in c assert 11 in c @@ -79,24 +78,24 @@ def test_in_list(): assert "blah" not in c -def test_list_config_with_list(): +def test_list_config_with_list() -> None: c = OmegaConf.create([]) assert isinstance(c, ListConfig) -def test_list_config_with_tuple(): +def test_list_config_with_tuple() -> None: c = OmegaConf.create(()) assert isinstance(c, ListConfig) -def test_items_on_list(): +def test_items_on_list() -> None: c = OmegaConf.create([1, 2]) with pytest.raises(AttributeError): c.items() -def test_list_enumerate(): - src = ["a", "b", "c", "d"] +def test_list_enumerate() -> None: + src: List[Optional[str]] = ["a", "b", "c", "d"] c = OmegaConf.create(src) for i, v in enumerate(c): assert src[i] == v @@ -107,7 +106,7 @@ def test_list_enumerate(): assert v is None -def test_list_delitem(): +def test_list_delitem() -> None: c = OmegaConf.create([1, 2, 3]) assert c == [1, 2, 3] del c[0] @@ -116,65 +115,58 @@ def test_list_delitem(): del c[100] -def test_list_len(): +def test_list_len() -> None: c = OmegaConf.create([1, 2]) assert len(c) == 2 -def test_nested_list_assign_illegal_value(): +def test_nested_list_assign_illegal_value() -> None: c = OmegaConf.create(dict(a=[None])) with pytest.raises(UnsupportedValueType, match=re.escape("key a[0]")): c.a[0] = IllegalType() -def test_list_append(): +def test_list_append() -> None: c = OmegaConf.create([]) c.append(1) c.append(2) c.append({}) c.append([]) - assert isinstance(c[2], DictConfig) - assert isinstance(c[3], ListConfig) assert c == [1, 2, {}, []] -def test_pretty_without_resolve(): +def test_pretty_without_resolve() -> None: c = OmegaConf.create([100, "${0}"]) # without resolve, references are preserved c2 = OmegaConf.create(c.pretty(resolve=False)) + assert isinstance(c2, ListConfig) c2[0] = 1000 assert c2[1] == 1000 -def test_pretty_with_resolve(): +def test_pretty_with_resolve() -> None: c = OmegaConf.create([100, "${0}"]) # with resolve, references are not preserved. c2 = OmegaConf.create(c.pretty(resolve=True)) + assert isinstance(c2, ListConfig) c2[0] = 1000 assert c[1] == 100 -def test_index_slice(): - c = OmegaConf.create([10, 11, 12, 13]) - assert c[1:3] == [11, 12] - - -def test_index_slice2(): - c = OmegaConf.create([10, 11, 12, 13]) - assert c[0:3:2] == [10, 12] - - -def test_negative_index(): +@pytest.mark.parametrize( # type: ignore + "index, expected", [(slice(1, 3), [11, 12]), (slice(0, 3, 2), [10, 12]), (-1, 13)] +) +def test_list_index(index: Any, expected: Any) -> None: c = OmegaConf.create([10, 11, 12, 13]) - assert c[-1] == 13 + assert c[index] == expected -def test_list_dir(): +def test_list_dir() -> None: c = OmegaConf.create([1, 2, 3]) assert ["0", "1", "2"] == dir(c) -def test_getattr(): +def test_getattr() -> None: c = OmegaConf.create(["a", "b", "c"]) assert getattr(c, "0") == "a" assert getattr(c, "1") == "b" @@ -183,7 +175,7 @@ def test_getattr(): getattr(c, "anything") -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "input_, index, value, expected, expected_node_type", [ (["a", "b", "c"], 1, 100, ["a", 100, "b", "c"], AnyNode), @@ -192,14 +184,16 @@ def test_getattr(): (["a", "b", "c"], 1, StringNode("foo"), ["a", "foo", "b", "c"], StringNode), ], ) -def test_insert(input_, index, value, expected, expected_node_type): +def test_insert( + input_: List[str], index: int, value: Any, expected: Any, expected_node_type: type +) -> None: c = OmegaConf.create(input_) c.insert(index, value) assert c == expected assert type(c.get_node(index)) == expected_node_type -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "src, append, result", [ ([], [], []), @@ -207,13 +201,13 @@ def test_insert(input_, index, value, expected, expected_node_type): ([1, 2], ("a", "b", "c"), [1, 2, "a", "b", "c"]), ], ) -def test_extend(src, append, result): - src = OmegaConf.create(src) - src.extend(append) - assert src == result +def test_extend(src: List[Any], append: List[Any], result: List[Any]) -> None: + lst = OmegaConf.create(src) + lst.extend(append) + assert lst == result -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "src, remove, result, expectation", [ ([10], 10, [], does_not_raise()), @@ -222,23 +216,24 @@ def test_extend(src, append, result): ([1, 2, 1, 2], 2, [1, 1, 2], does_not_raise()), ], ) -def test_remove(src, remove, result, expectation): +def test_remove(src: List[Any], remove: Any, result: Any, expectation: Any) -> None: with expectation: - src = OmegaConf.create(src) - src.remove(remove) - assert src == result + lst = OmegaConf.create(src) + assert isinstance(lst, ListConfig) + lst.remove(remove) + assert lst == result -@pytest.mark.parametrize("src", [[], [1, 2, 3], [None, dict(foo="bar")]]) -@pytest.mark.parametrize("num_clears", [1, 2]) -def test_clear(src, num_clears): - src = OmegaConf.create(src) +@pytest.mark.parametrize("src", [[], [1, 2, 3], [None, dict(foo="bar")]]) # type: ignore +@pytest.mark.parametrize("num_clears", [1, 2]) # type: ignore +def test_clear(src: List[Any], num_clears: int) -> None: + lst = OmegaConf.create(src) for i in range(num_clears): - src.clear() - assert src == [] + lst.clear() + assert lst == [] -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "src, item, expected_index, expectation", [ ([], 20, -1, pytest.raises(ValueError)), @@ -246,22 +241,36 @@ def test_clear(src, num_clears): ([10, 20], 20, 1, does_not_raise()), ], ) -def test_index(src, item, expected_index, expectation): +def test_index( + src: List[Any], item: Any, expected_index: int, expectation: Any +) -> None: with expectation: - src = OmegaConf.create(src) - assert src.index(item) == expected_index + lst = OmegaConf.create(src) + assert lst.index(item) == expected_index -@pytest.mark.parametrize( +def test_index_with_range() -> None: + lst = OmegaConf.create([10, 20, 30, 40, 50]) + assert lst.index(x=30) == 2 + assert lst.index(x=30, start=1) == 2 + assert lst.index(x=30, start=1, end=3) == 2 + with pytest.raises(ValueError): + lst.index(x=30, start=3) + + with pytest.raises(ValueError): + lst.index(x=30, end=2) + + +@pytest.mark.parametrize( # type: ignore "src, item, count", [([], 10, 0), ([10], 10, 1), ([10, 2, 10], 10, 2), ([10, 2, 10], None, 0)], ) -def test_count(src, item, count): - src = OmegaConf.create(src) - assert src.count(item) == count +def test_count(src: List[Any], item: Any, count: int) -> None: + lst = OmegaConf.create(src) + assert lst.count(item) == count -def test_sort(): +def test_sort() -> None: c = OmegaConf.create(["bbb", "aa", "c"]) c.sort() assert ["aa", "bbb", "c"] == c @@ -273,80 +282,7 @@ def test_sort(): assert ["bbb", "aa", "c"] == c -@pytest.mark.parametrize( - "l1,l2", - [ - # empty list - ([], []), - # simple list - (["a", 12, "15"], ["a", 12, "15"]), - # raw vs any - ([1, 2, 12], [1, 2, AnyNode(12)]), - # nested empty dict - ([12, dict()], [12, dict()]), - # nested dict - ([12, dict(c=10)], [12, dict(c=10)]), - # nested list - ([1, 2, 3, [10, 20, 30]], [1, 2, 3, [10, 20, 30]]), - # nested list with any - ([1, 2, 3, [1, 2, AnyNode(3)]], [1, 2, 3, [1, 2, AnyNode(3)]]), - ], -) -def test_list_eq(l1, l2): - c1 = OmegaConf.create(l1) - c2 = OmegaConf.create(l2) - - def eq(a, b): - assert a == b - assert b == a - assert not a != b - assert not b != a - - eq(c1, c2) - eq(c1, l1) - eq(c2, l2) - - -@pytest.mark.parametrize("l1,l2", [([10, "${0}"], [10, 10])]) -def test_list_eq_with_interpolation(l1, l2): - c1 = OmegaConf.create(l1) - c2 = OmegaConf.create(l2) - - def eq(a, b): - assert a == b - assert b == a - assert not a != b - assert not b != a - - eq(c1, c2) - - -@pytest.mark.parametrize( - "input1, input2", - [ - ([], [10]), - ([10], [11]), - ([12], [AnyNode(13)]), - ([12, dict()], [13, dict()]), - ([12, dict(c=10)], [13, dict(c=10)]), - ([12, [1, 2, 3]], [12, [10, 2, 3]]), - ([12, [1, 2, AnyNode(3)]], [12, [1, 2, AnyNode(30)]]), - ], -) -def test_list_not_eq(input1, input2): - c1 = OmegaConf.create(input1) - c2 = OmegaConf.create(input2) - - def neq(a, b): - assert a != b - assert b != a - assert not a == b - assert not b == a - - neq(c1, c2) - - -def test_insert_throws_not_changing_list(): +def test_insert_throws_not_changing_list() -> None: c = OmegaConf.create([]) with pytest.raises(ValueError): c.insert(0, IllegalType()) @@ -354,7 +290,7 @@ def test_insert_throws_not_changing_list(): assert c == [] -def test_append_throws_not_changing_list(): +def test_append_throws_not_changing_list() -> None: c = OmegaConf.create([]) with pytest.raises(ValueError): c.append(IllegalType()) @@ -362,7 +298,7 @@ def test_append_throws_not_changing_list(): assert c == [] -def test_hash(): +def test_hash() -> None: c1 = OmegaConf.create([10]) c2 = OmegaConf.create([10]) assert hash(c1) == hash(c2) @@ -371,7 +307,7 @@ def test_hash(): @pytest.mark.parametrize( - "list1, list2, expected", + "in_list1, in_list2,in_expected", [ ([], [], []), ([1, 2], [3, 4], [1, 2, 3, 4]), @@ -379,28 +315,32 @@ def test_hash(): ], ) class TestListAdd: - def test_list_plus(self, list1, list2, expected): - list1 = OmegaConf.create(list1) - list2 = OmegaConf.create(list2) - expected = OmegaConf.create(expected) + def test_list_plus( + self, in_list1: List[Any], in_list2: List[Any], in_expected: List[Any] + ) -> None: + list1 = OmegaConf.create(in_list1) + list2 = OmegaConf.create(in_list2) + expected = OmegaConf.create(in_expected) ret = list1 + list2 assert ret == expected - def test_list_plus_eq(self, list1, list2, expected): - list1 = OmegaConf.create(list1) - list2 = OmegaConf.create(list2) - expected = OmegaConf.create(expected) + def test_list_plus_eq( + self, in_list1: List[Any], in_list2: List[Any], in_expected: List[Any] + ) -> None: + list1 = OmegaConf.create(in_list1) + list2 = OmegaConf.create(in_list2) + expected = OmegaConf.create(in_expected) list1 += list2 assert list1 == expected -def test_deep_add(): +def test_deep_add() -> None: cfg = OmegaConf.create({"foo": [1, 2, "${bar}"], "bar": "xx"}) lst = cfg.foo + [10, 20] assert lst == [1, 2, "xx", 10, 20] -def test_set_with_invalid_key(): +def test_set_with_invalid_key() -> None: cfg = OmegaConf.create([1, 2, 3]) with pytest.raises(UnsupportedKeyType): - cfg["foo"] = 4 + cfg["foo"] = 4 # type: ignore diff --git a/tests/test_create.py b/tests/test_create.py index b9b74a9f7..b92a0d2cc 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -1,16 +1,16 @@ """Testing for OmegaConf""" import re import sys +from typing import Any, Dict, List import pytest - from omegaconf import OmegaConf from omegaconf.errors import UnsupportedValueType from . import IllegalType -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "input_,expected", [ # empty @@ -42,24 +42,24 @@ (OmegaConf.create([OmegaConf.create({})]), [{}]), ], ) -def test_create_value(input_, expected): +def test_create_value(input_: Any, expected: Any) -> None: assert expected == OmegaConf.create(input_) -def test_create_from_cli(): +def test_create_from_cli() -> None: sys.argv = ["program.py", "a=1", "b.c=2"] c = OmegaConf.from_cli() assert {"a": 1, "b": {"c": 2}} == c -def test_cli_passing(): +def test_cli_passing() -> None: args_list = ["a=1", "b.c=2"] c = OmegaConf.from_cli(args_list) assert {"a": 1, "b": {"c": 2}} == c -@pytest.mark.parametrize( - "input_, expected", +@pytest.mark.parametrize( # type: ignore + "input_,expected", [ # simple (["a=1", "b.c=2"], dict(a=1, b=dict(c=2))), @@ -69,35 +69,43 @@ def test_cli_passing(): (["my_date=2019-12-11"], dict(my_date="2019-12-11")), ], ) -def test_dotlist(input_, expected): +def test_dotlist(input_: List[str], expected: Dict[str, Any]) -> None: c = OmegaConf.from_dotlist(input_) assert c == expected -def test_create_list_with_illegal_value_idx0(): +def test_create_list_with_illegal_value_idx0() -> None: with pytest.raises(UnsupportedValueType, match=re.escape("key [0]")): OmegaConf.create([IllegalType()]) -def test_create_list_with_illegal_value_idx1(): +def test_create_list_with_illegal_value_idx1() -> None: with pytest.raises(UnsupportedValueType, match=re.escape("key [1]")): OmegaConf.create([1, IllegalType(), 3]) -def test_create_dict_with_illegal_value(): +def test_create_dict_with_illegal_value() -> None: with pytest.raises(UnsupportedValueType, match=re.escape("key a")): OmegaConf.create(dict(a=IllegalType())) # TODO: improve exception message to contain full key a.b # https://github.com/omry/omegaconf/issues/14 -def test_create_nested_dict_with_illegal_value(): +def test_create_nested_dict_with_illegal_value() -> None: with pytest.raises(ValueError): OmegaConf.create(dict(a=dict(b=IllegalType()))) -def test_create_from_oc(): +def test_create_from_oc() -> None: c = OmegaConf.create( {"a": OmegaConf.create([1, 2, 3]), "b": OmegaConf.create({"c": 10})} ) assert c == {"a": [1, 2, 3], "b": {"c": 10}} + + +def test_create_from_oc_with_flags() -> None: + c1 = OmegaConf.create({"foo": "bar"}) + OmegaConf.set_struct(c1, True) + c2 = OmegaConf.create(c1) + assert c1 == c2 + assert c1.flags == c2.flags diff --git a/tests/test_get_full_key.py b/tests/test_get_full_key.py index 00ee7b71d..52b69149c 100644 --- a/tests/test_get_full_key.py +++ b/tests/test_get_full_key.py @@ -1,71 +1,84 @@ -from omegaconf import OmegaConf +from omegaconf import DictConfig, ListConfig, OmegaConf class TestGetFullKey: # 1 - def test_dict(self): + def test_dict(self) -> None: c = OmegaConf.create(dict(a=1)) + assert isinstance(c, DictConfig) assert c.get_full_key("a") == "a" - def test_list(self): + def test_list(self) -> None: c = OmegaConf.create([1, 2, 3]) + assert isinstance(c, ListConfig) assert c.get_full_key("2") == "[2]" # 2 - def test_dd(self): + def test_dd(self) -> None: c = OmegaConf.create(dict(a=1, b=dict(c=1))) + assert isinstance(c, DictConfig) assert c.b.get_full_key("c") == "b.c" - def test_dl(self): + def test_dl(self) -> None: c = OmegaConf.create(dict(a=[1, 2, 3])) + assert isinstance(c, DictConfig) assert c.a.get_full_key(1) == "a[1]" - def test_ll(self): + def test_ll(self) -> None: c = OmegaConf.create([[1, 2, 3]]) + assert isinstance(c, ListConfig) assert c[0].get_full_key("2") == "[0][2]" - def test_ld(self): + def test_ld(self) -> None: c = OmegaConf.create([1, 2, dict(a=1)]) + assert isinstance(c, ListConfig) assert c[2].get_full_key("a") == "[2].a" # 3 - def test_ddd(self): + def test_ddd(self) -> None: c = OmegaConf.create(dict(a=dict(b=dict(c=1)))) + assert isinstance(c, DictConfig) assert c.a.b.get_full_key("c") == "a.b.c" - def test_ddl(self): + def test_ddl(self) -> None: c = OmegaConf.create(dict(a=dict(b=[0, 1]))) assert c.a.b.get_full_key(0) == "a.b[0]" - def test_dll(self): + def test_dll(self) -> None: c = OmegaConf.create(dict(a=[1, [2]])) assert c.a[1].get_full_key(0) == "a[1][0]" - def test_dld(self): + def test_dld(self) -> None: c = OmegaConf.create(dict(a=[dict(b=2)])) assert c.a[0].get_full_key("b") == "a[0].b" - def test_ldd(self): + def test_ldd(self) -> None: c = OmegaConf.create([dict(a=dict(b=1))]) + assert isinstance(c, ListConfig) assert c[0].a.get_full_key("b") == "[0].a.b" - def test_ldl(self): + def test_ldl(self) -> None: c = OmegaConf.create([dict(a=[0])]) + assert isinstance(c, ListConfig) assert c[0].a.get_full_key(0) == "[0].a[0]" - def test_lll(self): + def test_lll(self) -> None: c = OmegaConf.create([[[0]]]) + assert isinstance(c, ListConfig) assert c[0][0].get_full_key(0) == "[0][0][0]" - def test_lld(self): + def test_lld(self) -> None: c = OmegaConf.create([[dict(a=1)]]) + assert isinstance(c, ListConfig) assert c[0][0].get_full_key("a") == "[0][0].a" - def test_lldddl(self): + def test_lldddl(self) -> None: c = OmegaConf.create([[dict(a=dict(a=[0]))]]) + assert isinstance(c, ListConfig) assert c[0][0].a.a.get_full_key(0) == "[0][0].a.a[0]" # special cases - def test_parent_with_missing_item(self): + def test_parent_with_missing_item(self) -> None: c = OmegaConf.create(dict(x="???", a=1, b=dict(c=1))) + assert isinstance(c, DictConfig) assert c.b.get_full_key("c") == "b.c" diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index b9ad97fc1..8e277c0c2 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -1,19 +1,19 @@ import os import random +from typing import Any, Dict import pytest +from omegaconf import DictConfig, ListConfig, OmegaConf, Resolver -from omegaconf import OmegaConf - -def test_str_interpolation_dict_1(): +def test_str_interpolation_dict_1() -> None: # Simplest str_interpolation c = OmegaConf.create(dict(a="${referenced}", referenced="bar")) assert c.referenced == "bar" assert c.a == "bar" -def test_str_interpolation_key_error_1(): +def test_str_interpolation_key_error_1() -> None: # Test that a KeyError is thrown if an str_interpolation key is not available c = OmegaConf.create(dict(a="${not_found}")) @@ -21,7 +21,7 @@ def test_str_interpolation_key_error_1(): _ = c.a -def test_str_interpolation_key_error_2(): +def test_str_interpolation_key_error_2() -> None: # Test that a KeyError is thrown if an str_interpolation key is not available c = OmegaConf.create(dict(a="${not.found}")) @@ -29,14 +29,14 @@ def test_str_interpolation_key_error_2(): c.a -def test_str_interpolation_3(): +def test_str_interpolation_3() -> None: # Test that str_interpolation works with complex strings c = OmegaConf.create(dict(a="the year ${year}", year="of the cat")) assert c.a == "the year of the cat" -def test_str_interpolation_4(): +def test_str_interpolation_4() -> None: # Test that a string with multiple str_interpolations works c = OmegaConf.create( dict(a="${ha} ${ha} ${ha}, said Pennywise, ${ha} ${ha}... ${ha}!", ha="HA") @@ -45,7 +45,7 @@ def test_str_interpolation_4(): assert c.a == "HA HA HA, said Pennywise, HA HA... HA!" -def test_deep_str_interpolation_1(): +def test_deep_str_interpolation_1() -> None: # Test deep str_interpolation works c = OmegaConf.create( dict( @@ -57,7 +57,7 @@ def test_deep_str_interpolation_1(): assert c.a == "the answer to the universe and everything is 42" -def test_deep_str_interpolation_2(): +def test_deep_str_interpolation_2() -> None: # Test that str_interpolation of a key that is nested works c = OmegaConf.create( dict( @@ -69,7 +69,7 @@ def test_deep_str_interpolation_2(): assert c.deep.inside == "the answer to the universe and everything is 42" -def test_simple_str_interpolation_inherit_type(): +def test_simple_str_interpolation_inherit_type() -> None: # Test that str_interpolation of a key that is nested works c = OmegaConf.create( dict( @@ -90,7 +90,7 @@ def test_simple_str_interpolation_inherit_type(): assert type(c.inter4) == str -def test_complex_str_interpolation_is_always_str_1(): +def test_complex_str_interpolation_is_always_str_1() -> None: c = OmegaConf.create(dict(two=2, four=4, inter1="${four}${two}", inter2="4${two}")) assert type(c.inter1) == str @@ -99,7 +99,7 @@ def test_complex_str_interpolation_is_always_str_1(): assert c.inter2 == "42" -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "input_,key,expected", [ (dict(a=10, b="${a}"), "b", 10), @@ -110,18 +110,18 @@ def test_complex_str_interpolation_is_always_str_1(): (dict(a="foo-${b}", b=dict(c=10)), "a", "foo-{'c': 10}"), ], ) -def test_interpolation(input_, key, expected): +def test_interpolation(input_: Dict[str, Any], key: str, expected: str) -> None: c = OmegaConf.create(input_) assert c.select(key) == expected -def test_2_step_interpolation(): +def test_2_step_interpolation() -> None: c = OmegaConf.create(dict(src="bar", copy_src="${src}", copy_copy="${copy_src}")) assert c.copy_src == "bar" assert c.copy_copy == "bar" -def test_env_interpolation1(): +def test_env_interpolation1() -> None: try: os.environ["foobar"] = "1234" c = OmegaConf.create(dict(path="/test/${env:foobar}")) @@ -131,13 +131,13 @@ def test_env_interpolation1(): OmegaConf.clear_resolvers() -def test_env_interpolation_not_found(): +def test_env_interpolation_not_found() -> None: c = OmegaConf.create(dict(path="/test/${env:foobar}")) with pytest.raises(KeyError): c.path -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "value,expected", [ # bool @@ -162,7 +162,7 @@ def test_env_interpolation_not_found(): ("foo: \n - bar\n - baz", "foo: \n - bar\n - baz"), ], ) -def test_env_values_are_typed(value, expected): +def test_env_values_are_typed(value: Any, expected: Any) -> None: try: os.environ["my_key"] = value c = OmegaConf.create(dict(my_key="${env:my_key}")) @@ -172,16 +172,20 @@ def test_env_values_are_typed(value, expected): OmegaConf.clear_resolvers() -def test_register_resolver_twice_error(): +def test_register_resolver_twice_error() -> None: try: - OmegaConf.register_resolver("foo", lambda: 10) + + def foo() -> int: + return 10 + + OmegaConf.register_resolver("foo", foo) with pytest.raises(AssertionError): OmegaConf.register_resolver("foo", lambda: 10) finally: OmegaConf.clear_resolvers() -def test_clear_resolvers(): +def test_clear_resolvers() -> None: assert OmegaConf.get_resolver("foo") is None try: OmegaConf.register_resolver("foo", lambda x: int(x) + 10) @@ -191,7 +195,7 @@ def test_clear_resolvers(): assert OmegaConf.get_resolver("foo") is None -def test_register_resolver_1(): +def test_register_resolver_1() -> None: try: OmegaConf.register_resolver("plus_10", lambda x: int(x) + 10) c = OmegaConf.create(dict(k="${plus_10:990}")) @@ -202,7 +206,7 @@ def test_register_resolver_1(): OmegaConf.clear_resolvers() -def test_resolver_cache_1(): +def test_resolver_cache_1() -> None: # resolvers are always converted to stateless idempotent functions # subsequent calls to the same function with the same argument will always return the same value. # this is important to allow embedding of functions like time() without having the value change during @@ -215,7 +219,7 @@ def test_resolver_cache_1(): OmegaConf.clear_resolvers() -def test_resolver_cache_2(): +def test_resolver_cache_2() -> None: """ Tests that resolver cache is not shared between different OmegaConf objects """ @@ -230,7 +234,7 @@ def test_resolver_cache_2(): OmegaConf.clear_resolvers() -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "resolver,name,key,result", [ (lambda *args: args, "arg_list", "${my_resolver:cat, dog}", ("cat", "dog")), @@ -249,16 +253,19 @@ def test_resolver_cache_2(): (lambda: "zero", "zero_arg", "${my_resolver:}", "zero"), ], ) -def test_resolver_that_allows_a_list_of_arguments(resolver, name, key, result): +def test_resolver_that_allows_a_list_of_arguments( + resolver: Resolver, name: str, key: str, result: Any +) -> None: try: OmegaConf.register_resolver("my_resolver", resolver) c = OmegaConf.create({name: key}) + assert isinstance(c, DictConfig) assert c[name] == result finally: OmegaConf.clear_resolvers() -def test_copy_cache(): +def test_copy_cache() -> None: OmegaConf.register_resolver("random", lambda _: random.randint(0, 10000000)) c1 = OmegaConf.create(dict(k="${random:_}")) assert c1.k == c1.k @@ -275,7 +282,7 @@ def test_copy_cache(): assert c3.k == c1.k -def test_supported_chars(): +def test_supported_chars() -> None: supported_chars = "%_-abc123." c = OmegaConf.create(dict(dir1="${copy:" + supported_chars + "}")) @@ -283,30 +290,33 @@ def test_supported_chars(): assert c.dir1 == supported_chars -def test_interpolation_in_list_key_error(): +def test_interpolation_in_list_key_error() -> None: # Test that a KeyError is thrown if an str_interpolation key is not available c = OmegaConf.create(["${10}"]) + assert isinstance(c, ListConfig) with pytest.raises(KeyError): c[0] -def test_unsupported_interpolation_type(): +def test_unsupported_interpolation_type() -> None: c = OmegaConf.create(dict(foo="${wrong_type:ref}")) with pytest.raises(ValueError): c.foo -def test_incremental_dict_with_interpolation(): +def test_incremental_dict_with_interpolation() -> None: conf = OmegaConf.create() + assert isinstance(conf, DictConfig) conf.a = 1 conf.b = OmegaConf.create() + assert isinstance(conf.b, DictConfig) conf.b.c = "${a}" - assert conf.b.c == conf.a + assert conf.b.c == conf.a # type: ignore -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "cfg,key,expected", [ ({"a": 10, "b": "${a}"}, "b", 10), @@ -317,6 +327,6 @@ def test_incremental_dict_with_interpolation(): ({"list": ["${ref}"], "ref": "bar"}, "list.0", "bar"), ], ) -def test_interpolations(cfg, key, expected): - cfg = OmegaConf.create(cfg) - assert cfg.select(key) == expected +def test_interpolations(cfg: DictConfig, key: str, expected: Any) -> None: + c = OmegaConf.create(cfg) + assert c.select(key) == expected diff --git a/tests/test_merge.py b/tests/test_merge.py index 5e726e02b..f8741e188 100644 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -1,9 +1,8 @@ from dataclasses import dataclass, field -from typing import Dict +from typing import Any, Dict, Tuple import pytest - -from omegaconf import MISSING, OmegaConf, nodes +from omegaconf import MISSING, DictConfig, OmegaConf, nodes @dataclass @@ -17,7 +16,7 @@ class Users: name2user: Dict[str, User] = field(default_factory=lambda: {}) -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "inputs, expected", [ # dictionaries @@ -75,7 +74,7 @@ class Users: ), ], ) -def test_merge(inputs, expected): +def test_merge(inputs: Any, expected: Any) -> None: configs = [OmegaConf.create(c) for c in inputs] merged = OmegaConf.merge(*configs) assert merged == expected @@ -84,21 +83,21 @@ def test_merge(inputs, expected): for i in range(len(inputs)): input_i = OmegaConf.create(inputs[i]) orig = OmegaConf.to_container(input_i, resolve=False) - merged = OmegaConf.to_container(configs[i], resolve=False) - assert orig == merged + merged2 = OmegaConf.to_container(configs[i], resolve=False) + assert orig == merged2 -def test_primitive_dicts(): +def test_primitive_dicts() -> None: c1 = {"a": 10} c2 = {"b": 20} merged = OmegaConf.merge(c1, c2) assert merged == {"a": 10, "b": 20} -# like above but don't verify merge does not change because even eq does not work no tuples because we convert -# them to a list -@pytest.mark.parametrize("a_, b_, expected", [((1, 2, 3), (4, 5, 6), [4, 5, 6])]) -def test_merge_no_eq_verify(a_, b_, expected): +@pytest.mark.parametrize("a_, b_, expected", [((1, 2, 3), (4, 5, 6), [4, 5, 6])]) # type: ignore +def test_merge_no_eq_verify( + a_: Tuple[int], b_: Tuple[int], expected: Tuple[int] +) -> None: a = OmegaConf.create(a_) b = OmegaConf.create(b_) c = OmegaConf.merge(a, b) @@ -106,15 +105,16 @@ def test_merge_no_eq_verify(a_, b_, expected): assert expected == c -def test_merge_with_1(): +def test_merge_with_1() -> None: a = OmegaConf.create() b = OmegaConf.create(dict(a=1, b=2)) a.merge_with(b) assert a == b -def test_merge_with_2(): +def test_merge_with_2() -> None: a = OmegaConf.create() + assert isinstance(a, DictConfig) a.inner = {} b = OmegaConf.create( """ @@ -122,11 +122,11 @@ def test_merge_with_2(): b : 2 """ ) - a.inner.merge_with(b) + a.inner.merge_with(b) # type: ignore assert a.inner == b -def test_3way_dict_merge(): +def test_3way_dict_merge() -> None: c1 = OmegaConf.create("{a: 1, b: 2}") c2 = OmegaConf.create("{b: 3}") c3 = OmegaConf.create("{a: 2, c: 3}") @@ -134,14 +134,14 @@ def test_3way_dict_merge(): assert {"a": 2, "b": 3, "c": 3} == c4 -def test_merge_list_list(): +def test_merge_list_list() -> None: a = OmegaConf.create([1, 2, 3]) b = OmegaConf.create([4, 5, 6]) a.merge_with(b) assert a == b -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "base, merge, exception", [ ({}, [], TypeError), @@ -150,17 +150,20 @@ def test_merge_list_list(): (dict(a=10), None, ValueError), ], ) -def test_merge_error(base, merge, exception): +def test_merge_error(base: Any, merge: Any, exception: Any) -> None: base = OmegaConf.create(base) merge = None if merge is None else OmegaConf.create(merge) with pytest.raises(exception): OmegaConf.merge(base, merge) -def test_parent_maintained(): +def test_parent_maintained() -> None: c1 = OmegaConf.create(dict(a=dict(b=10))) c2 = OmegaConf.create(dict(aa=dict(bb=100))) c3 = OmegaConf.merge(c1, c2) + assert isinstance(c1, DictConfig) + assert isinstance(c2, DictConfig) + assert isinstance(c3, DictConfig) assert id(c1.a.parent) == id(c1) assert id(c2.aa.parent) == id(c2) assert id(c3.a.parent) == id(c3) diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 2d0b84ee3..5f6b02a2a 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -1,9 +1,8 @@ import copy from enum import Enum -from typing import Any +from typing import Any, Dict, Tuple, Type import pytest - from omegaconf import ( AnyNode, BooleanNode, @@ -14,12 +13,13 @@ ListConfig, OmegaConf, StringNode, + ValueNode, ) from omegaconf.errors import ValidationError # testing valid conversions -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "type_,input_,output_", [ # string @@ -57,7 +57,7 @@ (BooleanNode, 0, False), ], ) -def test_valid_inputs(type_, input_, output_): +def test_valid_inputs(type_: type, input_: Any, output_: Any) -> None: node = type_(input_) assert node == output_ assert node == node @@ -67,7 +67,7 @@ def test_valid_inputs(type_, input_, output_): # testing invalid conversions -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "type_,input_", [ (IntegerNode, "abc"), @@ -79,7 +79,7 @@ def test_valid_inputs(type_, input_, output_): (BooleanNode, "Yup"), ], ) -def test_invalid_inputs(type_, input_): +def test_invalid_inputs(type_: type, input_: Any) -> None: empty_node = type_() with pytest.raises(ValidationError): empty_node.set_value(input_) @@ -88,7 +88,7 @@ def test_invalid_inputs(type_, input_): type_(input_) -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "input_, expected_type", [ ({}, DictConfig), @@ -100,14 +100,16 @@ def test_invalid_inputs(type_, input_): ("str", AnyNode), ], ) -def test_assigned_value_node_type(input_, expected_type): +def test_assigned_value_node_type(input_: type, expected_type: Any) -> None: c = OmegaConf.create() + assert isinstance(c, DictConfig) c.foo = input_ assert type(c.get_node("foo")) == expected_type -def test_get_node_no_validate_access(): +def test_get_node_no_validate_access() -> None: c = OmegaConf.create({"foo": "bar"}) + assert isinstance(c, DictConfig) OmegaConf.set_struct(c, True) with pytest.raises(AttributeError): c.get_node("zoo", validate_access=True) @@ -115,13 +117,14 @@ def test_get_node_no_validate_access(): assert c.get_node("zoo", validate_access=False) is None assert ( - c.get_node("zoo", validate_access=False, default_value="default") == "default" + c.get_node("zoo", validate_access=False, default_value="default") == "default" # type: ignore ) # dict -def test_dict_any(): +def test_dict_any() -> None: c = OmegaConf.create() + assert isinstance(c, DictConfig) # default type is Any c.foo = 10 c[Enum1.FOO] = "bar" @@ -134,16 +137,18 @@ def test_dict_any(): assert type(c.get_node(Enum1.FOO)) == AnyNode -def test_dict_integer_1(): +def test_dict_integer_1() -> None: c = OmegaConf.create() + assert isinstance(c, DictConfig) c.foo = IntegerNode(10) assert type(c.get_node("foo")) == IntegerNode assert c.foo == 10 # list -def test_list_any(): +def test_list_any() -> None: c = OmegaConf.create([]) + assert isinstance(c, ListConfig) # default type is Any c.append(10) assert c[0] == 10 @@ -152,16 +157,18 @@ def test_list_any(): assert c[0] == "string" -def test_list_integer(): +def test_list_integer() -> None: val = 10 c = OmegaConf.create([]) + assert isinstance(c, ListConfig) c.append(IntegerNode(val)) assert type(c.get_node(0)) == IntegerNode assert c.get(0) == val -def test_list_integer_rejects_string(): +def test_list_integer_rejects_string() -> None: c = OmegaConf.create([]) + assert isinstance(c, ListConfig) c.append(IntegerNode(10)) assert c.get(0) == 10 with pytest.raises(ValidationError): @@ -171,7 +178,7 @@ def test_list_integer_rejects_string(): # Test merge raises validation error -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "c1, c2", [ (dict(a=IntegerNode(10)), dict(a="str")), @@ -180,7 +187,7 @@ def test_list_integer_rejects_string(): (dict(foo=dict(bar=IntegerNode(10))), dict(foo=dict(bar="str"))), ], ) -def test_merge_validation_error(c1, c2): +def test_merge_validation_error(c1: Dict[str, Any], c2: Dict[str, Any]) -> None: conf1 = OmegaConf.create(c1) conf2 = OmegaConf.create(c2) with pytest.raises(ValidationError): @@ -190,7 +197,7 @@ def test_merge_validation_error(c1, c2): assert conf2 == OmegaConf.create(c2) -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "type_,valid_value, invalid_value", [ (IntegerNode, 1, "invalid"), @@ -200,12 +207,15 @@ def test_merge_validation_error(c1, c2): (StringNode, "blah", None), ], ) -def test_accepts_mandatory_missing(type_, valid_value, invalid_value): +def test_accepts_mandatory_missing( + type_: type, valid_value: Any, invalid_value: Any +) -> None: node = type_() node.set_value("???") assert node.value() == "???" conf = OmegaConf.create({"foo": node}) + assert isinstance(conf, DictConfig) assert "foo" not in conf assert type(conf.get_node("foo")) == type_ @@ -230,10 +240,10 @@ class Enum2(Enum): NOT_BAR = 2 -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "type_", [BooleanNode, EnumNode, FloatNode, IntegerNode, StringNode, AnyNode] ) -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "values, success_map", [ ( @@ -282,7 +292,9 @@ class Enum2(Enum): ), ], ) -def test_legal_assignment(type_, values, success_map): +def test_legal_assignment( + type_: type, values: Any, success_map: Dict[Any, Dict[str, Any]] +) -> None: if not isinstance(values, (list, tuple)): values = [values] @@ -299,7 +311,7 @@ def test_legal_assignment(type_, values, success_map): type_(value) -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "node,value", [ (IntegerNode(), "foo"), @@ -308,15 +320,15 @@ def test_legal_assignment(type_, values, success_map): (EnumNode(enum_type=Enum1), "foo"), ], ) -def test_illegal_assignment(node, value): +def test_illegal_assignment(node: ValueNode, value: Any) -> None: with pytest.raises(ValidationError): node.set_value(value) -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "node_type", [BooleanNode, EnumNode, FloatNode, IntegerNode, StringNode, AnyNode] ) -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "enum_type, values, success_map", [ ( @@ -326,7 +338,12 @@ def test_illegal_assignment(node, value): ) ], ) -def test_legal_assignment_enum(node_type, enum_type, values, success_map): +def test_legal_assignment_enum( + node_type: Type[EnumNode], + enum_type: Type[Enum], + values: Tuple[Any], + success_map: Dict[Any, Any], +) -> None: assert isinstance(values, (list, tuple)) for value in values: @@ -342,8 +359,9 @@ def test_legal_assignment_enum(node_type, enum_type, values, success_map): node_type(enum_type) -def test_pretty_with_enum(): +def test_pretty_with_enum() -> None: cfg = OmegaConf.create() + assert isinstance(cfg, DictConfig) cfg.foo = EnumNode(Enum1) cfg.foo = Enum1.FOO @@ -356,8 +374,8 @@ class DummyEnum(Enum): FOO = 1 -@pytest.mark.parametrize("is_optional", [True, False]) -@pytest.mark.parametrize( +@pytest.mark.parametrize("is_optional", [True, False]) # type: ignore +@pytest.mark.parametrize( # type: ignore "type_,value, expected_type", [ (Any, 10, AnyNode), @@ -368,7 +386,9 @@ class DummyEnum(Enum): (str, "foo", StringNode), ], ) -def test_node_wrap(type_, is_optional, value, expected_type): +def test_node_wrap( + type_: type, is_optional: bool, value: Any, expected_type: Any +) -> None: from omegaconf.omegaconf import _node_wrap ret = _node_wrap(type_=type_, value=value, is_optional=is_optional, parent=None) @@ -382,7 +402,7 @@ def test_node_wrap(type_, is_optional, value, expected_type): assert ret == None # noqa E711 -def test_node_wrap_illegal_type(): +def test_node_wrap_illegal_type() -> None: class UserClass: pass @@ -392,7 +412,7 @@ class UserClass: _node_wrap(type_=UserClass, value=UserClass(), is_optional=False, parent=None) -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "obj", [ StringNode(), @@ -406,13 +426,13 @@ class UserClass: OmegaConf.create({"foo": "foo"}), ], ) -def test_deepcopy(obj): +def test_deepcopy(obj: Any) -> None: cp = copy.deepcopy(obj) assert cp == obj assert id(cp) != id(obj) -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "node, value, expected", [ (StringNode(), None, True), @@ -435,7 +455,7 @@ def test_deepcopy(obj): (BooleanNode(False), False, True), ], ) -def test_eq(node, value, expected): +def test_eq(node: ValueNode, value: Any, expected: Any) -> None: assert (node == value) == expected assert (node != value) != expected assert (value == node) == expected diff --git a/tests/test_readonly.py b/tests/test_readonly.py index ffd01c62f..6abac1931 100644 --- a/tests/test_readonly.py +++ b/tests/test_readonly.py @@ -1,12 +1,12 @@ import re +from typing import Any, Callable, Dict, List, Union import pytest +from omegaconf import DictConfig, ListConfig, OmegaConf, ReadonlyConfigError from pytest import raises -from omegaconf import OmegaConf, ReadonlyConfigError - -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "src, func, expectation", [ ({}, lambda c: c.__setitem__("a", 1), raises(ReadonlyConfigError, match="a")), @@ -15,7 +15,11 @@ lambda c: c.__getattr__("a").__getattr__("b").__setitem__("c", 1), raises(ReadonlyConfigError, match="a.b.c"), ), - ({}, lambda c: c.update("a.b", 10), raises(ReadonlyConfigError, match="a")), + ( + {}, + lambda c: c.update_node("a.b", 10), + raises(ReadonlyConfigError, match="a"), + ), ( dict(a=10), lambda c: c.__setattr__("a", 1), @@ -29,12 +33,18 @@ ), # list ([], lambda c: c.__setitem__(0, 1), raises(ReadonlyConfigError, match="0")), - ([], lambda c: c.update("0.b", 10), raises(ReadonlyConfigError, match="[0]")), + ( + [], + lambda c: c.update_node("0.b", 10), + raises(ReadonlyConfigError, match="[0]"), + ), ([10], lambda c: c.pop(), raises(ReadonlyConfigError)), ([0], lambda c: c.__delitem__(0), raises(ReadonlyConfigError, match="[0]")), ], ) -def test_readonly(src, func, expectation): +def test_readonly( + src: Union[Dict[str, Any], List[Any]], func: Callable[[Any], Any], expectation: Any +) -> None: c = OmegaConf.create(src) OmegaConf.set_readonly(c, True) with expectation: @@ -42,8 +52,8 @@ def test_readonly(src, func, expectation): assert c == src -@pytest.mark.parametrize("src", [{}, []]) -def test_readonly_flag(src): +@pytest.mark.parametrize("src", [{}, []]) # type: ignore +def test_readonly_flag(src: Union[Dict[str, Any], List[Any]]) -> None: c = OmegaConf.create(src) assert not OmegaConf.is_readonly(c) OmegaConf.set_readonly(c, True) @@ -54,8 +64,9 @@ def test_readonly_flag(src): assert not OmegaConf.is_readonly(c) -def test_readonly_nested_list(): +def test_readonly_nested_list() -> None: c = OmegaConf.create([[1]]) + assert isinstance(c, ListConfig) assert not OmegaConf.is_readonly(c) assert not OmegaConf.is_readonly(c[0]) OmegaConf.set_readonly(c, True) @@ -72,7 +83,7 @@ def test_readonly_nested_list(): assert OmegaConf.is_readonly(c[0]) -def test_readonly_list_insert(): +def test_readonly_list_insert() -> None: c = OmegaConf.create([]) OmegaConf.set_readonly(c, True) with raises(ReadonlyConfigError, match="[0]"): @@ -80,16 +91,17 @@ def test_readonly_list_insert(): assert c == [] -def test_readonly_list_insert_deep(): - src = [dict(a=[dict(b=[])])] +def test_readonly_list_insert_deep() -> None: + src: List[Dict[str, Any]] = [dict(a=[dict(b=[])])] c = OmegaConf.create(src) + assert isinstance(c, ListConfig) OmegaConf.set_readonly(c, True) with raises(ReadonlyConfigError, match=re.escape("[0].a[0].b[0]")): c[0].a[0].b.insert(0, 10) assert c == src -def test_readonly_list_append(): +def test_readonly_list_append() -> None: c = OmegaConf.create([]) OmegaConf.set_readonly(c, True) with raises(ReadonlyConfigError, match="[0]"): @@ -97,40 +109,45 @@ def test_readonly_list_append(): assert c == [] -def test_readonly_list_change_item(): +def test_readonly_list_change_item() -> None: c = OmegaConf.create([1, 2, 3]) + assert isinstance(c, ListConfig) OmegaConf.set_readonly(c, True) with raises(ReadonlyConfigError, match="[1]"): c[1] = 10 assert c == [1, 2, 3] -def test_readonly_list_pop(): +def test_readonly_list_pop() -> None: c = OmegaConf.create([1, 2, 3]) + assert isinstance(c, ListConfig) OmegaConf.set_readonly(c, True) with raises(ReadonlyConfigError, match="[1]"): c.pop(1) assert c == [1, 2, 3] -def test_readonly_list_del(): +def test_readonly_list_del() -> None: c = OmegaConf.create([1, 2, 3]) + assert isinstance(c, ListConfig) OmegaConf.set_readonly(c, True) with raises(ReadonlyConfigError, match="[1]"): del c[1] assert c == [1, 2, 3] -def test_readonly_list_sort(): +def test_readonly_list_sort() -> None: c = OmegaConf.create([3, 1, 2]) + assert isinstance(c, ListConfig) OmegaConf.set_readonly(c, True) with raises(ReadonlyConfigError): c.sort() assert c == [3, 1, 2] -def test_readonly_from_cli(): +def test_readonly_from_cli() -> None: c = OmegaConf.create({"foo": {"bar": [1]}}) + assert isinstance(c, DictConfig) OmegaConf.set_readonly(c, True) cli = OmegaConf.from_dotlist(["foo.bar=[2]"]) with raises(ReadonlyConfigError, match="foo.bar"): diff --git a/tests/test_select.py b/tests/test_select.py index 94cf40ee9..dcbe65062 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -1,23 +1,24 @@ -import pytest -from pytest import raises +from typing import Optional +import pytest from omegaconf import OmegaConf +from pytest import raises -@pytest.mark.parametrize("struct", [True, False, None]) -def test_select_key_from_empty(struct): +@pytest.mark.parametrize("struct", [True, False, None]) # type: ignore +def test_select_key_from_empty(struct: Optional[bool]) -> None: c = OmegaConf.create() OmegaConf.set_struct(c, struct) assert c.select("not_there") is None -def test_select_dotkey_from_empty(): +def test_select_dotkey_from_empty() -> None: c = OmegaConf.create() assert c.select("not.there") is None assert c.select("still.not.there") is None -def test_select_from_dict(): +def test_select_from_dict() -> None: c = OmegaConf.create(dict(a=dict(v=1), b=dict(v=1))) assert c.select("a") == {"v": 1} @@ -26,12 +27,12 @@ def test_select_from_dict(): assert c.select("nope") is None -def test_select_from_empty_list(): +def test_select_from_empty_list() -> None: c = OmegaConf.create([]) assert c.select("0") is None -def test_select_from_primitive_list(): +def test_select_from_primitive_list() -> None: c = OmegaConf.create([1, 2, 3, "4"]) assert c.select("0") == 1 assert c.select("1") == 2 @@ -39,7 +40,7 @@ def test_select_from_primitive_list(): assert c.select("3") == "4" -def test_select_from_dict_in_list(): +def test_select_from_dict_in_list() -> None: c = OmegaConf.create([1, dict(a=10, c=["foo", "bar"])]) assert c.select("0") == 1 assert c.select("1.a") == 10 @@ -48,7 +49,7 @@ def test_select_from_dict_in_list(): assert c.select("1.c.1") == "bar" -def test_list_select_non_int_key(): +def test_list_select_non_int_key() -> None: c = OmegaConf.create([1, 2, 3]) with raises(TypeError): c.select("a") diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 27a2124bf..1e17fe002 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -2,20 +2,20 @@ import io import os import tempfile +from typing import Any, Dict import pytest +from omegaconf import Container, OmegaConf -from omegaconf import OmegaConf - -def save_load_from_file(conf, resolve, expected): +def save_load_from_file(conf: Container, resolve: bool, expected: Any) -> None: if expected is None: expected = conf try: with tempfile.NamedTemporaryFile( mode="wt", delete=False, encoding="utf-8" ) as fp: - OmegaConf.save(conf, fp.file, resolve=resolve) + OmegaConf.save(conf, fp.file, resolve=resolve) # type: ignore with io.open(os.path.abspath(fp.name), "rt", encoding="utf-8") as handle: c2 = OmegaConf.load(handle) assert c2 == expected @@ -23,7 +23,7 @@ def save_load_from_file(conf, resolve, expected): os.unlink(fp.name) -def save_load_from_filename(conf, resolve, expected): +def save_load_from_filename(conf: Container, resolve: bool, expected: Any) -> None: if expected is None: expected = conf # note that delete=False here is a work around windows incompetence. @@ -36,13 +36,13 @@ def save_load_from_filename(conf, resolve, expected): os.unlink(fp.name) -def test_load_from_invalid(): +def test_load_from_invalid() -> None: with pytest.raises(TypeError): - OmegaConf.load(3.1415) + OmegaConf.load(3.1415) # type: ignore @pytest.mark.parametrize( - "cfg,resolve,expected", + "input_,resolve,expected", [ (dict(a=10), False, None), ({"foo": 10, "bar": "${foo}"}, False, None), @@ -51,21 +51,25 @@ def test_load_from_invalid(): ], ) class TestSaveLoad: - def test_save_load__from_file(self, cfg, resolve, expected): - cfg = OmegaConf.create(cfg) + def test_save_load__from_file( + self, input_: Dict[str, Any], resolve: bool, expected: Any + ) -> None: + cfg = OmegaConf.create(input_) save_load_from_file(cfg, resolve, expected) - def test_save_load__from_filename(self, cfg, resolve, expected): - cfg = OmegaConf.create(cfg) + def test_save_load__from_filename( + self, input_: Dict[str, Any], resolve: bool, expected: Any + ) -> None: + cfg = OmegaConf.create(input_) save_load_from_filename(cfg, resolve, expected) -def test_save_illegal_type(): +def test_save_illegal_type() -> None: with pytest.raises(TypeError): - OmegaConf.save(OmegaConf.create(), 1000) + OmegaConf.save(OmegaConf.create(), 1000) # type: ignore -def test_pickle_dict(): +def test_pickle_dict() -> None: with tempfile.TemporaryFile() as fp: import pickle @@ -77,7 +81,7 @@ def test_pickle_dict(): assert c == c1 -def test_pickle_list(): +def test_pickle_list() -> None: with tempfile.TemporaryFile() as fp: import pickle diff --git a/tests/test_serialization_deprecated_save.py b/tests/test_serialization_deprecated_save.py index 85bbed69c..f9531e611 100644 --- a/tests/test_serialization_deprecated_save.py +++ b/tests/test_serialization_deprecated_save.py @@ -3,15 +3,14 @@ import os import tempfile +from omegaconf import DictConfig, OmegaConf from pytest import raises -from omegaconf import OmegaConf - -def save_load_file_deprecated(conf): +def save_load_file_deprecated(conf: DictConfig) -> None: try: with tempfile.NamedTemporaryFile(mode="wt", delete=False) as fp: - conf.save(fp.file) + conf.save(fp.file) # type: ignore with io.open(os.path.abspath(fp.name), "rt") as handle: c2 = OmegaConf.load(handle) assert conf == c2 @@ -19,7 +18,7 @@ def save_load_file_deprecated(conf): os.unlink(fp.name) -def save_load_filename(conf): +def save_load_filename(conf: DictConfig) -> None: # note that delete=False here is a work around windows incompetence. try: with tempfile.NamedTemporaryFile(delete=False) as fp: @@ -31,12 +30,10 @@ def save_load_filename(conf): # Test deprecated config.save() - - -def save_load__from_file_deprecated(conf): +def save_load__from_file_deprecated(conf: DictConfig) -> None: try: with tempfile.NamedTemporaryFile(mode="wt", delete=False) as fp: - conf.save(fp.file) + conf.save(fp.file) # type: ignore with io.open(os.path.abspath(fp.name), "rt") as handle: c2 = OmegaConf.load(handle) assert conf == c2 @@ -44,7 +41,7 @@ def save_load__from_file_deprecated(conf): os.unlink(fp.name) -def save_load__from_filename_deprecated(conf): +def save_load__from_filename_deprecated(conf: DictConfig) -> None: # note that delete=False here is a work around windows incompetence. try: with tempfile.NamedTemporaryFile(delete=False) as fp: @@ -55,30 +52,42 @@ def save_load__from_filename_deprecated(conf): os.unlink(fp.name) -def test_save_load_file_deprecated(): - save_load_file_deprecated(OmegaConf.create(dict(a=10))) +def test_save_load_file_deprecated() -> None: + cfg = OmegaConf.create(dict(a=10)) + assert isinstance(cfg, DictConfig) + save_load_file_deprecated(cfg) -def test_save_load_filename_deprecated(): - save_load_filename(OmegaConf.create(dict(a=10))) +def test_save_load_filename_deprecated() -> None: + cfg = OmegaConf.create(dict(a=10)) + assert isinstance(cfg, DictConfig) + save_load_filename(cfg) -def test_save_load__from_file_deprecated(): - save_load__from_file_deprecated(OmegaConf.create(dict(a=10))) +def test_save_load__from_file_deprecated() -> None: + cfg = OmegaConf.create(dict(a=10)) + assert isinstance(cfg, DictConfig) + save_load__from_file_deprecated(cfg) -def test_save_load__from_filename_deprecated(): - save_load__from_filename_deprecated(OmegaConf.create(dict(a=10))) +def test_save_load__from_filename_deprecated() -> None: + cfg = OmegaConf.create(dict(a=10)) + assert isinstance(cfg, DictConfig) + save_load__from_filename_deprecated(cfg) -def test_save_illegal_type_deprecated(): +def test_save_illegal_type_deprecated() -> None: with raises(TypeError): - OmegaConf.create().save(1000) + OmegaConf.create().save(1000) # type: ignore -def test_save_load_file(): - save_load_file_deprecated(OmegaConf.create(dict(a=10))) +def test_save_load_file() -> None: + cfg = OmegaConf.create(dict(a=10)) + assert isinstance(cfg, DictConfig) + save_load_file_deprecated(cfg) -def test_save_load_filename(): - save_load_filename(OmegaConf.create(dict(a=10))) +def test_save_load_filename() -> None: + cfg = OmegaConf.create(dict(a=10)) + assert isinstance(cfg, DictConfig) + save_load_filename(cfg) diff --git a/tests/test_struct.py b/tests/test_struct.py index 29e25a4f9..23d8f4f8b 100644 --- a/tests/test_struct.py +++ b/tests/test_struct.py @@ -1,17 +1,17 @@ import re +from typing import Any, Dict import pytest - from omegaconf import OmegaConf -def test_struct_default(): +def test_struct_default() -> None: c = OmegaConf.create() assert c.not_found is None assert OmegaConf.is_struct(c) is None -def test_struct_set_on_dict(): +def test_struct_set_on_dict() -> None: c = OmegaConf.create(dict(a=dict())) OmegaConf.set_struct(c, True) # Throwing when it hits foo, so exception key is a.foo and not a.foo.bar @@ -20,7 +20,7 @@ def test_struct_set_on_dict(): c.a.foo.bar -def test_struct_set_on_nested_dict(): +def test_struct_set_on_nested_dict() -> None: c = OmegaConf.create(dict(a=dict(b=10))) OmegaConf.set_struct(c, True) with pytest.raises(AttributeError): @@ -34,23 +34,25 @@ def test_struct_set_on_nested_dict(): c.a.foo -def test_merge_dotlist_into_struct(): +def test_merge_dotlist_into_struct() -> None: c = OmegaConf.create(dict(a=dict(b=10))) OmegaConf.set_struct(c, True) with pytest.raises(AttributeError, match=re.escape("foo")): c.merge_with_dotlist(["foo=1"]) -@pytest.mark.parametrize("base, merged", [(dict(), dict(a=10))]) -def test_merge_config_with_struct(base, merged): - base = OmegaConf.create(base) - merged = OmegaConf.create(merged) +@pytest.mark.parametrize("in_base, in_merged", [(dict(), dict(a=10))]) # type: ignore +def test_merge_config_with_struct( + in_base: Dict[str, Any], in_merged: Dict[str, Any] +) -> None: + base = OmegaConf.create(in_base) + merged = OmegaConf.create(in_merged) OmegaConf.set_struct(base, True) with pytest.raises(AttributeError): OmegaConf.merge(base, merged) -def test_struct_contain_missing(): +def test_struct_contain_missing() -> None: c = OmegaConf.create(dict()) OmegaConf.set_struct(c, True) assert "foo" not in c diff --git a/tests/test_structured_config.py b/tests/test_structured_config.py index e0755d024..0e3da7dc9 100644 --- a/tests/test_structured_config.py +++ b/tests/test_structured_config.py @@ -1,11 +1,11 @@ import sys from importlib import import_module -from typing import Any +from typing import Any, Dict import pytest - from omegaconf import ( AnyNode, + DictConfig, MissingMandatoryValue, OmegaConf, ReadonlyConfigError, @@ -85,18 +85,18 @@ class AnyTypeConfigAssignments: @pytest.mark.skipif(sys.version_info < (3, 6), reason="requires python3.6 or higher") @pytest.mark.parametrize("class_type", ["dataclass_test_data", "attr_test_data"]) class TestConfigs: - def test_nested_config_errors_on_missing(self, class_type): - module = import_module("tests.structured_conf." + class_type) + def test_nested_config_errors_on_missing(self, class_type: str) -> None: + module: Any = import_module("tests.structured_conf." + class_type) with pytest.raises(ValueError): - OmegaConf.create(module.ErrorOnMissingNestedConfig) + OmegaConf.structured(module.ErrorOnMissingNestedConfig) - def test_nested_config_errors_on_none(self, class_type): - module = import_module("tests.structured_conf." + class_type) + def test_nested_config_errors_on_none(self, class_type: str) -> None: + module: Any = import_module("tests.structured_conf." + class_type) with pytest.raises(ValueError): - OmegaConf.create(module.ErrorOnNoneNestedConfig) + OmegaConf.structured(module.ErrorOnNoneNestedConfig) - def test_nested_config(self, class_type): - def validate(cfg): + def test_nested_config(self, class_type: str) -> None: + def validate(cfg: DictConfig) -> None: assert cfg == { "default_value": { "with_default": 10, @@ -113,63 +113,61 @@ def validate(cfg): "value_at_root": 1000, } - module = import_module("tests.structured_conf." + class_type) + module: Any = import_module("tests.structured_conf." + class_type) - conf1 = OmegaConf.create(module.NestedConfig) + conf1 = OmegaConf.structured(module.NestedConfig) validate(conf1) - conf1 = OmegaConf.create(module.NestedConfig(default_value=module.Nested())) + conf1 = OmegaConf.structured(module.NestedConfig(default_value=module.Nested())) validate(conf1) - def test_no_default_errors(self, class_type): - module = import_module("tests.structured_conf." + class_type) + def test_no_default_errors(self, class_type: str) -> None: + module: Any = import_module("tests.structured_conf." + class_type) with pytest.raises(ValueError): - OmegaConf.create(module.NoDefaultErrors) + OmegaConf.structured(module.NoDefaultErrors) - def test_config_with_list(self, class_type): - module = import_module("tests.structured_conf." + class_type) + def test_config_with_list(self, class_type: str) -> None: + module: Any = import_module("tests.structured_conf." + class_type) - def validate(cfg): + def validate(cfg: DictConfig) -> None: assert cfg == { "list1": [1, 2, 3], "list2": [1, 2, 3], - "list3": [1, 2, 3], - "list4": [1, 2, 3], } with pytest.raises(ValidationError): - cfg.list4[1] = "foo" + cfg.list1[1] = "foo" - conf1 = OmegaConf.create(module.ConfigWithList) + conf1 = OmegaConf.structured(module.ConfigWithList) validate(conf1) - conf1 = OmegaConf.create(module.ConfigWithList()) + conf1 = OmegaConf.structured(module.ConfigWithList()) validate(conf1) - def test_assignment_to_nested_structured_config(self, class_type): - module = import_module("tests.structured_conf." + class_type) - conf = OmegaConf.create(module.NestedConfig) + def test_assignment_to_nested_structured_config(self, class_type: str) -> None: + module: Any = import_module("tests.structured_conf." + class_type) + conf = OmegaConf.structured(module.NestedConfig) with pytest.raises(ValidationError): conf.default_value = 10 conf.default_value = module.Nested() - def test_config_with_dict(self, class_type): - module = import_module("tests.structured_conf." + class_type) + def test_config_with_dict(self, class_type: str) -> None: + module: Any = import_module("tests.structured_conf." + class_type) - def validate(cfg): - assert cfg == {"dict1": {"foo": "bar"}, "dict2": {"foo": "bar"}} + def validate(cfg: DictConfig) -> None: + assert cfg == {"dict1": {"foo": "bar"}} - conf1 = OmegaConf.create(module.ConfigWithDict) + conf1 = OmegaConf.structured(module.ConfigWithDict) validate(conf1) - conf1 = OmegaConf.create(module.ConfigWithDict()) + conf1 = OmegaConf.structured(module.ConfigWithDict()) validate(conf1) - conf1 = OmegaConf.create(module.ConfigWithDict()) + conf1 = OmegaConf.structured(module.ConfigWithDict()) validate(conf1) - def test_structured_config_struct_behavior(self, class_type): - module = import_module("tests.structured_conf." + class_type) + def test_structured_config_struct_behavior(self, class_type: str) -> None: + module: Any = import_module("tests.structured_conf." + class_type) - def validate(cfg): + def validate(cfg: DictConfig) -> None: assert not OmegaConf.is_struct(cfg) with pytest.raises(KeyError): # noinspection PyStatementEffect @@ -184,12 +182,12 @@ def validate(cfg): cfg.foo = 20 assert cfg.foo == 20 - conf = OmegaConf.create(module.ConfigWithDict) + conf = OmegaConf.structured(module.ConfigWithDict) validate(conf) - conf = OmegaConf.create(module.ConfigWithDict()) + conf = OmegaConf.structured(module.ConfigWithDict()) validate(conf) - @pytest.mark.parametrize( + @pytest.mark.parametrize( # type:ignore "tested_type,assignment_data, init_dict", [ # Use class to build config @@ -208,13 +206,17 @@ def validate(cfg): ], ) def test_field_with_default_value( - self, class_type, tested_type, init_dict, assignment_data - ): - module = import_module("tests.structured_conf." + class_type) + self, + class_type: str, + tested_type: str, + init_dict: Dict[str, Any], + assignment_data: Any, + ) -> None: + module: Any = import_module("tests.structured_conf." + class_type) input_class = getattr(module, tested_type) - def validate(input_, expected): - conf = OmegaConf.create(input_) + def validate(input_: Any, expected: Any) -> None: + conf = OmegaConf.structured(input_) # Test access assert conf.with_default == expected.with_default assert conf.null_default is None @@ -257,7 +259,7 @@ def validate(input_, expected): validate(input_class, input_class()) validate(input_class(**init_dict), input_class(**init_dict)) - @pytest.mark.parametrize( + @pytest.mark.parametrize( # type: ignore "input_init, expected_init", [ # attr class as class @@ -270,13 +272,16 @@ def validate(input_, expected): ({"int_default": 30}, {"int_default": 30}), ], ) - def test_untyped(self, class_type, input_init, expected_init): - input_ = import_module("tests.structured_conf." + class_type).AnyTypeConfig + def test_untyped( + self, class_type: str, input_init: Any, expected_init: Any + ) -> None: + module: Any = import_module("tests.structured_conf." + class_type) + input_ = module.AnyTypeConfig expected = input_(**expected_init) if input_init is not None: input_ = input_(**input_init) - conf = OmegaConf.create(input_) + conf = OmegaConf.structured(input_) assert conf.null_default == expected.null_default assert conf.int_default == expected.int_default assert conf.float_default == expected.float_default @@ -314,9 +319,10 @@ def test_untyped(self, class_type, input_init, expected_init): assert conf.str_default == val assert conf.bool_default == val - def test_interpolation(self, class_type): - input_ = import_module("tests.structured_conf." + class_type).Interpolation() - conf = OmegaConf.create(input_) + def test_interpolation(self, class_type: str) -> Any: + module: Any = import_module("tests.structured_conf." + class_type) + input_ = module.Interpolation() + conf = OmegaConf.structured(input_) assert conf.x == input_.x assert conf.z1 == conf.x assert conf.z2 == f"{conf.x}_{conf.y}" @@ -325,7 +331,7 @@ def test_interpolation(self, class_type): assert type(conf.z1) == int assert type(conf.z2) == str - @pytest.mark.parametrize( + @pytest.mark.parametrize( # type: ignore "tested_type", [ "BoolOptional", @@ -335,11 +341,11 @@ def test_interpolation(self, class_type): "EnumOptional", ], ) - def test_optional(self, class_type, tested_type): - module = import_module("tests.structured_conf." + class_type) + def test_optional(self, class_type: str, tested_type: str) -> None: + module: Any = import_module("tests.structured_conf." + class_type) input_ = getattr(module, tested_type) obj = input_(no_default=None) - conf = OmegaConf.create(input_) + conf = OmegaConf.structured(input_) assert conf.no_default is None assert conf.as_none is None assert conf.with_default == obj.with_default @@ -351,9 +357,10 @@ def test_optional(self, class_type, tested_type): with pytest.raises(ValidationError): conf.not_optional = None - def test_typed_list(self, class_type): - input_ = import_module("tests.structured_conf." + class_type).WithTypedList - conf = OmegaConf.create(input_) + def test_typed_list(self, class_type: str) -> None: + module: Any = import_module("tests.structured_conf." + class_type) + input_ = module.WithTypedList + conf = OmegaConf.structured(input_) with pytest.raises(ValidationError): conf.list[0] = "fail" @@ -364,47 +371,46 @@ def test_typed_list(self, class_type): cfg2 = OmegaConf.create({"list": ["fail"]}) OmegaConf.merge(conf, cfg2) - def test_typed_dict(self, class_type): - input_ = import_module("tests.structured_conf." + class_type).WithTypedDict - conf = OmegaConf.create(input_) + def test_typed_dict(self, class_type: str) -> None: + module: Any = import_module("tests.structured_conf." + class_type) + input_ = module.WithTypedDict + conf = OmegaConf.structured(input_) with pytest.raises(ValidationError): conf.dict["foo"] = "fail" with pytest.raises(ValidationError): OmegaConf.merge(conf, OmegaConf.create({"dict": {"foo": "fail"}})) - def test_typed_dict_key_error(self, class_type): - input_ = import_module("tests.structured_conf." + class_type).ErrorDictIntKey + def test_typed_dict_key_error(self, class_type: str) -> None: + module: Any = import_module("tests.structured_conf." + class_type) + input_ = module.ErrorDictIntKey with pytest.raises(UnsupportedKeyType): - OmegaConf.create(input_) + OmegaConf.structured(input_) - def test_typed_dict_value_error(self, class_type): - input_ = import_module( - "tests.structured_conf." + class_type - ).ErrorDictUnsupportedValue + def test_typed_dict_value_error(self, class_type: str) -> None: + module: Any = import_module("tests.structured_conf." + class_type) + input_ = module.ErrorDictUnsupportedValue with pytest.raises(ValidationError): - OmegaConf.create(input_) + OmegaConf.structured(input_) - def test_typed_list_value_error(self, class_type): - input_ = import_module( - "tests.structured_conf." + class_type - ).ErrorListUnsupportedValue + def test_typed_list_value_error(self, class_type: str) -> None: + module: Any = import_module("tests.structured_conf." + class_type) + input_ = module.ErrorListUnsupportedValue with pytest.raises(ValidationError): - OmegaConf.create(input_) + OmegaConf.structured(input_) - def test_list_examples(self, class_type): - module = import_module("tests.structured_conf." + class_type) - conf = OmegaConf.create(module.ListExamples) + def test_list_examples(self, class_type: str) -> None: + module: Any = import_module("tests.structured_conf." + class_type) + conf = OmegaConf.structured(module.ListExamples) - def test_any(name): + def test_any(name: str) -> None: conf[name].append(True) conf[name].extend([Color.RED, 3.1415]) conf[name][2] = False assert conf[name] == [1, "foo", False, Color.RED, 3.1415] # any and untyped - test_any("any1") - test_any("any2") + test_any("any") # test ints with pytest.raises(ValidationError): @@ -439,17 +445,16 @@ def test_any(name): Color.BLUE, ] - def test_dict_examples(self, class_type): - module = import_module("tests.structured_conf." + class_type) - conf = OmegaConf.create(module.DictExamples) - # any1: Dict = {"a": 1, "b": "foo"} - # any2: Dict[str, Any] = {"a": 1, "b": "foo"} + def test_dict_examples(self, class_type: str) -> None: + module: Any = import_module("tests.structured_conf." + class_type) + conf = OmegaConf.structured(module.DictExamples) + # any: Dict = {"a": 1, "b": "foo"} # ints: Dict[str, int] = {"a": 10, "b": 20} # strings: Dict[str, str] = {"a": "foo", "b": "bar"} # booleans: Dict[str, bool] = {"a": True, "b": False} # colors: Dict[str, Color] = {"red": Color.RED, "green": "GREEN", "blue": 3} - def test_any(name): + def test_any(name: str) -> None: conf[name].c = True conf[name].d = Color.RED conf[name].e = 3.1415 @@ -462,8 +467,7 @@ def test_any(name): } # any and untyped - test_any("any1") - test_any("any2") + test_any("any") # test ints with pytest.raises(ValidationError): @@ -506,9 +510,9 @@ def test_any(name): "f": Color.BLUE, } - def test_enum_key(self, class_type): - module = import_module("tests.structured_conf." + class_type) - conf = OmegaConf.create(module.DictWithEnumKeys) + def test_enum_key(self, class_type: str) -> None: + module: Any = import_module("tests.structured_conf." + class_type) + conf = OmegaConf.structured(module.DictWithEnumKeys) # When an Enum is a dictionary key the name of the Enum is actually used # as the key @@ -516,9 +520,9 @@ def test_enum_key(self, class_type): assert conf.enum_key["GREEN"] == "green" assert conf.enum_key[Color.GREEN] == "green" - def test_dict_of_objects(self, class_type): - module = import_module("tests.structured_conf." + class_type) - conf = OmegaConf.create(module.DictOfObjects) + def test_dict_of_objects(self, class_type: str) -> None: + module: Any = import_module("tests.structured_conf." + class_type) + conf = OmegaConf.structured(module.DictOfObjects) assert conf.users.joe.age == 18 assert conf.users.joe.name == "Joe" @@ -529,9 +533,9 @@ def test_dict_of_objects(self, class_type): with pytest.raises(ValidationError): conf.users.fail = "fail" - def test_list_of_objects(self, class_type): - module = import_module("tests.structured_conf." + class_type) - conf = OmegaConf.create(module.ListOfObjects) + def test_list_of_objects(self, class_type: str) -> None: + module: Any = import_module("tests.structured_conf." + class_type) + conf = OmegaConf.structured(module.ListOfObjects) assert conf.users[0].age == 18 assert conf.users[0].name == "Joe" @@ -543,7 +547,7 @@ def test_list_of_objects(self, class_type): conf.users.append("fail") -def validate_frozen_impl(conf): +def validate_frozen_impl(conf: DictConfig) -> None: with pytest.raises(ReadonlyConfigError): conf.x = 20 @@ -554,17 +558,17 @@ def validate_frozen_impl(conf): conf.user.age = 20 -@pytest.mark.skipif(sys.version_info < (3, 6), reason="requires python3.6 or higher") +@pytest.mark.skipif(sys.version_info < (3, 6), reason="requires python3.6 or higher") # type: ignore def test_attr_frozen() -> None: from tests.structured_conf.attr_test_data import FrozenClass - validate_frozen_impl(OmegaConf.create(FrozenClass)) - validate_frozen_impl(OmegaConf.create(FrozenClass())) + validate_frozen_impl(OmegaConf.structured(FrozenClass)) + validate_frozen_impl(OmegaConf.structured(FrozenClass())) -@pytest.mark.skipif(sys.version_info < (3, 6), reason="requires python3.6 or higher") +@pytest.mark.skipif(sys.version_info < (3, 6), reason="requires python3.6 or higher") # type: ignore def test_dataclass_frozen() -> None: from tests.structured_conf.dataclass_test_data import FrozenClass - validate_frozen_impl(OmegaConf.create(FrozenClass)) - validate_frozen_impl(OmegaConf.create(FrozenClass())) + validate_frozen_impl(OmegaConf.structured(FrozenClass)) + validate_frozen_impl(OmegaConf.structured(FrozenClass())) diff --git a/tests/test_update.py b/tests/test_update.py index 0b1f60e23..fe76f7ed3 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -1,128 +1,130 @@ import sys +from typing import Any, Dict, List, Union import pytest +from omegaconf import DictConfig, ListConfig, MissingMandatoryValue, OmegaConf from pytest import raises -from omegaconf import MissingMandatoryValue, OmegaConf - -def test_update_map_value(): +def test_update_map_value() -> None: # Replacing an existing key in a map s = "hello: world" c = OmegaConf.create(s) - c.update("hello", "there") + c.update_node("hello", "there") assert {"hello": "there"} == c -def test_update_map_new_keyvalue(): +def test_update_map_new_keyvalue() -> None: # Adding another key to a map s = "hello: world" c = OmegaConf.create(s) - c.update("who", "goes there") + c.update_node("who", "goes there") assert {"hello": "world", "who": "goes there"} == c -def test_update_map_to_value(): +def test_update_map_to_value() -> None: # changing map to single node s = "hello: world" c = OmegaConf.create(s) - c.update("value") + c.update_node("value") assert {"hello": "world", "value": None} == c -def test_update_with_empty_map_value(): +def test_update_with_empty_map_value() -> None: c = OmegaConf.create() - c.update("a", {}) + c.update_node("a", {}) assert {"a": {}} == c -def test_update_with_map_value(): +def test_update_with_map_value() -> None: c = OmegaConf.create() - c.update("a", {"aa": 1, "bb": 2}) + c.update_node("a", {"aa": 1, "bb": 2}) assert {"a": {"aa": 1, "bb": 2}} == c -def test_update_deep_from_empty(): +def test_update_deep_from_empty() -> None: c = OmegaConf.create() - c.update("a.b", 1) + c.update_node("a.b", 1) assert {"a": {"b": 1}} == c -def test_update_deep_with_map(): +def test_update_deep_with_map() -> None: c = OmegaConf.create("a: b") - c.update("a.b", {"c": 1}) + c.update_node("a.b", {"c": 1}) assert {"a": {"b": {"c": 1}}} == c -def test_update_deep_with_value(): +def test_update_deep_with_value() -> None: c = OmegaConf.create() - c.update("a.b", 1) + c.update_node("a.b", 1) assert {"a": {"b": 1}} == c -def test_update_deep_with_map2(): +def test_update_deep_with_map2() -> None: c = OmegaConf.create("a: 1") - c.update("b.c", 2) + c.update_node("b.c", 2) assert {"a": 1, "b": {"c": 2}} == c -def test_update_deep_with_map_update(): +def test_update_deep_with_map_update() -> None: c = OmegaConf.create("a: {b : {c: 1}}") - c.update("a.b.d", 2) + c.update_node("a.b.d", 2) assert {"a": {"b": {"c": 1, "d": 2}}} == c -def test_list_value_update(): +def test_list_value_update() -> None: # List update is always a replace because a list can be merged in too many ways c = OmegaConf.create("a: [1,2]") - c.update("a", [2, 3, 4]) + c.update_node("a", [2, 3, 4]) assert {"a": [2, 3, 4]} == c -def test_override_mandatory_value(): +def test_override_mandatory_value() -> None: c = OmegaConf.create('{a: "???"}') + assert isinstance(c, DictConfig) with raises(MissingMandatoryValue): c.get("a") - c.update("a", 123) + c.update_node("a", 123) assert {"a": 123} == c -def test_update_empty_to_value(): +def test_update_empty_to_value() -> None: """""" s = "" c = OmegaConf.create(s) - c.update("hello") + c.update_node("hello") assert {"hello": None} == c -def test_update_same_value(): +def test_update_same_value() -> None: """""" s = "hello" c = OmegaConf.create(s) - c.update("hello") + c.update_node("hello") assert {"hello": None} == c -def test_update_value_to_map(): +def test_update_value_to_map() -> None: s = "hello" c = OmegaConf.create(s) - c.update("hi", "there") + c.update_node("hi", "there") assert {"hello": None, "hi": "there"} == c -def test_update_map_empty_to_map(): +def test_update_map_empty_to_map() -> None: s = "" c = OmegaConf.create(s) - c.update("hello", "there") + c.update_node("hello", "there") assert {"hello": "there"} == c -def test_update_list(): +def test_update_list() -> None: c = OmegaConf.create([1, 2, 3]) - c.update("1", "abc") - c.update("-1", "last") + assert isinstance(c, ListConfig) + c.update_node("1", "abc") + c.update_node("-1", "last") with raises(IndexError): - c.update("4", "abc") + c.update_node("4", "abc") assert len(c) == 3 assert c[0] == 1 @@ -130,31 +132,32 @@ def test_update_list(): assert c[2] == "last" -def test_update_nested_list(): +def test_update_nested_list() -> None: c = OmegaConf.create(dict(deep=dict(list=[1, 2, 3]))) - c.update("deep.list.1", "abc") - c.update("deep.list.-1", "last") + c.update_node("deep.list.1", "abc") + c.update_node("deep.list.-1", "last") with raises(IndexError): - c.update("deep.list.4", "abc") + c.update_node("deep.list.4", "abc") assert c.deep.list[0] == 1 assert c.deep.list[1] == "abc" assert c.deep.list[2] == "last" -def test_update_list_make_dict(): +def test_update_list_make_dict() -> None: c = OmegaConf.create([None, None]) - c.update("0.a.a", "aa") - c.update("0.a.b", "ab") - c.update("1.b.a", "ba") - c.update("1.b.b", "bb") + assert isinstance(c, ListConfig) + c.update_node("0.a.a", "aa") + c.update_node("0.a.b", "ab") + c.update_node("1.b.a", "ba") + c.update_node("1.b.b", "bb") assert c[0].a.a == "aa" assert c[0].a.b == "ab" assert c[1].b.a == "ba" assert c[1].b.b == "bb" -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type:ignore "cfg,overrides,expected", [ ([1, 2, 3], ["0=bar", "2.a=100"], ["bar", 2, dict(a=100)]), @@ -162,31 +165,35 @@ def test_update_list_make_dict(): ({}, ["foo=bar=10"], {"foo": "bar=10"}), ], ) -def test_merge_with_dotlist(cfg, overrides, expected): +def test_merge_with_dotlist( + cfg: Union[List[Any], Dict[str, Any]], + overrides: List[str], + expected: Union[List[Any], Dict[str, Any]], +) -> None: c = OmegaConf.create(cfg) c.merge_with_dotlist(overrides) assert c == expected -def test_merge_with_cli(): +def test_merge_with_cli() -> None: c = OmegaConf.create([1, 2, 3]) sys.argv = ["app.py", "0=bar", "2.a=100"] c.merge_with_cli() assert c == ["bar", 2, dict(a=100)] -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type:ignore "dotlist, expected", [([], {}), (["foo=1"], {"foo": 1}), (["foo=1", "bar"], {"foo": 1, "bar": None})], ) -def test_merge_empty_with_dotlist(dotlist, expected): +def test_merge_empty_with_dotlist(dotlist: List[str], expected: Dict[str, Any]) -> None: c = OmegaConf.create() c.merge_with_dotlist(dotlist) assert c == expected -@pytest.mark.parametrize("dotlist", ["foo=10", ["foo=1", 10]]) -def test_merge_with_dotlist_errors(dotlist): +@pytest.mark.parametrize("dotlist", ["foo=10", ["foo=1", 10]]) # type:ignore +def test_merge_with_dotlist_errors(dotlist: List[str]) -> None: c = OmegaConf.create() with pytest.raises(ValueError): c.merge_with_dotlist(dotlist) diff --git a/tests/test_utils.py b/tests/test_utils.py index 1e379e818..305fe0f1b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,17 +6,15 @@ import pytest # noinspection PyProtectedMember -from omegaconf import OmegaConf, _utils +from omegaconf import DictConfig, OmegaConf, _utils from omegaconf.errors import ValidationError from omegaconf.nodes import StringNode from . import does_not_raise from .structured_conf.common import Color -# TODO: complete test coverage for utils - -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "target_type, value, expectation", [ # Any @@ -66,7 +64,7 @@ (Exception, "nope", pytest.raises(ValueError)), ], ) -def test_maybe_wrap(target_type, value, expectation): +def test_maybe_wrap(target_type: type, value: Any, expectation: Any) -> None: with expectation: from omegaconf.omegaconf import _maybe_wrap @@ -87,12 +85,8 @@ class _TestDataclass: b: bool = True f: float = 3.14 e: _TestEnum = _TestEnum.A - list1: list = field(default_factory=lambda: []) - list2: List = field(default_factory=lambda: []) - list3: List[int] = field(default_factory=lambda: []) - dict1: dict = field(default_factory=lambda: {}) - dict2: Dict = field(default_factory=lambda: {}) - dict3: Dict[str, int] = field(default_factory=lambda: {}) + list1: List[int] = field(default_factory=lambda: []) + dict1: Dict[str, int] = field(default_factory=lambda: {}) @attr.s(auto_attribs=True) @@ -102,19 +96,15 @@ class _TestAttrsClass: b: bool = True f: float = 3.14 e: _TestEnum = _TestEnum.A - list1: list = [] - list2: List = [] - list3: List[int] = [] - dict1: dict = {} - dict2: Dict = {} - dict3: Dict[str, int] = {} + list1: List[int] = [] + dict1: Dict[str, int] = {} class _TestUserClass: pass -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "type_, expected", [ (int, True), @@ -129,13 +119,13 @@ class _TestUserClass: (_TestDataclass, True), ], ) -def test_valid_value_annotation_type(type_, expected): +def test_valid_value_annotation_type(type_: type, expected: bool) -> None: from omegaconf.omegaconf import _valid_value_annotation_type assert _valid_value_annotation_type(type_) == expected -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "test_cls_or_obj, expectation", [ (_TestDataclass, does_not_raise()), @@ -145,7 +135,7 @@ def test_valid_value_annotation_type(type_, expected): ("invalid", pytest.raises(ValueError)), ], ) -def test_get_structured_config_data(test_cls_or_obj, expectation): +def test_get_structured_config_data(test_cls_or_obj: Any, expectation: Any) -> None: with expectation: d = _utils.get_structured_config_data(test_cls_or_obj) assert d["x"] == 10 @@ -154,14 +144,10 @@ def test_get_structured_config_data(test_cls_or_obj, expectation): assert d["f"] == 3.14 assert d["e"] == _TestEnum.A assert d["list1"] == [] - assert d["list2"] == [] - assert d["list3"] == [] assert d["dict1"] == {} - assert d["dict2"] == {} - assert d["dict3"] == {} -def test_is_dataclass(mocker): +def test_is_dataclass(mocker: Any) -> None: @dataclass class Foo: pass @@ -174,7 +160,7 @@ class Foo: assert not _utils.is_dataclass(10) -def test_is_attr_class(mocker): +def test_is_attr_class(mocker: Any) -> None: @attr.s class Foo: pass @@ -187,7 +173,7 @@ class Foo: assert not _utils.is_attr_class(10) -def test_is_structured_config_frozen_with_invalid_obj(): +def test_is_structured_config_frozen_with_invalid_obj() -> None: with pytest.raises(ValueError): _utils.is_structured_config_frozen(10) @@ -197,7 +183,7 @@ class Dataclass: pass -@pytest.mark.parametrize( +@pytest.mark.parametrize( # type: ignore "value,kind", [ ("foo", _utils.ValueKind.VALUE), @@ -212,26 +198,27 @@ class Dataclass: ("ftp://${host}/path", _utils.ValueKind.STR_INTERPOLATION), ], ) -def test_value_kind(value, kind): +def test_value_kind(value: Any, kind: _utils.ValueKind) -> None: assert _utils.get_value_kind(value) == kind -def test_re_parent(): +def test_re_parent() -> None: + def validate(cfg1: DictConfig) -> None: + assert cfg1._get_parent() is None + assert cfg1.get_node("str")._get_parent() == cfg1 + assert cfg1.get_node("list")._get_parent() == cfg1 + assert cfg1.list.get_node(0)._get_parent() == cfg1.list + cfg = OmegaConf.create({}) + assert isinstance(cfg, DictConfig) cfg.str = StringNode("str") cfg.list = [1] - def validate(): - assert cfg._get_parent() is None - assert cfg.get_node("str")._get_parent() == cfg - assert cfg.get_node("list")._get_parent() == cfg - assert cfg.list.get_node(0)._get_parent() == cfg.list - - validate() + validate(cfg) cfg.get_node("str")._set_parent(None) cfg.get_node("list")._set_parent(None) - cfg.list.get_node(0)._set_parent(None) + cfg.list.get_node(0)._set_parent(None) # type: ignore _utils._re_parent(cfg) - validate() + validate(cfg)