Skip to content

Commit

Permalink
Merge 7bf7ea8 into 986d861
Browse files Browse the repository at this point in the history
  • Loading branch information
kuenishi committed Apr 16, 2018
2 parents 986d861 + 7bf7ea8 commit be4ff82
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 9 deletions.
2 changes: 1 addition & 1 deletion chainer/functions/array/split_axis.py
Expand Up @@ -108,7 +108,7 @@ def _ideep_is_supported(self, inputs):
if indices[0] == 0:
return False # Sequence starting with 0
for i in six.moves.range(1, len(indices)):
if indices[i-1] == indices[i]:
if indices[i - 1] == indices[i]:
return False # Sequence with duplicate index
else:
if self.sections == 1:
Expand Down
19 changes: 18 additions & 1 deletion chainer/training/extension.py
Expand Up @@ -98,6 +98,22 @@ def initialize(self, trainer):
"""
pass

def on_error(self, trainer, exc, tb):
"""Handle the error during training before finalize.
This method is called when an exception is thrown during the
training loop. An extension that needs different error
handling from finalize, can override this method to handle
errors.
Args:
trainer (Trainer): Trainer object that runs the training loop.
exp (Exception): arbitrary exception thrown during update loop.
tb (Traceback): traceback of the exception
"""
pass

def serialize(self, serializer):
"""Serializes the extension state.
Expand All @@ -109,7 +125,7 @@ def serialize(self, serializer):


