Skip to content

Commit

Permalink
Merge pull request #1136 from dmitriy-serdyuk/fix-timing-again
Browse files Browse the repository at this point in the history
Fix Timing extension again
  • Loading branch information
dwf committed Aug 18, 2016
2 parents a3c8404 + bf8b788 commit 4e266ad
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 20 deletions.
39 changes: 19 additions & 20 deletions blocks/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,22 +604,15 @@ def __init__(self, prefix="", **kwargs):
kwargs.setdefault('before_first_epoch', True)
kwargs.setdefault('after_epoch', True)
super(Timing, self).__init__(**kwargs)
self.current = {
level: {'train': 0, 'read_data': 0}
for level in ['batch', 'epoch']
}
self.previous = {
level: {'train': 0, 'read_data': 0}
for level in ['batch', 'epoch']
}
self.current_index = {
level: 0
for level in ['batch', 'epoch']
}
self.previous_index = {
level: 0
for level in ['batch', 'epoch']
}

def init_dict():
return {
level: {'train': 0, 'read_data': 0}
for level in ['batch', 'epoch']}
self.current = init_dict()
self.previous = init_dict()
self.current_index = init_dict()
self.previous_index = init_dict()
self.prefix = prefix
if self.prefix:
self.prefix += '_'
Expand All @@ -637,10 +630,16 @@ def do(self, which_callback, *args):
elif which_callback == 'after_epoch':
level = 'epoch'
counter = 'epochs_done'
else:
raise ValueError('wrong callback type `{}`'.format(which_callback))
for action in ['train', 'read_data']:
self.previous_index[level] = self.current_index[level]
self.current_index[level] = self.main_loop.log.status[counter]
if self.current_index[level] == self.previous_index[level]:
self.previous_index[level][action] = (
self.current_index[level][action])
self.current_index[level][action] = (
self.main_loop.log.status[counter])
current_index = self.current_index[level][action]
previous_index = self.previous_index[level][action]
if current_index == previous_index:
logger.debug('Timing extension was called twice this %s, '
'log was not updated.', level)
# Nothing to report for this level
Expand All @@ -652,7 +651,7 @@ def do(self, which_callback, *args):
this_time = self.prefix + 'time_{}_this_{}'
current_row[this_time.format(action, level)] = (
(self.current[level][action] - self.previous[level][action]) /
(self.current_index[level] - self.previous_index[level]))
(current_index - previous_index))
total_time = self.prefix + 'time_{}_total'
current_row[total_time.format(action)] = \
self.current[level][action]
2 changes: 2 additions & 0 deletions tests/extensions/test_timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ def test_timing():
main_loop.log[iterations]['each_time_train_this_epoch']) / 2,
main_loop.log.current_row['each_second_time_train_this_epoch'],
atol=1e-2)
assert 'each_time_read_data_this_epoch' in main_loop.log[iterations]
assert 'each_second_time_read_data_this_epoch' in main_loop.log[iterations]

0 comments on commit 4e266ad

Please sign in to comment.