Skip to content

Commit

Permalink
Merge pull request #1110 from dmitriy-serdyuk/fix-timing
Browse files Browse the repository at this point in the history
Improve timing extension
  • Loading branch information
dwf committed Jun 8, 2016
2 parents d7a3060 + b714362 commit b8d1255
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 9 deletions.
41 changes: 37 additions & 4 deletions blocks/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,16 +513,24 @@ class Timing(SimpleExtension):
reading data per batch or epoch. It also reports the time spent
initializing the algorithm.
Parameters
----------
prefix : str
Prefix to be added to the log record. Defaults to the empty string.
Notes
-----
Add this extension *before* the :class:`Printing` extension.
Created with callbacks like ``every_n_batches`` this extension
averages the time.
This extension does *not* enable full profiling information. To see a
full profile of the main loop at the end of training, use the
``profile`` configuration (e.g. by setting ``BLOCKS_PROFILE=true``).
"""
def __init__(self, **kwargs):
def __init__(self, prefix="", **kwargs):
kwargs.setdefault('before_first_epoch', True)
kwargs.setdefault('after_epoch', True)
super(Timing, self).__init__(**kwargs)
Expand All @@ -534,6 +542,17 @@ def __init__(self, **kwargs):
level: {'train': 0, 'read_data': 0}
for level in ['batch', 'epoch']
}
self.current_index = {
level: 0
for level in ['batch', 'epoch']
}
self.previous_index = {
level: 0
for level in ['batch', 'epoch']
}
self.prefix = prefix
if self.prefix:
self.prefix += '_'

def do(self, which_callback, *args):
current_row = self.main_loop.log.current_row
Expand All @@ -544,12 +563,26 @@ def do(self, which_callback, *args):
return
if which_callback == 'after_batch':
level = 'batch'
counter = 'iterations_done'
elif which_callback == 'after_epoch':
level = 'epoch'
counter = 'epochs_done'
for action in ['train', 'read_data']:
self.previous_index[level] = self.current_index[level]
self.current_index[level] = self.main_loop.log.status[counter]
if self.current_index[level] == self.previous_index[level]:
logger.debug('Timing extension was called twice this %s, '
'log was not updated.', level)
# Nothing to report for this level
continue

self.previous[level][action] = self.current[level][action]
self.current[level][action] = profile['training', 'epoch', action]
current_row['time_{}_this_{}'.format(action, level)] = \
self.current[level][action] - self.previous[level][action]
current_row['time_{}_total'.format(action)] = \

this_time = self.prefix + 'time_{}_this_{}'
current_row[this_time.format(action, level)] = (
(self.current[level][action] - self.previous[level][action]) /
(self.current_index[level] - self.previous_index[level]))
total_time = self.prefix + 'time_{}_total'
current_row[total_time.format(action)] = \
self.current[level][action]
9 changes: 6 additions & 3 deletions blocks/utils/testing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
import sys
import time
from six import wraps
from importlib import import_module
from unittest.case import SkipTest
Expand Down Expand Up @@ -103,15 +104,17 @@ class MockAlgorithm(TrainingAlgorithm):
Also checks that the initialization routine is only called once.
"""
def __init__(self):
def __init__(self, delay_time=0):
self._initialized = False
self.delay_time = delay_time

def initialize(self):
assert not self._initialized
self._initialized = True

def process_batch(self, batch):
self.batch = batch
time.sleep(self.delay_time)


class MockMainLoop(MainLoop):
Expand All @@ -121,8 +124,8 @@ class MockMainLoop(MainLoop):
which calls were made.
"""
def __init__(self, **kwargs):
def __init__(self, delay_time=0, **kwargs):
kwargs.setdefault('data_stream',
IterableDataset(range(10)).get_example_stream())
kwargs.setdefault('algorithm', MockAlgorithm())
kwargs.setdefault('algorithm', MockAlgorithm(delay_time))
super(MockMainLoop, self).__init__(**kwargs)
16 changes: 14 additions & 2 deletions tests/extensions/test_timing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
from numpy.testing import assert_allclose

from blocks.extensions import Timing, FinishAfter
from blocks.utils.testing import MockMainLoop


def test_timing():
main_loop = MockMainLoop(extensions=[Timing(),
FinishAfter(after_n_epochs=2)])
epochs = 2
main_loop = MockMainLoop(delay_time=0.1,
extensions=[Timing(prefix='each'),
Timing(prefix='each_second',
every_n_epochs=2),
FinishAfter(after_n_epochs=epochs)])
main_loop.run()
iterations = int(main_loop.log.status['iterations_done'] / epochs)
assert_allclose(
(main_loop.log[iterations]['each_time_train_this_epoch'] +
main_loop.log[iterations]['each_time_train_this_epoch']) / 2,
main_loop.log.current_row['each_second_time_train_this_epoch'],
atol=1e-2)

0 comments on commit b8d1255

Please sign in to comment.