Skip to content

Commit

Permalink
Rename gtar file handle in Trajectory
Browse files Browse the repository at this point in the history
  • Loading branch information
klarh committed May 13, 2020
1 parent b787f08 commit dd2a5db
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions keras_gtar/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Trajectory:
def __init__(self, filename, mode='r'):
self.filename = filename
self.mode = mode
self.traj = gtar.GTAR(filename, mode)
self.handle = gtar.GTAR(filename, mode)

def __enter__(self):
return self
Expand All @@ -34,25 +34,25 @@ def __len__(self):
return len(self.frames)

def close(self):
self.traj.close()
self.handle.close()

@property
def frames(self):
(_, frames) = self.traj.framesWithRecordsNamed('weight')
(_, frames) = self.handle.framesWithRecordsNamed('weight')
return frames

def get_weights(self, frame=-1):
"""Returns a list of weight arrays for a model stored at the given frame index
:param frame: integer index of the step to load. Can be negative to count from the end.
"""
(_, frames) = self.traj.framesWithRecordsNamed('weight')
(_, frames) = self.handle.framesWithRecordsNamed('weight')
frame_index = frames[frame]

weight_records = collections.defaultdict(dict)
shape_records = collections.defaultdict(dict)
weight_pattern = re.compile(r'keras/layer/(?P<layer>\d+)/weight/(?P<weight>\d+)')
for rec in self.traj.getRecordTypes():
for rec in self.handle.getRecordTypes():
match = weight_pattern.search(rec.getGroup())
if not match:
continue
Expand All @@ -69,8 +69,8 @@ def get_weights(self, frame=-1):
for weight_index in range(len(records)):
weight_rec = records[weight_index]
shape_rec = shape_records[i][weight_index]
shape = self.traj.getRecord(shape_rec, frame_index)
weight = self.traj.getRecord(weight_rec, frame_index)
shape = self.handle.getRecord(shape_rec, frame_index)
weight = self.handle.getRecord(weight_rec, frame_index)
weight = weight.reshape(shape)
all_weights.append(weight)

Expand All @@ -81,10 +81,10 @@ def load(self, frame=-1):
:param frame: integer index of the step to load. Can be negative to count from the end.
"""
model_description = self.traj.readStr('keras/model.json')
model_description = self.handle.readStr('keras/model.json')
assert model_description

extra_classes = self.traj.readBytes('keras/layer_classes.pkl')
extra_classes = self.handle.readBytes('keras/layer_classes.pkl')
extra_classes = pickle.loads(extra_classes) if extra_classes else {}

model = keras.models.model_from_json(model_description, extra_classes)
Expand All @@ -105,8 +105,8 @@ def save(self, model, frame=None, only_weights=False):
layer_classes = {type(layer).__name__: type(layer) for layer in model.layers}
layer_classes = pickle.dumps(layer_classes)

self.traj.writeStr('keras/model.json', model_json)
self.traj.writeBytes('keras/layer_classes.pkl', layer_classes)
self.handle.writeStr('keras/model.json', model_json)
self.handle.writeBytes('keras/layer_classes.pkl', layer_classes)
else:
assert frame, 'Trying to save only the weights of a model without a frame given'

Expand All @@ -117,8 +117,8 @@ def save(self, model, frame=None, only_weights=False):
for (j, weight) in enumerate(layer.get_weights()):
dtype_string = dtypes[weight.dtype.name]
group = 'keras/layer/{}/weight/{}'.format(i, j)
self.traj.writePath('{}/frames/{}/weight.{}.uni'.format(group, frame, dtype_string), weight)
self.traj.writePath('{}/shape.u32.uni'.format(group), weight.shape)
self.handle.writePath('{}/frames/{}/weight.{}.uni'.format(group, frame, dtype_string), weight)
self.handle.writePath('{}/shape.u32.uni'.format(group), weight.shape)

def save_weights(self, model, frame):
"""Save (only) the current model weights.
Expand Down

0 comments on commit dd2a5db

Please sign in to comment.