Skip to content

Commit

Permalink
Refactor some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
odelalleau committed Mar 25, 2021
1 parent 6898d48 commit 018d521
Showing 1 changed file with 25 additions and 32 deletions.
57 changes: 25 additions & 32 deletions tests/interpolation/built_in_resolvers/test_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,29 @@
param(
{"foo": "${oc.dict.keys:{a: 0, b: 1}}"},
"foo",
OmegaConf.create(["a", "b"]),
["a", "b"],
id="dict",
),
param(
{"foo": "${oc.dict.keys:${bar}}", "bar": {"a": 0, "b": 1}},
"foo",
OmegaConf.create(["a", "b"]),
["a", "b"],
id="dictconfig_interpolation",
),
param(
{"foo": "${oc.dict.keys:bar}", "bar": {"a": 0, "b": 1}},
"foo",
OmegaConf.create(["a", "b"]),
["a", "b"],
id="dictconfig_select",
),
param(
{"foo": "${sum:${oc.dict.keys:{1: one, 2: two}}}"},
"foo",
3,
id="nested",
),
],
)
def test_dict_keys(restore_resolvers: Any, cfg: Any, key: Any, expected: Any) -> None:
OmegaConf.register_new_resolver("sum", lambda x: sum(x))

def test_dict_keys(cfg: Any, key: Any, expected: Any) -> None:
cfg = OmegaConf.create(cfg)
val = cfg[key]
assert val == expected
assert type(val) is type(expected)

if isinstance(val, ListConfig):
assert val._parent is cfg
assert isinstance(val, ListConfig)
assert val._parent is cfg


@mark.parametrize(
Expand All @@ -53,34 +43,28 @@ def test_dict_keys(restore_resolvers: Any, cfg: Any, key: Any, expected: Any) ->
param(
{"foo": "${oc.dict.values:{a: 0, b: 1}}"},
"foo",
OmegaConf.create([0, 1]),
[0, 1],
id="dict",
),
param(
{"foo": "${oc.dict.values:${bar}}", "bar": {"a": 0, "b": 1}},
"foo",
OmegaConf.create([0, 1]),
[0, 1],
id="dictconfig",
),
param(
{"foo": "${oc.dict.values:bar}", "bar": {"a": 0, "b": 1}},
"foo",
OmegaConf.create([0, 1]),
[0, 1],
id="dictconfig_select",
),
param(
{"foo": "${sum:${oc.dict.values:{one: 1, two: 2}}}"},
"foo",
3,
id="nested",
),
param(
{
"foo": "${oc.dict.values:${bar}}",
"bar": {"x": {"x0": 0, "x1": 1}, "y": {"y0": 0}},
},
"foo",
OmegaConf.create([{"x0": 0, "x1": 1}, {"y0": 0}]),
[{"x0": 0, "x1": 1}, {"y0": 0}],
id="convert_node_to_list",
),
param(
Expand All @@ -89,21 +73,30 @@ def test_dict_keys(restore_resolvers: Any, cfg: Any, key: Any, expected: Any) ->
"val_ref": "value",
},
"foo",
OmegaConf.create(["value"]),
["value"],
id="dict_with_interpolated_value",
),
],
)
def test_dict_values(restore_resolvers: Any, cfg: Any, key: Any, expected: Any) -> None:
OmegaConf.register_new_resolver("sum", lambda x: sum(x))
def test_dict_values(cfg: Any, key: Any, expected: Any) -> None:

cfg = OmegaConf.create(cfg)
val = cfg[key]
assert val == expected
assert type(val) is type(expected)
assert isinstance(val, ListConfig)
assert val._parent is cfg

if isinstance(val, ListConfig):
assert val._parent is cfg

def test_dict_keys_nested(restore_resolvers: Any) -> None:
OmegaConf.register_new_resolver("sum", lambda x: sum(x))
cfg = OmegaConf.create({"x": "${sum:${oc.dict.keys:{1: one, 2: two}}}"})
assert cfg.x == 3


def test_dict_values_nested(restore_resolvers: Any) -> None:
OmegaConf.register_new_resolver("sum", lambda x: sum(x))
cfg = OmegaConf.create({"x": "${sum:${oc.dict.values:{one: 1, two: 2}}}"})
assert cfg.x == 3


@mark.parametrize(
Expand Down

0 comments on commit 018d521

Please sign in to comment.