Skip to content

Commit

Permalink
Properly index per-batch trajectory storage
Browse files Browse the repository at this point in the history
  • Loading branch information
klarh committed May 29, 2020
1 parent dd2a5db commit bfbdcba
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions keras_gtar/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self, filename, period=1, when='post_epoch', append=True,
self.when = when
self.append = append
self.step_offset = step_offset
self.batches = 0

assert when in ('pre_batch', 'post_batch', 'pre_epoch', 'post_epoch')
super().__init__(*args, **kwargs)
Expand All @@ -39,10 +40,12 @@ def _save(self, index, required_time):
self.trajectory.save_weights(self.model, str(index))

def on_batch_begin(self, index, logs={}):
return self._save(index, 'pre_batch')
return self._save(self.batches, 'pre_batch')

def on_batch_end(self, index, logs={}):
return self._save(index, 'post_batch')
result = self._save(self.batches, 'post_batch')
self.batches += 1
return result

def on_epoch_begin(self, index, logs={}):
return self._save(index, 'pre_epoch')
Expand Down

0 comments on commit bfbdcba

Please sign in to comment.