From f301f18261b8a0154b4c28de83786408f4778df9 Mon Sep 17 00:00:00 2001 From: David Warde-Farley Date: Thu, 27 Oct 2016 01:44:10 -0400 Subject: [PATCH 1/3] Add 'on_error' to BOOLEAN_TRIGGERS. --- blocks/extensions/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blocks/extensions/__init__.py b/blocks/extensions/__init__.py index 627ce933..f81f01b6 100644 --- a/blocks/extensions/__init__.py +++ b/blocks/extensions/__init__.py @@ -219,7 +219,7 @@ class SimpleExtension(TrainingExtension): """ BOOLEAN_TRIGGERS = frozenset(["before_training", "before_first_epoch", "before_epoch", "before_batch", - "on_resumption", "on_interrupt", + "on_resumption", "on_interrupt", "on_error", "after_epoch", "after_batch", "after_training"]) From 53477f5d9aae2298835365d7cf487c184cff6979 Mon Sep 17 00:00:00 2001 From: David Warde-Farley Date: Thu, 27 Oct 2016 01:45:38 -0400 Subject: [PATCH 2/3] Improve existing Timestamp test. --- tests/extensions/test_extensions.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/extensions/test_extensions.py b/tests/extensions/test_extensions.py index f9b5ab13..e0cb57d8 100644 --- a/tests/extensions/test_extensions.py +++ b/tests/extensions/test_extensions.py @@ -176,6 +176,11 @@ def check(kwargs): assert ext.main_loop.log.current_row[log_record] == 'foo' # Exercise original get_timestamp. ext.do('after_epoch') + sep = kwargs.get('separator', ' ') + assert bool(re.match(''.join(['[0-9]{4}-[0-9]{2}-[0-9]{2}', sep, + '[0-9]{2}(\\:[0-9]{2}){2}' + '\\.[0-9]+']), + ext.main_loop.log.current_row[log_record])) yield check, {} yield check, {'log_record': 'loggy mclogpants'} From e2c38291c3612d77fd394e4ef34717cd8c2da346 Mon Sep 17 00:00:00 2001 From: David Warde-Farley Date: Thu, 27 Oct 2016 01:47:03 -0400 Subject: [PATCH 3/3] Add default triggers to the Timestamp extension. --- blocks/extensions/__init__.py | 14 +++++++++++++- tests/extensions/test_extensions.py | 16 ++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/blocks/extensions/__init__.py b/blocks/extensions/__init__.py index f81f01b6..92ee1c2f 100644 --- a/blocks/extensions/__init__.py +++ b/blocks/extensions/__init__.py @@ -669,6 +669,15 @@ class Timestamp(SimpleExtension): Separator between the date and time. ISO 8601 specifies 'T'. Here, we default to ' ' (blank space) for human readability. + Notes + ----- + By default, triggers after every epoch as well as before training + starts, after training finishes, when an error occurs or when training + is interrupted or resumed, as these are all generally useful + circumstances for which to have a timestamp. These can be disabled + by passing `False` as the appropriate keyword argument; see + :class:`SimpleExtension`. + """ DEFAULT_LOG_RECORD = 'timestamp' @@ -676,7 +685,10 @@ def __init__(self, log_record=DEFAULT_LOG_RECORD, separator=' ', **kwargs): self.log_record = log_record self.separator = separator - kwargs.setdefault('after_epoch', True) + default_callbacks = ['before_training', 'after_epoch', 'on_error', + 'on_interrupt', 'on_resumption', 'after_training'] + for callback in default_callbacks: + kwargs.setdefault(callback, True) super(Timestamp, self).__init__(**kwargs) def do(self, *args): diff --git a/tests/extensions/test_extensions.py b/tests/extensions/test_extensions.py index e0cb57d8..c9c2eafe 100644 --- a/tests/extensions/test_extensions.py +++ b/tests/extensions/test_extensions.py @@ -1,3 +1,4 @@ +import re from mock import Mock from numpy.testing import assert_raises @@ -184,3 +185,18 @@ def check(kwargs): yield check, {} yield check, {'log_record': 'loggy mclogpants'} + + +def test_timestamp_default_triggers(): + def check(callback): + ext = InjectedTimestamp() + ext.main_loop = Mock() + ext.main_loop.log.current_row = {} + ext.dispatch(callback) + assert ext.main_loop.log.current_row.get('timestamp') == 'baz' + + callbacks = ['before_training', 'after_epoch', 'on_error', + 'on_interrupt', 'on_resumption', 'after_training'] + + for callback in callbacks: + yield check, callback