Skip to content

Commit

Permalink
Add group argument to Trajectory
Browse files Browse the repository at this point in the history
  • Loading branch information
klarh committed Feb 2, 2021
1 parent 8f0151b commit 3aa6135
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 10 deletions.
7 changes: 5 additions & 2 deletions keras_gtar/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,25 @@ class GTARLogger(keras.callbacks.Callback):
:param when: String indicating when to save: one of `pre_batch`, `post_batch`, `pre_epoch`, or `post_epoch`
:param append: If True, append to instead of overwriting the file if it exists already
:param step_offset: Offset to apply to the epoch or batch index
:param group: GTAR group to use to organize multiple sub-trajectories within the same GTAR file, if given
"""
def __init__(self, filename, period=1, when='post_epoch', append=True,
step_offset=0, *args, **kwargs):
step_offset=0, group=None, *args, **kwargs):
self.filename = filename
self.period = period
self.when = when
self.append = append
self.step_offset = step_offset
self.group = group
self.batches = 0

assert when in ('pre_batch', 'post_batch', 'pre_epoch', 'post_epoch')
super().__init__(*args, **kwargs)

def on_train_begin(self, logs={}):
mode = 'a' if self.append else 'w'
self.trajectory = Trajectory(self.filename, mode)
self.trajectory = Trajectory(self.filename, mode, self.group)
self.trajectory.save_model(self.model)

def on_train_end(self, logs={}):
Expand Down
30 changes: 22 additions & 8 deletions keras_gtar/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ class Trajectory:
:param filename: File to save or load from
:param mode: File open mode: 'r' (read-only), 'w' (overwrite), or 'a' (append)
:param group: GTAR group prefix to use to organize multiple sub-trajectories within the same GTAR file, if given
"""

def __init__(self, filename, mode='r'):
def __init__(self, filename, mode='r', group=None):
self.filename = filename
self.mode = mode
self.handle = gtar.GTAR(filename, mode)
self.group = group

def __enter__(self):
return self
Expand All @@ -38,21 +40,33 @@ def close(self):

@property
def frames(self):
(_, frames) = self.handle.framesWithRecordsNamed('weight')
(_, frames) = self.handle.framesWithRecordsNamed('weight', group_prefix=self.group)
return frames

def _get_path(self, name):
if self.group is None:
return name
else:
return '{}/{}'.format(self.group, name)

def get_weights(self, frame=-1):
"""Returns a list of weight arrays for a model stored at the given frame index
:param frame: integer index of the step to load. Can be negative to count from the end.
"""
(_, frames) = self.handle.framesWithRecordsNamed('weight')
(_, frames) = self.handle.framesWithRecordsNamed('weight', group_prefix=self.group)
frame_index = frames[frame]

weight_records = collections.defaultdict(dict)
shape_records = collections.defaultdict(dict)
weight_pattern = re.compile(r'keras/layer/(?P<layer>\d+)/weight/(?P<weight>\d+)')
for rec in self.handle.getRecordTypes():
group = rec.getGroup()
invalid_group = (self.group is not None and
not rec.getGroup().startswith(self.group))
if invalid_group:
continue

match = weight_pattern.search(rec.getGroup())
if not match:
continue
Expand Down Expand Up @@ -81,10 +95,10 @@ def load(self, frame=-1):
:param frame: integer index of the step to load. Can be negative to count from the end.
"""
model_description = self.handle.readStr('keras/model.json')
model_description = self.handle.readStr(self._get_path('keras/model.json'))
assert model_description

extra_classes = self.handle.readBytes('keras/layer_classes.pkl')
extra_classes = self.handle.readBytes(self._get_path('keras/layer_classes.pkl'))
extra_classes = pickle.loads(extra_classes) if extra_classes else {}

model = keras.models.model_from_json(model_description, extra_classes)
Expand All @@ -105,8 +119,8 @@ def save(self, model, frame=None, only_weights=False):
layer_classes = {type(layer).__name__: type(layer) for layer in model.layers}
layer_classes = pickle.dumps(layer_classes)

self.handle.writeStr('keras/model.json', model_json)
self.handle.writeBytes('keras/layer_classes.pkl', layer_classes)
self.handle.writeStr(self._get_path('keras/model.json'), model_json)
self.handle.writeBytes(self._get_path('keras/layer_classes.pkl'), layer_classes)
else:
assert frame, 'Trying to save only the weights of a model without a frame given'

Expand All @@ -116,7 +130,7 @@ def save(self, model, frame=None, only_weights=False):
for (i, layer) in enumerate(model.layers):
for (j, weight) in enumerate(layer.get_weights()):
dtype_string = dtypes[weight.dtype.name]
group = 'keras/layer/{}/weight/{}'.format(i, j)
group = self._get_path('keras/layer/{}/weight/{}'.format(i, j))
self.handle.writePath('{}/frames/{}/weight.{}.uni'.format(group, frame, dtype_string), weight)
self.handle.writePath('{}/shape.u32.uni'.format(group), weight.shape)

Expand Down

0 comments on commit 3aa6135

Please sign in to comment.