Skip to content

Commit

Permalink
Fixed keras model builder class datatype
Browse files Browse the repository at this point in the history
Wasn't fully working correctly after new datatypes update..
  • Loading branch information
markgw committed Aug 6, 2019
1 parent 7993b22 commit b4ad297
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/python/pimlico/datatypes/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
from pimlico.core.dependencies.python import keras_dependency
from pimlico.datatypes import PimlicoDatatype
from pimlico.datatypes.base import DatatypeWriteError

from pimlico.utils.core import import_member

Expand Down Expand Up @@ -131,14 +132,18 @@ def load_model(self, override_params=None):
class Writer(object):
required_tasks = ["architecture", "weights"]

def __init__(self, base_dir, build_params, builder_class_path, **kwargs):
super(Writer, self).__init__(base_dir, **kwargs)
def __init__(self, *args, **kwargs):
build_params = kwargs.pop("build_params", {})
if "builder_class_path" not in kwargs:
raise DatatypeWriteError("builder_class_path must be supplied for a Keras model builder class writer")
build_params["builder_class_path"] = kwargs.pop("builder_class_path")

super(KerasModelBuilderClass.Writer, self).__init__(*args, **kwargs)
self.weights_filename = os.path.join(self.data_dir, "weights.hdf5")
build_params["builder_class_path"] = builder_class_path
self.build_params = build_params

def __enter__(self):
super(Writer, self).__enter__()
super(KerasModelBuilderClass.Writer, self).__enter__()
# Store the model-building hyperparameters as JSON
with open(os.path.join(self.data_dir, "build_params.json"), "w") as f:
json.dump(self.build_params, f, indent=4)
Expand Down

0 comments on commit b4ad297

Please sign in to comment.