Skip to content

Commit

Permalink
Merge pull request #1126 from pbrakel/add_before_batch
Browse files Browse the repository at this point in the history
add before_batch to SimpleExtension
  • Loading branch information
rizar committed Jul 5, 2016
2 parents 6dd819f + 3f32dd1 commit 7f380de
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
7 changes: 4 additions & 3 deletions blocks/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,10 @@ class SimpleExtension(TrainingExtension):
"""
BOOLEAN_TRIGGERS = frozenset(["before_training", "before_first_epoch",
"before_epoch", "on_resumption",
"on_interrupt", "after_epoch",
"after_batch", "after_training"])
"before_epoch", "before_batch",
"on_resumption", "on_interrupt",
"after_epoch", "after_batch",
"after_training"])

INTEGER_TRIGGERS = frozenset(["after_n_epochs", "after_n_batches",
"every_n_epochs", "every_n_batches"])
Expand Down
16 changes: 16 additions & 0 deletions tests/extensions/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,19 @@ def do(self, *args):
comp.do.assert_called_once_with('before_training')
a.do.assert_called_once_with('after_training')
b.do.assert_called_once_with('after_batch')


def test_simple_extension_before_batch_callback():

class Foo(SimpleExtension):
def __init__(self, **kwargs):
self.do = Mock()
super(Foo, self).__init__(**kwargs)

def do(self, which_callback, *args):
pass

ext = Foo(before_batch=True)
ext.main_loop = Mock()
ext.dispatch('before_batch')
ext.do.assert_called_once_with('before_batch')

0 comments on commit 7f380de

Please sign in to comment.