Skip to content

Commit

Permalink
Merge pull request #3208 from okuta/fix-nan-check
Browse files Browse the repository at this point in the history
Fix nan check
  • Loading branch information
niboshi committed Aug 23, 2017
2 parents 1d00f3a + ff41355 commit 265e910
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
11 changes: 6 additions & 5 deletions chainer/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,11 +929,12 @@ def get_grad(node):
if gx is None:
continue
gx_data = gx.data
cuda.get_device_from_array(gx_data).use()
if cuda.get_array_module(gx_data).isnan(gx_data).any():
msg = ('NaN is detected on backward computation of '
'{}'.format(func.label))
raise RuntimeError(msg)
if gx_data.dtype.kind == 'f':
cuda.get_device_from_array(gx_data).use()
if cuda.get_array_module(gx_data).isnan(gx_data).any():
raise RuntimeError(
'NaN is detected on backward computation of '
'{}'.format(func.label))

if not retain_grad:
for y in outputs:
Expand Down
15 changes: 15 additions & 0 deletions tests/chainer_tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,21 @@ def test_traceback_cpu(self):
def test_traceback_gpu(self):
self.check_traceback(cuda.to_gpu(self.x))

def test_raise(self):
x = np.array([1], np.float32)
x = chainer.Variable(x)
y = F.identity(x)
y.grad = np.array([np.nan], np.float32)
with self.assertRaises(RuntimeError):
y.backward()

def test_int(self):
x = np.array([1], np.int)
x = chainer.Variable(x)
y = F.identity(x)
y.grad = np.array([0], np.int)
y.backward()


@testing.parameterize(*testing.product({
'in_shape': [(4, 3, 2)],
Expand Down

0 comments on commit 265e910

Please sign in to comment.