Skip to content

Commit

Permalink
Support YAML merge tags
Browse files Browse the repository at this point in the history
This adds support for YAML merge tags (<< *ref) while retaining the
sanity check for duplicate keys. Note that the spec for merge keys
(https://yaml.org/type/merge.html) explicitly states that keys in the
current mapping override the ones in the merged mapping. Hence, the
check for duplicates is applied to scalar keys of the current mapping
only.

Fixes omry#470.
  • Loading branch information
jgehring committed Feb 3, 2021
1 parent 5a798e0 commit 09b85d8
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 23 deletions.
37 changes: 14 additions & 23 deletions omegaconf/_utils.py
Expand Up @@ -81,27 +81,21 @@ def yaml_is_bool(b: str) -> bool:


def get_yaml_loader() -> Any:
# Custom constructor that checks for duplicate keys
# (from https://gist.github.com/pypt/94d747fe5180851196eb)
def no_duplicates_constructor(
loader: yaml.Loader, node: yaml.Node, deep: bool = False
) -> Any:
mapping: Dict[str, Any] = {}
for key_node, value_node in node.value:
key = loader.construct_object(key_node, deep=deep)
value = loader.construct_object(value_node, deep=deep)
if key in mapping:
raise yaml.constructor.ConstructorError(
"while constructing a mapping",
node.start_mark,
f"found duplicate key {key}",
key_node.start_mark,
)
mapping[key] = value
return loader.construct_mapping(node, deep)

class OmegaConfLoader(yaml.SafeLoader): # type: ignore
pass
def construct_mapping(self, node: yaml.Node, deep: bool = False) -> Any:
keys = set()
for key_node, value_node in node.value:
if key_node.tag != yaml.resolver.BaseResolver.DEFAULT_SCALAR_TAG:
continue
if key_node.value in keys:
raise yaml.constructor.ConstructorError(
"while constructing a mapping",
node.start_mark,
f"found duplicate key {key_node.value}",
key_node.start_mark,
)
keys.add(key_node.value)
return super().construct_mapping(node, deep=deep)

loader = OmegaConfLoader
loader.add_implicit_resolver(
Expand All @@ -126,9 +120,6 @@ class OmegaConfLoader(yaml.SafeLoader): # type: ignore
]
for key, resolvers in loader.yaml_implicit_resolvers.items()
}
loader.add_constructor(
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, no_duplicates_constructor
)
return loader


Expand Down
23 changes: 23 additions & 0 deletions tests/test_serialization.py
Expand Up @@ -183,6 +183,29 @@ def test_load_duplicate_keys_sub() -> None:
os.unlink(fp.name)


def test_load_merge_duplicates() -> None:
try:
with tempfile.NamedTemporaryFile(delete=False) as fp:
content = dedent(
"""\
a: &A
x: 1
b: &B
y: 2
c:
<<: *A
<<: *B
x: 3
z: 1
"""
)
fp.write(content.encode("utf-8"))
cfg = OmegaConf.load(fp.name)
assert cfg == {"a": {"x": 1}, "b": {"y": 2}, "c": {"x": 3, "y": 2, "z": 1}}
finally:
os.unlink(fp.name)


def test_load_empty_file(tmpdir: str) -> None:
empty = Path(tmpdir) / "test.yaml"
empty.touch()
Expand Down

0 comments on commit 09b85d8

Please sign in to comment.