Skip to content

Commit

Permalink
Save layer classes in a more robust, individual manner
Browse files Browse the repository at this point in the history
  • Loading branch information
klarh committed Feb 18, 2022
1 parent 3df56bf commit e1a5a70
Showing 1 changed file with 32 additions and 4 deletions.
36 changes: 32 additions & 4 deletions keras_gtar/trajectory.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import collections
import logging
import pickle
import re

import gtar
from tensorflow import keras

logger = logging.getLogger(__name__)

def all_layers(model):
"""Recursively finds all the layers within a keras model"""
for layer in model.layers:
Expand Down Expand Up @@ -99,16 +102,37 @@ def get_weights(self, frame=-1):

return all_weights

def load(self, frame=-1):
def load(self, frame=-1, extra_classes={}):
"""Loads a model stored at the given frame index
:param frame: integer index of the step to load. Can be negative to count from the end.
:param extra_classes: Dictionary of additional (name: Class) values to use when initializing model.
"""
given_extra_classes = extra_classes
model_description = self.handle.readStr(self._get_path('keras/model.json'))
assert model_description

extra_classes = self.handle.readBytes(self._get_path('keras/layer_classes.pkl'))
extra_classes = pickle.loads(extra_classes) if extra_classes else {}
try:
extra_classes = pickle.loads(extra_classes) if extra_classes else {}
except (AttributeError, ModuleNotFoundError):
logger.warning('Failed to load saved layer classes. '
'Custom layers may not load.', exc_info=True)
extra_classes = {}

(extra_class_rec, extra_class_names) = self.handle.framesWithRecordsNamed(
'layer_class.pkl', group_prefix=self.group)
for name in extra_class_names:
if name not in extra_classes:
try:
content = self.handle.getRecord(extra_class_rec, name)
extra_classes[name] = pickle.loads(content)
except (AttributeError, ModuleNotFoundError):
logger.warning(
'Failed to load saved layer class for {}'.format(name),
exc_info=True)

extra_classes.update(given_extra_classes)

model = keras.models.model_from_json(model_description, extra_classes)

Expand All @@ -126,10 +150,14 @@ def save(self, model, frame=None, only_weights=False):
if not only_weights:
model_json = model.to_json()
layer_classes = {type(layer).__name__: type(layer) for layer in all_layers(model)}
layer_classes = pickle.dumps(layer_classes)
layer_classes_dump = pickle.dumps(layer_classes)

self.handle.writeStr(self._get_path('keras/model.json'), model_json)
self.handle.writeBytes(self._get_path('keras/layer_classes.pkl'), layer_classes)
self.handle.writeBytes(self._get_path('keras/layer_classes.pkl'), layer_classes_dump)
for (name, cls) in layer_classes.items():
path = self._get_path('keras/vars/layer_class.pkl/{}'.format(name))
self.handle.writeBytes(path, pickle.dumps(cls))

else:
assert frame, 'Trying to save only the weights of a model without a frame given'

Expand Down

0 comments on commit e1a5a70

Please sign in to comment.