Skip to content

Commit

Permalink
Merge pull request #1161 from dwf/timestamp_better_defaults
Browse files Browse the repository at this point in the history
Add default triggers to Timestamp
  • Loading branch information
dmitriy-serdyuk committed Nov 8, 2016
2 parents 1d6683f + e2c3829 commit 3f002e9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
16 changes: 14 additions & 2 deletions blocks/extensions/__init__.py
Expand Up @@ -219,7 +219,7 @@ class SimpleExtension(TrainingExtension):
"""
BOOLEAN_TRIGGERS = frozenset(["before_training", "before_first_epoch",
"before_epoch", "before_batch",
"on_resumption", "on_interrupt",
"on_resumption", "on_interrupt", "on_error",
"after_epoch", "after_batch",
"after_training"])

Expand Down Expand Up @@ -669,14 +669,26 @@ class Timestamp(SimpleExtension):
Separator between the date and time. ISO 8601 specifies 'T'.
Here, we default to ' ' (blank space) for human readability.
Notes
-----
By default, triggers after every epoch as well as before training
starts, after training finishes, when an error occurs or when training
is interrupted or resumed, as these are all generally useful
circumstances for which to have a timestamp. These can be disabled
by passing `False` as the appropriate keyword argument; see
:class:`SimpleExtension`.
"""
DEFAULT_LOG_RECORD = 'timestamp'

def __init__(self, log_record=DEFAULT_LOG_RECORD, separator=' ',
**kwargs):
self.log_record = log_record
self.separator = separator
kwargs.setdefault('after_epoch', True)
default_callbacks = ['before_training', 'after_epoch', 'on_error',
'on_interrupt', 'on_resumption', 'after_training']
for callback in default_callbacks:
kwargs.setdefault(callback, True)
super(Timestamp, self).__init__(**kwargs)

def do(self, *args):
Expand Down
21 changes: 21 additions & 0 deletions tests/extensions/test_extensions.py
@@ -1,3 +1,4 @@
import re
from mock import Mock
from numpy.testing import assert_raises

Expand Down Expand Up @@ -176,6 +177,26 @@ def check(kwargs):
assert ext.main_loop.log.current_row[log_record] == 'foo'
# Exercise original get_timestamp.
ext.do('after_epoch')
sep = kwargs.get('separator', ' ')
assert bool(re.match(''.join(['[0-9]{4}-[0-9]{2}-[0-9]{2}', sep,
'[0-9]{2}(\\:[0-9]{2}){2}'
'\\.[0-9]+']),
ext.main_loop.log.current_row[log_record]))

yield check, {}
yield check, {'log_record': 'loggy mclogpants'}


def test_timestamp_default_triggers():
def check(callback):
ext = InjectedTimestamp()
ext.main_loop = Mock()
ext.main_loop.log.current_row = {}
ext.dispatch(callback)
assert ext.main_loop.log.current_row.get('timestamp') == 'baz'

callbacks = ['before_training', 'after_epoch', 'on_error',
'on_interrupt', 'on_resumption', 'after_training']

for callback in callbacks:
yield check, callback

0 comments on commit 3f002e9

Please sign in to comment.