Skip to content

Commit

Permalink
Merge pull request #1131 from dmitriy-serdyuk/on-error
Browse files Browse the repository at this point in the history
Provide exception to `on_error` extensions
  • Loading branch information
dmitriy-serdyuk committed Aug 1, 2016
2 parents 7f380de + 0b9efa6 commit 2c1d7a5
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
12 changes: 9 additions & 3 deletions blocks/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,15 @@ def on_resumption(self):
pass

@callback
def on_error(self):
"""The callback invoked when an error occurs."""
pass
def on_error(self, exception):
"""The callback invoked when an error occurs.
Parameters
----------
exception : object
Exception occurred during the main loop run.
"""

@callback
def before_training(self):
Expand Down
2 changes: 1 addition & 1 deletion blocks/main_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def run(self):
self.log.current_row['got_exception'] = traceback.format_exc()
logger.error("Error occured during training." + error_message)
try:
self._run_extensions('on_error')
self._run_extensions('on_error', e)
except Exception:
logger.error(traceback.format_exc())
logger.error("Error occured when running extensions." +
Expand Down
6 changes: 3 additions & 3 deletions tests/test_main_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from multiprocessing import Process

from fuel.datasets import IterableDataset
from mock import MagicMock
from mock import MagicMock, ANY
from numpy.testing import assert_raises
from six.moves import cPickle

Expand Down Expand Up @@ -102,11 +102,11 @@ def test_error():
ext.on_error = MagicMock()
main_loop = MockMainLoop(extensions=[ext, FinishAfter(after_epoch=True)])
assert_raises(KeyError, main_loop.run)
ext.on_error.assert_called_once_with()
ext.on_error.assert_called_once_with(ANY)
assert 'got_exception' in main_loop.log.current_row

ext.on_error = MagicMock(side_effect=AttributeError)
main_loop = MockMainLoop(extensions=[ext, FinishAfter(after_epoch=True)])
assert_raises(KeyError, main_loop.run)
ext.on_error.assert_called_once_with()
ext.on_error.assert_called_once_with(ANY)
assert 'got_exception' in main_loop.log.current_row

0 comments on commit 2c1d7a5

Please sign in to comment.