Skip to content

Commit

Permalink
fixes #21
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed Jul 31, 2019
1 parent 4f256bd commit 981834a
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 45 deletions.
51 changes: 27 additions & 24 deletions omegaconf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,8 @@ def get_yaml_loader():
class Config(object):

def __init__(self):
"""
Can't be instantiated
"""
raise NotImplementedError
if type(self) == Config:
raise NotImplementedError

def save(self, f):
data = self.pretty()
Expand Down Expand Up @@ -138,7 +136,7 @@ def get_full_key(self, key):
full_key = "{}".format(key)
else:
for parent_key, v in parent.items():
if v == child:
if id(v) == id(child):
if isinstance(child, ListConfig):
full_key = "{}{}".format(parent_key, full_key)
else:
Expand All @@ -149,7 +147,7 @@ def get_full_key(self, key):
full_key = "[{}]".format(key)
else:
for idx, v in enumerate(parent):
if v == child:
if id(v) == id(child):
if isinstance(child, ListConfig):
full_key = "[{}]{}".format(idx, full_key)
else:
Expand Down Expand Up @@ -288,12 +286,10 @@ def _to_content(conf, resolve):
assert isinstance(conf, Config)
if isinstance(conf, DictConfig):
ret = {}
for key, value in conf.items():
for key, value in conf.items(resolve=resolve):
if isinstance(value, Config):
ret[key] = Config._to_content(value, resolve)
else:
if resolve:
value = conf[key]
ret[key] = value
return ret
elif isinstance(conf, ListConfig):
Expand Down Expand Up @@ -365,7 +361,7 @@ def re_parent(node):
# update parents of first level Config nodes to self
assert isinstance(node, (DictConfig, ListConfig))
if isinstance(node, DictConfig):
for _key, value in node.items():
for _key, value in node.items(resolve=False):
if isinstance(value, Config):
value._set_parent(node)
re_parent(value)
Expand Down Expand Up @@ -445,44 +441,51 @@ def is_primitive_type(value):
return isinstance(value, tuple(valid))

@staticmethod
def _item_eq(v1, v2):
def _item_eq(c1, k1, c2, k2):
v1 = c1.content[k1]
v2 = c2.content[k2]
if isinstance(v1, BaseNode):
v1 = v1.value()
if isinstance(v1, str):
# noinspection PyProtectedMember
v1 = c1._resolve_single(v1)
if isinstance(v2, BaseNode):
v2 = v2.value()
if isinstance(v2, str):
# noinspection PyProtectedMember
v2 = c2._resolve_single(v2)

if isinstance(v1, Config) and isinstance(v2, Config):
if not Config._config_eq(v1, v2):
return False
return v1 == v2

@staticmethod
def _list_eq(l1, l2):
assert isinstance(l1, list)
assert isinstance(l2, list)
from .listconfig import ListConfig
assert isinstance(l1, ListConfig)
assert isinstance(l2, ListConfig)
if len(l1) != len(l2):
return False
for i in range(len(l1)):
v1 = l1[i]
v2 = l2[i]
if not Config._item_eq(v1, v2):
if not Config._item_eq(l1, i, l2, i):
return False

return True

@staticmethod
def _dict_eq(d1, d2):
assert isinstance(d1, dict)
assert isinstance(d2, dict)
def _dict_conf_eq(d1, d2):
from .dictconfig import DictConfig
assert isinstance(d1, DictConfig)
assert isinstance(d2, DictConfig)
if len(d1) != len(d2):
return False
k1 = sorted(d1.keys())
k2 = sorted(d2.keys())
if k1 != k2:
return False
for k in k1:
v1 = d1[k]
v2 = d2[k]
if not Config._item_eq(v1, v2):
if not Config._item_eq(d1, k, d2, k):
return False

return True
Expand All @@ -494,8 +497,8 @@ def _config_eq(c1, c2):
assert isinstance(c1, Config)
assert isinstance(c2, Config)
if isinstance(c1, DictConfig) and isinstance(c2, DictConfig):
return Config._dict_eq(c1.content, c2.content)
return DictConfig._dict_conf_eq(c1, c2)
if isinstance(c1, ListConfig) and isinstance(c2, ListConfig):
return Config._list_eq(c1.content, c2.content)
return Config._list_eq(c1, c2)
# if type does not match objects are different
return False
16 changes: 10 additions & 6 deletions omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
class DictConfig(Config):

def __init__(self, content, parent=None):
super(DictConfig, self).__init__()
assert isinstance(content, dict)
self.__dict__['frozen_flag'] = None
self.__dict__['content'] = {}
Expand Down Expand Up @@ -84,7 +85,7 @@ def keys(self):
def __iter__(self):
return iter(self.keys())

def items(self):
def items(self, resolve=True):
class MyItems(object):
def __init__(self, m):
self.map = m
Expand All @@ -99,19 +100,22 @@ def __next__(self):

