Skip to content

Commit

Permalink
Merge pull request #5698 from niboshi/fix-gradient-check-test
Browse files Browse the repository at this point in the history
Add gradient consistency checks in `numerical_grad`
  • Loading branch information
okuta committed Nov 25, 2018
2 parents 170b9e0 + 0fdb8d5 commit efeb055
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 30 deletions.
8 changes: 6 additions & 2 deletions chainer/gradient_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def numerical_grad(
inputs (tuple of arrays): Tuple of arrays that should be treated as
inputs. Each element of them is slightly modified to realize
numerical gradient by finite differences.
grad_outputs (tuple of arrays): Tuple of arrays that are treated as
output gradients.
grad_outputs (tuple of arrays or scalars): Tuple of arrays or scalars
that are treated as output gradients.
eps (float): Epsilon value of finite differences.
detect_nondifferentiable (bool):
``False`` by default.
Expand Down Expand Up @@ -115,6 +115,10 @@ def numerical_grad(
def eval_func(x, i, delta, orig):
x[i] = orig + delta
y = _copy_arrays(f())
assert len(y) == len(grad_outputs)
assert all([
gy is None or numpy.isscalar(gy) or y_.shape == gy.shape
for y_, gy in zip(y, grad_outputs)])
x[i] = orig
return y

Expand Down
62 changes: 34 additions & 28 deletions tests/chainer_tests/test_gradient_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def _dot(x, y):

class NumericalGradientTest(unittest.TestCase):

in_shapes = ((2, 1),)
gout_shapes = ((2, 1),)
eps = None
atol = 1e-3
rtol = 1e-3
Expand All @@ -45,8 +47,9 @@ def df(self, xs):
return (2 * xs[0],),

def setUp(self):
self.xs = (_uniform(2, 1),)
self.gys = (_uniform(2, 1),)
self.xs = tuple([_uniform(*s) for s in self.in_shapes])
self.gys = tuple([
None if s is None else _uniform(*s) for s in self.gout_shapes])

def check_numerical_grad_one(self, f, df, xs, gys, eps):
dfxs = df(xs)
Expand Down Expand Up @@ -88,6 +91,9 @@ def test_numerical_grad_gpu(self):

class NumericalGradientTest2(NumericalGradientTest):

in_shapes = ((),)
gout_shapes = ((),)

def f(self, xs):
return 1,

Expand All @@ -97,6 +103,8 @@ def df(self, xs):

class NumericalGradientTest3(NumericalGradientTest):

in_shapes = ((2, 1),)
gout_shapes = ((2, 1),)
# Too small eps causes cancellation of significant digits
eps = (1e-2, 1e-3)

Expand All @@ -115,6 +123,8 @@ def setUp(self):

class NumericalGradientTest4(NumericalGradientTest):

in_shapes = ((2, 1), (2, 1))
gout_shapes = ((2, 1), (2, 1), (2, 1))
atol = 1e-2
rtol = 1e-2

Expand All @@ -130,12 +140,11 @@ def df(self, xs):
(_full_like(xs[0], 2), _full_like(xs[0], 4), _full_like(xs[0], 6)),
(_full_like(xs[1], 3), _full_like(xs[1], 5), _full_like(xs[1], 7)))

def setUp(self):
self.xs = tuple(_uniform(2, 1) for _ in six.moves.range(2))
self.gys = tuple(_uniform(2, 1) for _ in six.moves.range(3))

class NumericalGradientTest5(NumericalGradientTest):

class NumericalGradientTest5(NumericalGradientTest4):
in_shapes = ((2, 1), (2, 1))
gout_shapes = ((2, 1), None, (2, 1))

def f(self, xs):
assert len(xs) == 2
Expand All @@ -149,16 +158,11 @@ def df(self, xs):
(_full_like(xs[0], 2), _zeros_like(xs[0]), _full_like(xs[0], 6)),
(_full_like(xs[1], 3), _zeros_like(xs[1]), _full_like(xs[1], 7)))

def setUp(self):
super(NumericalGradientTest5, self).setUp()
self.gys = (_uniform(2, 1), None, _uniform(2, 1))


class NumericalGradientTest6(NumericalGradientTest):

def setUp(self):
self.xs = (_uniform(2, 1),)
self.gys = (None,)
in_shapes = ((2, 1),)
gout_shapes = (None,)


class NumericalGradientReferenceTest(unittest.TestCase):
Expand Down Expand Up @@ -394,18 +398,19 @@ def _func_nan_segment(self, x):
y[-1 < x < 1] = numpy.nan
return y,

def check_positive(self, xp, func_name, inputs, eps, nout):
def check_positive(self, xp, func_name, input, eps, nout):
# Should be non-differentiable
func = getattr(self, '_func_{}'.format(func_name))
grad_outputs = [
xp.random.uniform(-1, 1, _.shape).astype(_.dtype) for _ in inputs]
xp.random.uniform(-1, 1, input.shape).astype(input.dtype)
for _ in range(nout)]

def f():
return func(*inputs) * nout
return func(input) * nout

try:
gradient_check.numerical_grad(
f, inputs, grad_outputs, eps=eps,
f, (input,), grad_outputs, eps=eps,
detect_nondifferentiable=True)
except gradient_check.NondifferentiableError:
pass
Expand All @@ -414,46 +419,47 @@ def f():
'Function `{}` is expected to be non-differentiable, '
'but determined to be differentiable.\n\n'
'eps: {}\n'
'inputs: {}\n'
'input: {}\n'
'xp: {}\n'
''.format(
func_name, eps, inputs, xp.__name__))
func_name, eps, input, xp.__name__))

def check_negative(self, xp, func_name, inputs, eps, nout):
def check_negative(self, xp, func_name, input, eps, nout):
# Should be differentiable
func = getattr(self, '_func_{}'.format(func_name))
grad_outputs = [
xp.random.uniform(-1, 1, _.shape).astype(_.dtype) for _ in inputs]
xp.random.uniform(-1, 1, input.shape).astype(input.dtype)
for _ in range(nout)]

def f():
return func(*inputs) * nout
return func(input) * nout

try:
gradient_check.numerical_grad(
f, inputs, grad_outputs, eps=eps,
f, (input,), grad_outputs, eps=eps,
detect_nondifferentiable=True)
except gradient_check.NondifferentiableError as e:
raise AssertionError(
'Function `{}` is expected to be differentiable, '
'but determined to be non-differentiable.\n\n'
'eps: {}\n'
'inputs: {}\n'
'input: {}\n'
'xp: {}\n\n'
'{}: {}'
''.format(
func_name, eps, inputs, xp.__name__,
func_name, eps, input, xp.__name__,
e.__class__.__name__, e))

def check(self, xp, nout):
inputs = [xp.asarray(self.x).astype(numpy.float32)]
input = xp.asarray(self.x).astype(numpy.float32)
with warnings.catch_warnings():
if self.ignore_warning:
warnings.simplefilter('ignore', self.ignore_warning)

if self.result:
self.check_positive(xp, self.func, inputs, self.eps, nout)
self.check_positive(xp, self.func, input, self.eps, nout)
else:
self.check_negative(xp, self.func, inputs, self.eps, nout)
self.check_negative(xp, self.func, input, self.eps, nout)

def test_cpu(self):
self.check(numpy, 1)
Expand Down

0 comments on commit efeb055

Please sign in to comment.