Skip to content

Commit

Permalink
initial implementation of repeat syntax (#468)
Browse files Browse the repository at this point in the history
  • Loading branch information
msperber authored and neubig committed Jul 17, 2018
1 parent a0dd0cf commit 7c473ae
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
4 changes: 2 additions & 2 deletions test/test_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_load_referenced_serialized_top(self):
with open(f"{self.out_dir}/tmp1.yaml", "w") as f_out:
yaml.dump(DummyClass(arg1="v1"), f_out)
test_obj = yaml.load(f"!LoadSerialized {{ filename: {self.out_dir}/tmp1.yaml }}")
loaded_obj = YamlPreloader._load_referenced_serialized(test_obj)
loaded_obj = YamlPreloader._load_serialized(test_obj)
self.assertIsInstance(loaded_obj, DummyClass)
self.assertEqual(loaded_obj.arg1, "v1")

Expand All @@ -196,7 +196,7 @@ def test_load_referenced_serialized_nested(self):
val: !LoadSerialized
filename: {self.out_dir}/tmp1.yaml
""")
loaded_obj = YamlPreloader._load_referenced_serialized(test_obj)
loaded_obj = YamlPreloader._load_serialized(test_obj)
self.assertIsInstance(loaded_obj["b"], DummyClass)
self.assertIsInstance(loaded_obj["b"].arg1, DummyClass)

Expand Down
35 changes: 32 additions & 3 deletions xnmt/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* :class:`Serializable`: must be subclassed by all components that are specified in a YAML file.
* :class:`Ref`: a reference that points somewhere in the object hierarchy, for both convenience and to realize parameter sharing.
* :class:`Repeat`: a syntax for creating a list components with same configuration but without parameter sharing.
* :class:`YamlPreloader`: pre-loads YAML contents so that some infrastructure can be set up, but does not initialize components.
* :meth:`initialize_if_needed`, :meth:`initialize_object`: initialize a preloaded YAML tree, taking care of resolving references etc.
* :meth:`save_to_file`: saves a YAML file along with registered DyNet parameters
Expand Down Expand Up @@ -304,7 +305,6 @@ def resolve_path(self, named_paths: Dict[str, 'Path']) -> 'Path':
else:
raise ValueError(f"Could not resolve path of reference {self}")


class Path(object):
"""
A relative or absolute path in the component hierarchy.
Expand Down Expand Up @@ -426,6 +426,21 @@ def ancestors(self) -> Set['Path']:
ret.add(a)
return ret

class Repeat(Serializable):
"""
A special object that is replaced by a list of components with identical configuration but not with shared params.
This can be specified anywhere in the config hierarchy where normally a list is expected.
A common use case is a multi-layer neural architecture, where layer configurations are repeated many times.
It is replaced in the preloader and cannot be instantiated directly.
"""
yaml_tag = "!Repeat"
@serializable_init
def __init__(self, times: int, content: Any):
self.times = times
self.content = content
raise ValueError("Repeat cannot be instantiated")


_subcol_rand = random.Random()

Expand Down Expand Up @@ -861,12 +876,14 @@ def preload_obj(root: Any, exp_name: str, exp_dir: str) -> UninitializedYamlObje

YamlPreloader._format_strings(root, placeholders) # do this both before and after resolving !LoadSerialized

root = YamlPreloader._load_referenced_serialized(root)
root = YamlPreloader._load_serialized(root)

random_search_report = YamlPreloader._instantiate_random_search(root)
if random_search_report:
setattr(root, 'random_search_report', random_search_report)

YamlPreloader._resolve_repeat(root)

# if arguments were not given in the YAML file and are set to a bare(Serializable) by default, copy the bare object
# into the object hierarchy so it can be used w/ param sharing etc.
YamlPreloader._resolve_bare_default_args(root)
Expand All @@ -876,7 +893,7 @@ def preload_obj(root: Any, exp_name: str, exp_dir: str) -> UninitializedYamlObje
return UninitializedYamlObject(root)

@staticmethod
def _load_referenced_serialized(root: Any) -> Any:
def _load_serialized(root: Any) -> Any:
for path, node in _traverse_tree(root, traversal_order=_TraversalOrder.ROOT_LAST):
if isinstance(node, LoadSerialized):
LoadSerialized._check_wellformed(node)
Expand Down Expand Up @@ -970,6 +987,18 @@ def _instantiate_random_search(experiment):
param_report[path] = v
return param_report

@staticmethod
def _resolve_repeat(root):
for path, node in _traverse_tree(root, traversal_order=_TraversalOrder.ROOT_LAST):
if isinstance(node, Repeat):
expanded = []
for _ in range(node.times):
expanded.append(copy.deepcopy(node.content))
if len(path) == 0:
root = expanded
else:
_set_descendant(root, path, expanded)

@staticmethod
def _resolve_bare_default_args(root: Any) -> None:
for path, node in _traverse_tree(root):
Expand Down

0 comments on commit 7c473ae

Please sign in to comment.