def make_extension(trigger=None, default_name=None, priority=None,
finalizer=None, initializer=None, **kwargs):
finalizer=None, initializer=None, on_error=None, **kwargs):
"""Decorator to make given functions into trainer extensions.
This decorator just adds some attributes to a given function. The value of
Expand Down Expand Up @@ -144,6 +160,7 @@ def decorator(ext):
ext.default_name = default_name or ext.__name__
ext.priority = priority
ext.finalize = finalizer
ext.on_error = on_error
ext.initialize = initializer
return ext

Expand Down
35 changes: 29 additions & 6 deletions chainer/training/extensions/_snapshot.py
Expand Up @@ -6,7 +6,8 @@
from chainer import utils


def snapshot_object(target, filename, savefun=npz.save_npz):
def snapshot_object(target, filename, snapshot_on_error=False,
savefun=npz.save_npz):
"""Returns a trainer extension to take snapshots of a given object.
This extension serializes the given object and saves it to the output
Expand All @@ -27,21 +28,32 @@ def snapshot_object(target, filename, savefun=npz.save_npz):
the :meth:`str.format` method. For example,
``'snapshot_{.updater.iteration}'`` is converted to
``'snapshot_10000'`` at the 10,000th iteration.
snapshot_on_error (bool): Whether to take a snapshot in case trainer
loop has been failed.
savefun: Function to save the object. It takes two arguments: the
output file path and the object to serialize.
Returns:
An extension function.
"""
@extension.make_extension(trigger=(1, 'epoch'), priority=-100)
error_handler = None
if snapshot_on_error:
def h(trainer, exception, exc_info):
_snapshot_object(trainer, trainer, filename.format(trainer),
savefun)
error_handler = h

@extension.make_extension(trigger=(1, 'epoch'), priority=-100,
on_error=error_handler)
def snapshot_object(trainer):
_snapshot_object(trainer, target, filename.format(trainer), savefun)
_snapshot_object(trainer, target, filename.format(trainer),
savefun)

return snapshot_object


def snapshot(savefun=npz.save_npz,
def snapshot(savefun=npz.save_npz, snapshot_on_error=False,
filename='snapshot_iter_{.updater.iteration}'):
"""Returns a trainer extension to take snapshots of the trainer.
Expand All @@ -66,14 +78,25 @@ def snapshot(savefun=npz.save_npz,
Args:
savefun: Function to save the trainer. It takes two arguments: the
output file path and the trainer object.
snapshot_on_error (bool): Whether to take a snapshot in case trainer
loop has been failed.
filename (str): Name of the file into which the trainer is serialized.
It can be a format string, where the trainer object is passed to
the :meth:`str.format` method.
"""
@extension.make_extension(trigger=(1, 'epoch'), priority=-100)
error_handler = None
if snapshot_on_error:
def h(trainer, exception, exc_info):
_snapshot_object(trainer, trainer, filename.format(trainer),
savefun)
error_handler = h

@extension.make_extension(trigger=(1, 'epoch'), priority=-100,
on_error=error_handler)
def snapshot(trainer):
_snapshot_object(trainer, trainer, filename.format(trainer), savefun)
_snapshot_object(trainer, trainer, filename.format(trainer),
savefun)

return snapshot

Expand Down
13 changes: 13 additions & 0 deletions chainer/training/trainer.py
Expand Up @@ -315,6 +315,19 @@ def run(self, show_loop_exception_msg=True):
traceback.print_tb(sys.exc_info()[2])
f.write('Will finalize trainer extensions and updater before '
'reraising the exception.\n')
for _, entry in extensions:
handler = getattr(entry.extension, 'on_error', None)
if handler:
try:
# It is guaranteed all handlers are called,
# but exceptions thrown by those handlers are
# just printed and ignored, as well as its
# return values.
handler(self, e, sys.exc_info()[2])
except Exception as he:
f.write('Exception in error handler: {}\n'.format(he))
traceback.print_tb(sys.exc_info()[2])
f.write('Traceback (most recent call last):\n')
six.reraise(*sys.exc_info())
finally:
for _, entry in extensions:
Expand Down
3 changes: 2 additions & 1 deletion examples/mnist/train_mnist.py
Expand Up @@ -85,7 +85,8 @@ def main():

# Take a snapshot for each specified epoch
frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))
trainer.extend(extensions.snapshot(snapshot_on_error=True),
trigger=(frequency, 'epoch'))

# Write a log of evaluation statistics for each epoch
trainer.extend(extensions.LogReport())
Expand Down
Expand Up @@ -4,6 +4,7 @@
import mock

from chainer import testing
from chainer import training
from chainer.training import extensions


Expand Down Expand Up @@ -48,4 +49,39 @@ def test_clean_up_tempdir(self):
self.assertEqual(len(left_tmps), 0)


class TestSnapshotOnError(unittest.TestCase):

def setUp(self):
self.trainer = testing.get_trainer_with_mock_updater()
self.trainer.out = '.'
self.trainer._done = True
self.filename = 'myfile-deadbeef.dat'

def tearDown(self):
if os.path.exists(self.filename):
os.remove(self.filename)

def test_on_error(self):

class TheOnlyError(Exception):
pass

@training.make_extension(trigger=(1, 'iteration'), priority=100)
def exception_raiser(trainer):
raise TheOnlyError()
self.trainer.extend(exception_raiser)

snapshot = extensions.snapshot_object(self.trainer, self.filename,
snapshot_on_error=True)
self.trainer.extend(snapshot)
self.trainer._done = False

self.assertFalse(os.path.exists(self.filename))

with self.assertRaises(TheOnlyError):
self.trainer.run()

self.assertTrue(os.path.exists(self.filename))


testing.run_module(__name__, __file__)
75 changes: 75 additions & 0 deletions tests/chainer_tests/training_tests/test_trainer.py
@@ -1,4 +1,5 @@
import time
import traceback
import unittest

from chainer import testing
Expand All @@ -23,6 +24,29 @@ def initialize(self, trainer):
trainer.is_initialized = True


class ErrorHandlingExtension(training.extension.Extension):

def __init__(self):
self.is_error_handled = False

def __call__(self, trainer):
pass

def on_error(self, trainer, exception, tb):
traceback.print_tb(tb)
self.is_error_handled = True

def finalize(self):
pass

def initialize(self, trainer):
pass


class TheOnlyError(Exception):
pass


class DummyCallableClass(object):

def __init__(self, test_case):
Expand Down Expand Up @@ -169,5 +193,56 @@ def dummy_extension_2(trainer):
self.trainer.run()
self.assertEqual(self.called_order, [2, 1])

def test_exception_handler(self):

ext = ErrorHandlingExtension()
self.trainer.extend(ext, trigger=(1, 'iteration'), priority=1)
self.assertFalse(ext.is_error_handled)

d = {}

def exception_handler(trainer, exp, tb):
d['called'] = True

@training.make_extension(trigger=(1, 'iteration'), priority=100,
on_error=exception_handler)
def exception_raiser(trainer):
raise TheOnlyError()
self.trainer.extend(exception_raiser)

dummy_extension = DummyExtension(self)
self.trainer.extend(dummy_extension)

with self.assertRaises(TheOnlyError):
self.trainer.run()

self.assertTrue(d['called'])
self.assertTrue(ext.is_error_handled)
self.assertTrue(dummy_extension.is_finalized)

def test_exception_in_exception_handler(self):

ext = ErrorHandlingExtension()
self.trainer.extend(ext, trigger=(1, 'iteration'), priority=1)
self.assertFalse(ext.is_error_handled)

def exception_handler(trainer, exp, tb):
raise ValueError('hogehoge')

@training.make_extension(trigger=(1, 'iteration'), priority=100,
on_error=exception_handler)
def exception_raiser(trainer):
raise TheOnlyError()
self.trainer.extend(exception_raiser)

dummy_extension = DummyExtension(self)
self.trainer.extend(dummy_extension)

with self.assertRaises(TheOnlyError):
self.trainer.run()

self.assertTrue(ext.is_error_handled)
self.assertTrue(dummy_extension.is_finalized)


testing.run_module(__name__, __file__)

0 comments on commit be4ff82

Please sign in to comment.