Skip to content

Commit

Permalink
Fix a bug when saving models (#508)
Browse files Browse the repository at this point in the history
* add failing unit test

* expand and fix tests
  • Loading branch information
msperber committed Aug 6, 2018
1 parent e06711b commit 128c647
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 6 deletions.
79 changes: 76 additions & 3 deletions test/test_persistence.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import unittest
import copy
import os
import shutil

import yaml

import xnmt
from xnmt import utils, persistence
from xnmt.persistence import Path, YamlPreloader, Serializable, serializable_init, bare
from xnmt import events, param_collections, utils
from xnmt.persistence import Path, YamlPreloader, Serializable, serializable_init, bare, initialize_if_needed,\
save_to_file

class TestPath(unittest.TestCase):

Expand Down Expand Up @@ -254,5 +256,76 @@ def test_format_strings(self):
self.assertEqual(test_obj["c"], "val1/bla")
self.assertListEqual(test_obj["d"], ["bla", "bla.val2"])

class DummyArgClass(Serializable):
yaml_tag = "!DummyArgClass"
@serializable_init
def __init__(self, arg1, arg2):
pass # arg1 and arg2 are purposefully not kept
class DummyArgClass2(Serializable):
yaml_tag = "!DummyArgClass2"
@serializable_init
def __init__(self, v):
self.v = v

class TestSaving(unittest.TestCase):
def setUp(self):
events.clear()
yaml.add_representer(DummyArgClass, xnmt.init_representer)
yaml.add_representer(DummyArgClass2, xnmt.init_representer)
self.out_dir = os.path.join("test", "tmp")
utils.make_parent_dir(os.path.join(self.out_dir, "asdf"))
self.model_file = os.path.join(self.out_dir, "saved.mod")
param_collections.ParamManager.init_param_col()
param_collections.ParamManager.param_col.model_file = self.model_file

def test_shallow(self):
test_obj = yaml.load("""
a: !DummyArgClass
arg1: !DummyArgClass2
_xnmt_id: id1
v: some_val
arg2: !Ref { name: id1 }
""")
preloaded = YamlPreloader.preload_obj(root=test_obj,exp_name="exp1",exp_dir=self.out_dir)
initalized = initialize_if_needed(preloaded)
save_to_file(self.model_file, initalized)

def test_mid(self):
test_obj = yaml.load("""
a: !DummyArgClass
arg1: !DummyArgClass2
v: !DummyArgClass2
_xnmt_id: id1
v: some_val
arg2: !DummyArgClass2
v: !Ref { name: id1 }
""")
preloaded = YamlPreloader.preload_obj(root=test_obj,exp_name="exp1",exp_dir=self.out_dir)
initalized = initialize_if_needed(preloaded)
save_to_file(self.model_file, initalized)

def test_deep(self):
test_obj = yaml.load("""
a: !DummyArgClass
arg1: !DummyArgClass2
v: !DummyArgClass2
v: !DummyArgClass2
_xnmt_id: id1
v: some_val
arg2: !DummyArgClass2
v: !DummyArgClass2
v: !Ref { name: id1 }
""")
preloaded = YamlPreloader.preload_obj(root=test_obj,exp_name="exp1",exp_dir=self.out_dir)
initalized = initialize_if_needed(preloaded)
save_to_file(self.model_file, initalized)

def tearDown(self):
try:
if os.path.isdir(os.path.join("test","tmp")):
shutil.rmtree(os.path.join("test","tmp"))
except:
pass

if __name__ == '__main__':
unittest.main()
9 changes: 6 additions & 3 deletions xnmt/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,9 +541,12 @@ def _get_child_dict(node, name):

@_get_child.register(Serializable)
def _get_child_serializable(node, name):
if not hasattr(node, name):
raise PathError(f"{node} has no child named {name}")
return getattr(node, name)
if hasattr(node, "serialize_params"):
return _get_child(node.serialize_params, name)
else:
if not hasattr(node, name):
raise PathError(f"{node} has no child named {name}")
return getattr(node, name)


@singledispatch
Expand Down

0 comments on commit 128c647

Please sign in to comment.