Skip to content

Commit

Permalink
Merge b1c7f10 into d0ebd36
Browse files Browse the repository at this point in the history
  • Loading branch information
rezoo committed Apr 4, 2018
2 parents d0ebd36 + b1c7f10 commit 84b27ab
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 0 deletions.
1 change: 1 addition & 0 deletions chainer/training/extensions/__init__.py
Expand Up @@ -7,6 +7,7 @@
from chainer.training.extensions.linear_shift import LinearShift # NOQA
from chainer.training.extensions.log_report import LogReport # NOQA
from chainer.training.extensions.micro_average import MicroAverage # NOQA
from chainer.training.extensions.nan_killer import NaNKiller # NOQA
from chainer.training.extensions.parameter_statistics import ParameterStatistics # NOQA
from chainer.training.extensions.plot_report import PlotReport # NOQA
from chainer.training.extensions.print_report import PrintReport # NOQA
Expand Down
21 changes: 21 additions & 0 deletions chainer/training/extensions/nan_killer.py
@@ -0,0 +1,21 @@
from chainer.training import extension


class NaNKiller(extension.Extension):
"""Trainer extension to raise RuntimeError if parameters contain NaN.
Although parameters including NaN are unnecessary in most cases,
:class:`~chainer.training.Trainer` will continue to compute even if
the parameters in a given optimizer diverge. This extension is aimed to
reduce unnecessary computations by throwing ``RuntimeError``
if the parameters contain NaN.
"""

def __call__(self, trainer):
optimizers = trainer.updater.get_all_optimizers()
for optimizer in optimizers.values():
target = optimizer.target
xp = target.xp
for param in target.params():
if xp.isnan(param.array).any():
raise RuntimeError('NaN detected. R.I.P.')
1 change: 1 addition & 0 deletions docs/source/reference/training.rst
Expand Up @@ -67,6 +67,7 @@ The typical use case is to use :class:`~chainer.training.extensions.Evaluator` t
chainer.training.extensions.Evaluator
chainer.training.extensions.MicroAverage

chainer.training.extensions.NaNKiller
chainer.training.extensions.ParameterStatistics

chainer.training.extensions.observe_lr
Expand Down
@@ -0,0 +1,88 @@
import os
import shutil
import tempfile
import unittest

import numpy

import chainer
from chainer import links
from chainer import testing
from chainer.testing import attr
from chainer import training


class Model(chainer.Chain):

def __init__(self):
super(Model, self).__init__()
with self.init_scope():
self.l = links.Linear(1, 3)

def __call__(self, x):
return self.l(x)


class Dataset(chainer.dataset.DatasetMixin):

def __init__(self, values):
self.values = values

def __len__(self):
return len(self.values)

def get_example(self, i):
return numpy.array([self.values[i]], numpy.float32), numpy.int32(i % 2)


class TestNaNKiller(unittest.TestCase):

def setUp(self):
self.n_data = 4
self.n_epochs = 3

self.model = Model()
self.classifier = links.Classifier(self.model)
self.optimizer = chainer.optimizers.Adam()
self.optimizer.setup(self.classifier)

self.dataset = Dataset([i for i in range(self.n_data)])
self.iterator = chainer.iterators.SerialIterator(
self.dataset, 1, shuffle=False)
self.temp_dir = tempfile.mkdtemp()

def tearDown(self):
shutil.rmtree(self.temp_dir)

def prepare(self, dirname='test', device=None):
outdir = os.path.join(self.temp_dir, dirname)
self.updater = training.updaters.StandardUpdater(
self.iterator, self.optimizer, device=device)
self.trainer = training.Trainer(
self.updater, (self.n_epochs, 'epoch'), out=outdir)
self.trainer.extend(training.extensions.NaNKiller())

def test_trainer(self):
self.prepare(dirname='test_trainer')
self.trainer.run()

def test_nan_killer(self):
self.prepare(dirname='test_nan_killer')
self.model.l.W.array[1, 0] = numpy.nan
with self.assertRaises(RuntimeError):
self.trainer.run(show_loop_exception_msg=False)

@attr.gpu
def test_trainer_gpu(self):
self.prepare(dirname='test_trainer_gpu', device=0)
self.trainer.run()

@attr.gpu
def test_nan_killer_gpu(self):
self.prepare(dirname='test_nan_killer_gpu', device=0)
self.model.l.W.array[:] = numpy.nan
with self.assertRaises(RuntimeError):
self.trainer.run(show_loop_exception_msg=False)


testing.run_module(__name__, __file__)

0 comments on commit 84b27ab

Please sign in to comment.