Skip to content

Commit

Permalink
Merge d03633c into cfad351
Browse files Browse the repository at this point in the history
  • Loading branch information
niboshi committed Aug 14, 2017
2 parents cfad351 + d03633c commit d298cec
Show file tree
Hide file tree
Showing 26 changed files with 133 additions and 81 deletions.
102 changes: 81 additions & 21 deletions chainer/gradient_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def check_backward(func, x_data, y_grad, params=(),
eps=1e-3, atol=1e-5, rtol=1e-4, no_grads=None, dtype=None):
"""Test backward procedure of a given function.
This function automatically check backward-process of given function.
This function automatically checks backward-process of a given function.
For example, when you have a :class:`~chainer.Function` class ``MyFunc``,
that gets two arguments and returns one value, you can make its test like
this::
Expand All @@ -130,6 +130,22 @@ def check_backward(func, x_data, y_grad, params=(),
To check correctness of the gradients, the function calls
:func:`numerical_grad` to calculate numerically the gradients and compares
the types of gradients with :func:`chainer.testing.assert_allclose`.
To reduce computational time, it uses a function
:math:`g: \\mathbb{R} \\rightarrow \\mathbb{R}^n` defined as
:math:`g(\\alpha) = f(\\alpha x)`, where :math:`\\alpha \in \\mathbb{R}`
and :math:`f` is a function which actually
you want to test.
Its gradient is
.. math::
g'(\\alpha) = f'(\\alpha x) \\cdot x.
When :math:`\\alpha = 1`, :math:`g'(1) = f'(x) \\cdot x`.
So :math:`g'(1)` is calculated with :func:`numerical_grad` and
compared with dot product of the gradient :math:`f` and
:math:`x`.
If input objects (``x1_data`` or/and ``x2_data`` in this example) represent
integer variables, their gradients are ignored.
Expand Down Expand Up @@ -200,9 +216,9 @@ def check_backward(func, x_data, y_grad, params=(),
:func:`chainer.testing.assert_allclose`.
no_grads (list of bool): Flag to skip variable for gradient assertion.
It should be same length as ``x_data``.
dtype (~numpy.dtype): ``x_data`` and ``y_grad`` are casted to this
dtype when calculating numerical gradients. Only float types and
``None`` are allowed.
dtype (~numpy.dtype): ``x_data``, ``y_grad`` and ``params`` are casted
to this dtype when calculating numerical gradients. Only float
types and ``None`` are allowed.
See:
:func:`numerical_grad`
Expand Down Expand Up @@ -241,40 +257,84 @@ def check_backward(func, x_data, y_grad, params=(),
# `Variable.backward` method calls `Function.backward` of its creator.
y[0].backward()

param_data = [p.data for p in params]
if dtype is None:
casted_xs = [variable.Variable(x) for x in x_data]
else:
if numpy.dtype(dtype).kind != 'f':
raise ValueError('`dtype` is allowed only float type')
if len(params) > 0:
raise ValueError('`dtype` is available only if `params` is empty')
casted_xs = [variable.Variable(x.astype(dtype, copy=False)
if x.dtype.kind == 'f' else x)
for x in x_data]

def f():
ys = func(*casted_xs)
ys = _as_tuple(ys)
return tuple(y.data for y in ys)

if no_grads is None:
no_grads = [x.dtype.kind != 'f' for x in xs]
else:
if len(no_grads) != len(xs):
raise ValueError(
'Length of no_grads param and xs should be same.')
for skip, x, cx in six.moves.zip(no_grads, xs, casted_xs):
casted_data = [x.data.copy() for x in casted_xs]
for skip, x in six.moves.zip(no_grads, xs):
if skip:
assert x.grad is None
continue
gx, = numerical_grad(f, (cx.data,), y_grad, eps=eps)
testing.assert_allclose(gx, x.grad, atol=atol, rtol=rtol)
if dtype is None:
assert gx.dtype == x.grad.dtype
else:
assert gx.dtype.kind == 'f' and gx.dtype == dtype
if x.grad is None:
raise RuntimeError(
'gradients of some arguments are not calculated')

xp = cuda.get_array_module(*xs)
one = xp.array(1., dtype)

def g():
# This functions is called twice in `numerical_grad`.
# `one` is `1 + epsilon` or `1 - epsilon` in these calls.
# See the document of `numerical_grad`.
for skip, cx, data in six.moves.zip(no_grads, casted_xs, casted_data):
if skip:
continue
# astype is require to store data with the given type
data = (one * data).astype(data.dtype)
if numpy.isscalar(data):
data = xp.array(data)
cx.data = data
for param, data in six.moves.zip(params, param_data):
if dtype is not None:
param_dtype = dtype
else:
param_dtype = param.dtype
# The inner astype is required to calculates __mul__ in
# `param_type` when data is low accuracy float.
# The outer one is require to store data with the given type.
param.data = (one * data.astype(param_dtype)).astype(param_dtype)
ys = func(*casted_xs)
ys = _as_tuple(ys)
ys_data = tuple(y.data for y in ys)
for skip, cx, data in six.moves.zip(no_grads, casted_xs, casted_data):
if skip:
continue
cx.data = data
for param, data in six.moves.zip(params, param_data):
param.data = data
return ys_data

gx, = numerical_grad(g, (one,), y_grad, eps=eps)
gx_accum = 0
for skip, x, cx in six.moves.zip(no_grads, xs, casted_xs):
if skip:
continue
gxi = x.grad.ravel()
cxi = cx.data.ravel()
if dtype is not None:
gxi = gxi.astype(dtype)
cxi = cxi.astype(dtype)
gx_accum += gxi.dot(cxi)

for p in params:
gp, = numerical_grad(f, (p.data,), y_grad, eps=eps)
testing.assert_allclose(gp, p.grad, atol=atol, rtol=rtol)
assert gp.dtype is p.grad.dtype
gpi = p.grad.ravel()
pi = p.data.ravel()
if dtype is not None:
gpi = gpi.astype(dtype)
pi = pi.astype(dtype)
gx_accum += gpi.dot(pi)

testing.assert_allclose(gx, gx_accum, atol=atol, rtol=rtol)
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_forward_gpu(self):
def check_backward(self, x_data, y_grad):
gradient_check.check_backward(
lambda x: functions.maxout(x, self.pool_size, self.axis),
x_data, y_grad, eps=0.125)
x_data, y_grad, eps=0.125, dtype='d')

@condition.retry(3)
def test_backward_cpu(self):
Expand Down
3 changes: 2 additions & 1 deletion tests/chainer_tests/functions_tests/array_tests/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def test_forward_gpu(self):
def check_backward(self, x_data, g_data):
func = functions.Cast(self.out_type)
gradient_check.check_backward(
func, x_data, g_data, eps=2.0 ** -2, atol=1e-3, rtol=1e-3)
func, x_data, g_data, dtype='d',
eps=2.0 ** -2, atol=1e-2, rtol=1e-3)

def test_backward_cpu(self):
self.check_backward(self.x, self.g)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def check_backward(self, xs_data, g_data):
def func(*xs):
return functions.dstack(xs)

gradient_check.check_backward(
func, xs_data, g_data, eps=2.0 ** -2, atol=1e-3, rtol=1e-3)
gradient_check.check_backward(func, xs_data, g_data, dtype='d')

def test_backward_cpu(self):
self.check_backward(self.xs, self.g)
Expand Down
12 changes: 6 additions & 6 deletions tests/chainer_tests/functions_tests/array_tests/test_get_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def test_forward_gpu(self):
self.check_forward(cuda.to_gpu(self.x_data))

def check_backward(self, x_data, y_grad):
gradient_check.check_backward(functions.GetItem(self.slices),
(x_data,), y_grad)
gradient_check.check_backward(
functions.GetItem(self.slices), (x_data,), y_grad, dtype='d')

def test_backward_cpu(self):
self.check_backward(self.x_data, self.gy_data)
Expand Down Expand Up @@ -136,8 +136,8 @@ def test_forward_gpu(self):
self.check_forward(cuda.to_gpu(self.x_data))

def check_backward(self, x_data, y_grad):
gradient_check.check_backward(functions.GetItem(self.slices),
(x_data,), y_grad)
gradient_check.check_backward(
functions.GetItem(self.slices), (x_data,), y_grad, dtype='d')

def test_backward_cpu(self):
self.check_backward(self.x_data, self.gy_data)
Expand Down Expand Up @@ -192,8 +192,8 @@ def check_backward(self, x_data, y_grad):
s = chainer.cuda.cupy.array(s, dtype=numpy.int32)
slices.append(s)
slices = tuple(slices)
gradient_check.check_backward(functions.GetItem(slices),
(x_data,), y_grad)
gradient_check.check_backward(
functions.GetItem(slices), (x_data,), y_grad, dtype='d')

@attr.gpu
def test_backward_gpu(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def func(*xs):
return functions.hstack(xs)

gradient_check.check_backward(
func, xs_data, g_data, eps=2.0 ** -2, atol=1e-3, rtol=1e-3)
func, xs_data, g_data, dtype='d', atol=1e-3, rtol=1e-3)

def test_backward_cpu(self):
self.check_backward(self.xs, self.g)
Expand Down
6 changes: 3 additions & 3 deletions tests/chainer_tests/functions_tests/array_tests/test_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def setUp(self):
self.check_backward_options = {'dtype': numpy.float64}
if self.dtype == numpy.float16:
self.check_backward_options.update({
'atol': 2 ** -6, 'rtol': 2 ** -6})
'atol': 2 ** -5, 'rtol': 2 ** -5})

def check_forward(self, x_data):
y = functions.pad(x_data, self.pad_width, self.mode)
Expand Down Expand Up @@ -91,8 +91,8 @@ def setUp(self):
self.g = numpy.random.uniform(-1, 1, out_shape).astype(self.dtype)
self.check_backward_options = {'dtype': numpy.float64}
if self.dtype == numpy.float16:
self.check_backward_options = {
'atol': 2 ** -6, 'rtol': 2 ** -6}
self.check_backward_options.update({
'atol': 2 ** -5, 'rtol': 2 ** -5})

def check_forward(self, x_data):
y = functions.pad(x_data, self.pad_width, mode=self.mode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_forward_gpu(self):
def check_backward(self, x_data, ind_data, g_data):
fun = functions.Permutate(axis=self.axis, inv=self.inv)
gradient_check.check_backward(
fun, (x_data, ind_data), g_data)
fun, (x_data, ind_data), g_data, dtype='d', atol=0.001, rtol=0.001)

def test_backward_cpu(self):
self.check_backward(self.x, self.indices, self.g)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def setUp(self):
def check_backward(self, x, output_shape, grads):
gradient_check.check_backward(
functions.ResizeImages(output_shape),
(x,), (grads,), atol=1e-2, rtol=1e-3, eps=1e-5)
(x,), (grads,), dtype='d', atol=1e-2, rtol=1e-3, eps=1e-5)

@condition.retry(3)
def test_backward_cpu(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def test_forward_gpu(self):

def check_backward(self, x_data, g_data):
gradient_check.check_backward(
functions.Rollaxis(self.axis, self.start), x_data, g_data)
functions.Rollaxis(self.axis, self.start), x_data, g_data,
dtype='d')

def test_backward_cpu(self):
self.check_backward(self.x, self.g)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def setUp(self):
0, 2, self.out_shape).astype(numpy.int32)
self.gy_data = numpy.random.uniform(
-1, 1, self.out_shape).astype(self.dtype)
self.check_backward_options = {}
self.check_backward_options = {'atol': 0.01, 'rtol': 0.01}
if self.dtype == numpy.float16:
self.check_backward_options = {'atol': 0.05, 'rtol': 0.05}
self.check_backward_options = {'atol': 0.1, 'rtol': 0.1}

def check_forward(self, x_data, t_data):
x = chainer.Variable(x_data)
Expand All @@ -57,7 +57,8 @@ def test_forward_gpu(self):
def check_backward(self, x_data, t_data, gy_data):
gradient_check.check_backward(
functions.SelectItem(),
(x_data, t_data), gy_data, eps=0.01, **self.check_backward_options)
(x_data, t_data), gy_data, eps=0.01, dtype='d',
**self.check_backward_options)

def test_backward_cpu(self):
self.check_backward(self.x_data, self.t_data, self.gy_data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ def setUp(self):
del yshape[self.axis]
self.gys = [numpy.random.uniform(-1, 1, yshape).astype(self.dtype)
for _ in range(self.shape[self.axis])]
self.check_backward_options = {}
if self.dtype == numpy.float16:
self.check_backward_options = {
'eps': 2 ** -5, 'atol': 1e-3, 'rtol': 1e-2}

def check_forward(self, x_data):
x = chainer.Variable(x_data)
Expand All @@ -59,8 +55,7 @@ def check_backward(self, x_data, gys_data):
def f(x):
return separate.separate(x, self.axis)

gradient_check.check_backward(
f, x_data, gys_data, **self.check_backward_options)
gradient_check.check_backward(f, x_data, gys_data, dtype='d')

def test_backward_cpu(self):
self.check_backward(self.x, self.gys)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def func(*xs):
return functions.stack(xs, self.axis)

gradient_check.check_backward(
func, xs_data, g_data, eps=2.0 ** -2, atol=1e-3, rtol=1e-3)
func, xs_data, g_data, eps=2.0 ** -2, dtype='d')

def test_backward_cpu(self):
self.check_backward(self.xs, self.g)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def func(*xs):
return functions.vstack(xs)

gradient_check.check_backward(
func, xs_data, g_data, eps=2.0 ** -2, atol=1e-3, rtol=1e-3)
func, xs_data, g_data, eps=2.0 ** -2, dtype='d')

def test_backward_cpu(self):
self.check_backward(self.xs, self.g)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _black_out(x, t, W, samples):

gradient_check.check_backward(
_black_out, (x_data, t_data, w_data, samples_data),
gy_data, atol=1.e-3)
gy_data, dtype='d', atol=1e-2)

@condition.retry(3)
def test_backward_cpu(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def test_forward_gpu_no_cudnn(self):
def check_backward(self, x0_data, x1_data, t_data, gy_data):
gradient_check.check_backward(
functions.Contrastive(self.margin, self.reduce),
(x0_data, x1_data, t_data), gy_data, rtol=1e-4, atol=1e-4)
(x0_data, x1_data, t_data), gy_data, dtype='d',
rtol=1e-3, atol=1e-4)

@condition.retry(3)
def test_backward_cpu(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_forward_gpu(self):
def check_backward(self, x_data, t_data):
gradient_check.check_backward(
functions.Hinge(self.norm), (x_data, t_data), None,
eps=0.01, atol=1e-4)
dtype='d', rtol=1e-3, atol=1e-4)

@condition.retry(3)
def test_backward_cpu(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def check_backward(self, x_data, t_data, y_grad):

gradient_check.check_backward(
functions.SigmoidCrossEntropy(),
(x_data, t_data), None, eps=1e-2)
(x_data, t_data), None, atol=1e-4, rtol=1e-3)

def check_backward_no_reduction(
self, x_data, t_data, y_grad):
Expand All @@ -137,7 +137,7 @@ def check_backward_no_reduction(

gradient_check.check_backward(
functions.SigmoidCrossEntropy(reduce='no'),
(x_data, t_data), y_grad, eps=1e-2)
(x_data, t_data), y_grad, atol=1e-4, rtol=1e-3)

@condition.retry(3)
def test_backward_cpu(self):
Expand Down

0 comments on commit d298cec

Please sign in to comment.