Skip to content

Commit

Permalink
Support dtype option to check_double_backward and fix y_grad set up
Browse files Browse the repository at this point in the history
  • Loading branch information
beam2d committed Aug 15, 2017
1 parent 04157f6 commit 4dd002f
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions chainer/gradient_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,8 @@ def g():


def check_double_backward(func, x_data, y_grad, x_grad_grad, params=(),
eps=1e-3, atol=1e-4, rtol=1e-3, no_grads=None):
eps=1e-3, atol=1e-4, rtol=1e-3, no_grads=None,
dtype=None):
"""Test twice differentiation of a given procedure.
This function automatically checks if the backward procedure of ``func``
Expand Down Expand Up @@ -360,22 +361,30 @@ def check_double_backward(func, x_data, y_grad, x_grad_grad, params=(),
of the first order gradients.
"""
def first_order_grad(*xs):
x_data = _as_tuple(x_data)
n_x = len(x_data)

def first_order_grad(*inputs):
xs = inputs[:n_x]
gys = inputs[n_x:]

y = _as_tuple(func(*xs))
# Let all elements of y share the same creator.
# See the comment in check_backward.
y = identity.Identity().apply(y)

_set_y_grad(y, y_grad)
_set_y_grad(y, gys)
y[0].backward()

ret = tuple([x.grad_var for x in xs])
for x in xs:
x.grad_var = None
return ret

check_backward(first_order_grad, x_data, x_grad_grad, params=params,
eps=eps, atol=atol, rtol=rtol, no_grads=no_grads)
inputs = x_data + _as_tuple(y_grad)
check_backward(first_order_grad, inputs, x_grad_grad, params=params,
eps=eps, atol=atol, rtol=rtol, no_grads=no_grads,
dtype=dtype)


def _set_y_grad(y, y_grad):
Expand All @@ -384,7 +393,10 @@ def _set_y_grad(y, y_grad):
raise ValueError(
'`y_grad` must have the same length of output values')
for iy, igy in six.moves.zip(y, y_grad):
iy.grad = igy
if isinstance(igy, variable.Variable):
iy.grad_var = igy
else:
iy.grad = igy
else:
if len(y) != 1:
raise ValueError(
Expand Down

0 comments on commit 4dd002f

Please sign in to comment.