Skip to content

Commit

Permalink
Fixes to serializer for v2
Browse files Browse the repository at this point in the history
  • Loading branch information
neubig committed Jun 29, 2017
1 parent c44b235 commit 7cf915a
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions xnmt/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def save_to_file(self, fname, mod, params):
os.makedirs(dirname)
with open(fname, 'w') as f:
json.dump(self.__to_spec(mod), f)
params.save_all(fname + '.data')
params.save(fname + '.data')

'''
Load a model from a file.
Expand All @@ -37,7 +37,7 @@ def load_from_file(self, fname, param):
with open(fname, 'r') as f:
dict_spec = json.load(f)
mod = self.__from_spec(dict_spec, param)
param.load_all(fname + '.data')
param.populate(fname + '.data')
return mod

def __to_spec(self, obj):
Expand All @@ -50,7 +50,7 @@ def __to_spec(self, obj):
info['__param__'] = [self.__to_spec(x) for x in obj.serialize_params]
elif obj.__class__.__name__ == 'list' or obj.__class__.__name__ == 'dict':
return json.dumps(obj)
elif obj.__class__.__name__ != 'Model':
elif obj.__class__.__name__ != 'ParameterCollection':
raise NotImplementedError("Class %s is not serializable. Try adding serialize_params to it." % obj.__class__.__name__)
return info

Expand All @@ -66,7 +66,7 @@ def __from_spec(self, spec, params):
raise NotImplementedError("Class %s is not deserializable. Try adding serialize_params to it." % spec.__class__.__name__)
elif '__class__' not in spec:
raise NotImplementedError("Dict is not deserializable. Try adding __class__ when saving it:\n %r" % spec)
elif spec['__class__'] == 'Model':
elif spec['__class__'] == 'ParameterCollection':
return params
elif '__param__' not in spec:
raise NotImplementedError("Dict is not deserializable. Try adding __param__ when saving it:\n %r" % spec)
Expand Down

0 comments on commit 7cf915a

Please sign in to comment.