Skip to content

Commit

Permalink
more assignment tests
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed Aug 5, 2019
1 parent 31bb0fe commit 9f740d9
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 21 deletions.
2 changes: 2 additions & 0 deletions omegaconf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,8 @@ def _deepcopy_impl(self, res, memodict={}):
res.__dict__['parent'] = copy.deepcopy(self.__dict__['parent'], memodict)
res.__dict__['flags'] = copy.deepcopy(self.__dict__['flags'], memodict)



def merge_with(self, *others):
from .listconfig import ListConfig
from .dictconfig import DictConfig
Expand Down
20 changes: 15 additions & 5 deletions omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,23 @@ def __setitem__(self, key, value):
if key not in self.content and self._get_flag('struct') is True:
raise KeyError("Accessing unknown key in a struct : {}".format(self.get_full_key(key)))

if key in self and isinstance(value, BaseNode):
self.__dict__['content'][key].set_value(value)
else:
if not isinstance(value, (BaseNode, Config)):
self.__dict__['content'][key] = UntypedNode(value)
input_config_or_node = isinstance(value, (BaseNode, Config))
if key in self:
# BaseNode or Config, assign as is
if input_config_or_node:
self.__dict__['content'][key] = value
else:
# primitive input
if isinstance(self.__dict__['content'][key], Config):
# primitive input replaces config nodes
self.__dict__['content'][key] = value
else:
self.__dict__['content'][key].set_value(value)
else:
if input_config_or_node:
self.__dict__['content'][key] = value
else:
self.__dict__['content'][key] = UntypedNode(value)

# hide content while inspecting in debugger
def __dir__(self):
Expand Down
36 changes: 35 additions & 1 deletion tests/test_base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,41 @@

import pytest

from omegaconf import OmegaConf, ListConfig, DictConfig, ReadonlyConfigError
from omegaconf import *


@pytest.mark.parametrize('input_, key, value, expected', [
# dict
(dict(), 'foo', 10, dict(foo=10)),
(dict(), 'foo', IntegerNode(10), dict(foo=10)),
(dict(foo=5), 'foo', IntegerNode(10), dict(foo=10)),
# changing type of a node
(dict(foo=StringNode('str')), 'foo', IntegerNode(10), dict(foo=10)),
# list
([0], 0, 10, [10]),
(['a', 'b', 'c'], 1, 10, ['a', 10, 'c']),
([1, 2], 1, IntegerNode(10), [1, 10]),
([1, IntegerNode(2)], 1, IntegerNode(10), [1, 10]),
# changing type of a node
([1, StringNode('str')], 1, IntegerNode(10), [1, 10]),
])
def test_set_value(input_, key, value, expected):
c = OmegaConf.create(input_)
c[key] = value
assert c == expected


@pytest.mark.parametrize('input_, key, value', [
# dict
(dict(foo=IntegerNode(10)), 'foo', 'str'),
# list
([1, IntegerNode(10)], 1, 'str'),
])
def test_set_value_validation_fail(input_, key, value):
c = OmegaConf.create(input_)
with pytest.raises(ValidationError):
c[key] = value


@pytest.mark.parametrize('input_', [
Expand Down
9 changes: 1 addition & 8 deletions tests/test_basic_ops_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,11 @@

import pytest

from omegaconf import OmegaConf, DictConfig, Config
from omegaconf import nodes
from omegaconf import *
from omegaconf.errors import MissingMandatoryValue
from . import IllegalType


def test_setattr_value():
c = OmegaConf.create(dict(a=dict(b=dict(c=1))))
c.a = 9
assert c == dict(a=9)


def test_setattr_deep_value():
c = OmegaConf.create(dict(a=dict(b=dict(c=1))))
c.a.b = 9
Expand Down
5 changes: 0 additions & 5 deletions tests/test_basic_ops_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,6 @@ def test_getattr():
getattr(c, "anything")


def test_setitem():
c = OmegaConf.create(['a', 'b', 'c'])
c[1] = 10
assert c == ['a', 10, 'c']


def test_insert():
c = OmegaConf.create(['a', 'b', 'c'])
Expand Down
20 changes: 18 additions & 2 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import pytest

from omegaconf import OmegaConf, nodes, DictConfig, ListConfig
from omegaconf import *
from omegaconf.errors import ValidationError


def test_base_node():
b = BaseNode()
assert b.value() is None
with pytest.raises(NotImplementedError):
b.set_value(10)


# testing valid conversions
@pytest.mark.parametrize('type_,input_,output_', [
# string
Expand Down Expand Up @@ -44,6 +51,10 @@
def test_valid_inputs(type_, input_, output_):
node = type_(input_)
assert node == output_
assert node == node
assert not (node != output_)
assert not (node != node)
assert str(node) == str(output_)


# testing invalid conversions
Expand All @@ -67,8 +78,13 @@ def test_invalid_inputs(type_, input_):
@pytest.mark.parametrize('input_, expected_type', [
({}, DictConfig),
([], ListConfig),
(5, UntypedNode),
(5.0, UntypedNode),
(True, UntypedNode),
(False, UntypedNode),
('str', UntypedNode),
])
def test_config_type_not_wrapped(input_, expected_type):
def test_assigned_value_node_type(input_, expected_type):
c = OmegaConf.create()
c.foo = input_
assert type(c.get_node('foo')) == expected_type
Expand Down

0 comments on commit 9f740d9

Please sign in to comment.