def next(self):
k = next(self.iterator)
v = self.map.content[k]
if isinstance(v, BaseNode):
v = v.value()
if resolve:
v = self.map.get(k)
else:
v = self.map.content[k]
if isinstance(v, BaseNode):
v = v.value()
kv = (k, v)
return kv

return MyItems(self)

def __eq__(self, other):
if isinstance(other, dict):
return Config._dict_eq(self.content, other)
return Config._dict_conf_eq(self, DictConfig(other))
if isinstance(other, DictConfig):
return Config._dict_eq(self.content, other.content)
return Config._dict_conf_eq(self, other)
return NotImplemented

def __ne__(self, other):
Expand Down
5 changes: 3 additions & 2 deletions omegaconf/listconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

class ListConfig(Config):
def __init__(self, content, parent=None):
super(ListConfig, self).__init__()
assert isinstance(content, (list, tuple))
self.__dict__['frozen_flag'] = None
self.__dict__['content'] = []
Expand Down Expand Up @@ -122,9 +123,9 @@ def key1(x):

def __eq__(self, other):
if isinstance(other, list):
return Config._list_eq(self.content, other)
return Config._list_eq(self, ListConfig(other))
if isinstance(other, ListConfig):
return Config._list_eq(self.content, other.content)
return Config._list_eq(self, other)
return NotImplemented

def __ne__(self, other):
Expand Down
51 changes: 38 additions & 13 deletions tests/test_basic_ops_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from omegaconf import OmegaConf, DictConfig, Config
from omegaconf import nodes
from omegaconf.errors import MissingMandatoryValue, FrozenConfigError
from omegaconf.errors import MissingMandatoryValue
from . import IllegalType


Expand Down Expand Up @@ -144,8 +144,27 @@ def foo(a, b):


def test_items():
c = OmegaConf.create('{a: 2, b: 10}')
assert {'a': 2, 'b': 10}.items() == c.items()
c = OmegaConf.create(dict(a=2, b=10))
assert sorted([('a', 2), ('b', 10)]) == sorted(list(c.items()))


def test_items2():
c = OmegaConf.create(dict(a=dict(v=1), b=dict(v=1)))
for k, v in c.items():
v.v = 2

assert c.a.v == 2
assert c.b.v == 2


def test_items_with_interpolation():
c = OmegaConf.create(
dict(
a=2,
b='${a}'
)
)
assert list({'a': 2, 'b': 2}.items()) == list(c.items())


def test_dict_keys():
Expand Down Expand Up @@ -200,15 +219,6 @@ def test_iterate_dictionary():
assert m2 == c


def test_items():
c = OmegaConf.create(dict(a=dict(v=1), b=dict(v=1)))
for k, v in c.items():
v.v = 2

assert c.a.v == 2
assert c.b.v == 2


def test_dict_pop():
c = OmegaConf.create(dict(a=1, b=2))
assert c.pop('a') == 1
Expand Down Expand Up @@ -390,6 +400,22 @@ def eq(a, b):
eq(c2, input2)


@pytest.mark.parametrize('input1, input2', [
(dict(a=12, b='${a}'), dict(a=12, b=12)),
])
def test_dict_eq_with_interpolation(input1, input2):
c1 = OmegaConf.create(input1)
c2 = OmegaConf.create(input2)

def eq(a, b):
assert a == b
assert b == a
assert not a != b
assert not b != a

eq(c1, c2)


@pytest.mark.parametrize('input1, input2', [
(dict(), dict(a=10)),
({}, []),
Expand Down Expand Up @@ -424,7 +450,6 @@ def test_dict_not_eq_with_another_class():
assert OmegaConf.create() != "string"



def test_hash():
c1 = OmegaConf.create(dict(a=10))
c2 = OmegaConf.create(dict(a=10))
Expand Down
25 changes: 25 additions & 0 deletions tests/test_basic_ops_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ def test_iterate_list():
assert items[1] == 2


def test_items_with_interpolation():
c = OmegaConf.create([
'foo',
'${0}'
])

assert c == ['foo', 'foo']


def test_list_pop():
c = OmegaConf.create([1, 2, 3, 4])
assert c.pop(0) == 1
Expand Down Expand Up @@ -267,6 +276,22 @@ def eq(a, b):
eq(c2, l2)


@pytest.mark.parametrize('l1,l2', [
([10, '${0}'], [10, 10])
])
def test_list_eq_with_interpolation(l1, l2):
c1 = OmegaConf.create(l1)
c2 = OmegaConf.create(l2)

def eq(a, b):
assert a == b
assert b == a
assert not a != b
assert not b != a

eq(c1, c2)


@pytest.mark.parametrize('input1, input2', [
([], [10]),
([10], [11]),
Expand Down

0 comments on commit 981834a

Please sign in to comment.