Skip to content

Commit

Permalink
Merge pull request #5640 from toslunar/check-unretained-var-grad
Browse files Browse the repository at this point in the history
Support unretained `Variable`s in `_check_grad_type`
  • Loading branch information
takagi committed Dec 10, 2018
2 parents db0f2f9 + e69f645 commit 6a7413f
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 29 deletions.
6 changes: 2 additions & 4 deletions chainer/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,8 @@ def backward(self, target_input_indexes, grad_outputs):
with cuda.get_device_from_array(*(in_data + grad_out_data)):
gxs = self._function.backward(in_data, grad_out_data)

for x, gx in six.moves.zip(inputs, gxs):
if x is None:
continue
variable._check_grad_type(self, x, gx)
for x, gx in six.moves.zip(self.inputs, gxs):
variable._check_grad_type(self, x, True, gx, False)

# Convert input gradients back to ChainerX
if xp is chainerx:
Expand Down
49 changes: 33 additions & 16 deletions chainer/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,41 @@
import chainerx


def _check_grad_type(func, x, gx):
if x.data is None or gx is None:
# ``x.data is None`` implies that the data array is not retained
def _check_grad_type(func, x, is_node_x, gx, is_var_gx):
if gx is None:
return
if not chainer.is_arrays_compatible((gx, x.data)):
x_grad = gx.array if is_var_gx else gx

# FIXME: avoid `isinstance`
x_data = None if isinstance(x, _ChainerxVariableNodeProps) else x.data

# TODO(kataoka): Make _update_data_info store the array module.
# ``is_node_x and x_data is None`` implies that the data array is not
# retained.
# ``not is_node_x and x_data is None`` implies that grad of uninitialized
# variable is checked here.

if x_grad is None:
# TODO(kataoka): This should be an error.
return
elif x_data is None and not is_node_x:
# TODO(kataoka): This should be an error.
return
elif not chainer.is_arrays_compatible((x_grad, x_data)):
msg = ('Type of data and grad mismatch\ngrad: %s != data: %s' %
(type(gx), type(x.data)))
(type(x_grad), type(x_data)))
typ = TypeError
elif gx.dtype != x.data.dtype:
elif x.dtype is None or x.shape is None:
# unretained Variable(None)
# TODO(kataoka): This should be an error.
return
elif gx.dtype != x.dtype:
msg = ('Dtype of data and grad mismatch\ngrad: %s != data: %s' %
(gx.dtype, x.data.dtype))
(gx.dtype, x.dtype))
typ = TypeError
elif gx.shape != x.data.shape:
elif gx.shape != x.shape:
msg = ('Shape of data and grad mismatch\ngrad: %s != data: %s' %
(gx.shape, x.data.shape))
(gx.shape, x.shape))
typ = ValueError
else:
return
Expand All @@ -52,7 +72,7 @@ def _check_grad_type(func, x, gx):
Please report this error to the issue tracker with the stack trace,
the information of your environment, and your script:
https://github.com/chainer/chainer/issues/new.
'''.format(type(func).__name__, func.label)
'''

raise typ(detail + msg)

Expand Down Expand Up @@ -889,9 +909,7 @@ def grad(self):

@grad.setter
def grad(self, g):
if g is not None:
_check_grad_type(None, self, g)

_check_grad_type(None, self, False, g, False)
self._set_grad_without_check(g)

def _set_grad_var_without_check(self, gv):
Expand All @@ -911,8 +929,7 @@ def grad_var(self):

@grad_var.setter
def grad_var(self, g):
if g is not None:
_check_grad_type(None, self, g.array)
_check_grad_type(None, self, False, g, True)
self._set_grad_var_without_check(g)

@property
Expand Down Expand Up @@ -1407,7 +1424,7 @@ def add_cand(cand):
continue

for gx_elem in gx:
_check_grad_type(func, x, gx_elem.array)
_check_grad_type(func, x, True, gx_elem, True)
del gx_elem # to reduce memory usage

if x.creator_node is None: # leaf
Expand Down
45 changes: 36 additions & 9 deletions tests/chainer_tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2070,13 +2070,15 @@ class TestVariableBackwardError(unittest.TestCase):
def setUp(self):
self.x = np.array([1], np.float32)

def check_type_mismatch(self, x_data):
def check_type_mismatch(self, x_data, retain):
xp = backend.get_array_module(x_data)

class DummyFunction(chainer.Function):
label = 'dummy_function'

def forward(self, inputs):
if not retain:
self.retain_inputs(())
return xp.array(1, np.float32),

def backward(self, inputs, grads):
Expand All @@ -2088,19 +2090,28 @@ def backward(self, inputs, grads):
y.backward()

def test_type_mismatch_cpu(self):
self.check_type_mismatch(self.x)
self.check_type_mismatch(self.x, True)

def test_type_mismatch_unretain_cpu(self):
self.check_type_mismatch(self.x, False)

@attr.gpu
def test_type_mismatch_gpu(self):
self.check_type_mismatch(cuda.to_gpu(self.x))
self.check_type_mismatch(cuda.to_gpu(self.x), True)

@attr.gpu
def test_type_mismatch_unretain_gpu(self):
self.check_type_mismatch(cuda.to_gpu(self.x), False)

def check_dtype_mismatch(self, x_data):
def check_dtype_mismatch(self, x_data, retain):
xp = backend.get_array_module(x_data)

class DummyFunction(chainer.Function):
label = 'dummy_function'

def forward(self, inputs):
if not retain:
self.retain_inputs(())
return xp.array(1, np.float32),

def backward(self, inputs, grads):
Expand All @@ -2112,19 +2123,28 @@ def backward(self, inputs, grads):
y.backward()

def test_dtype_mismatch_cpu(self):
self.check_dtype_mismatch(self.x)
self.check_dtype_mismatch(self.x, True)

def test_dtype_mismatch_unretain_cpu(self):
self.check_dtype_mismatch(self.x, False)

@attr.gpu
def test_dtype_mismatch_gpu(self):
self.check_dtype_mismatch(cuda.to_gpu(self.x))
self.check_dtype_mismatch(cuda.to_gpu(self.x), True)

def check_shape_mismatch(self, x_data):
@attr.gpu
def test_dtype_mismatch_unretain_gpu(self):
self.check_dtype_mismatch(cuda.to_gpu(self.x), False)

def check_shape_mismatch(self, x_data, retain):
xp = backend.get_array_module(x_data)

class DummyFunction(chainer.Function):
label = 'dummy_function'

def forward(self, inputs):
if not retain:
self.retain_inputs(())
return xp.array(1, np.float32),

def backward(self, inputs, grads):
Expand All @@ -2136,11 +2156,18 @@ def backward(self, inputs, grads):
y.backward()

def test_shape_mismatch_cpu(self):
self.check_shape_mismatch(self.x)
self.check_shape_mismatch(self.x, True)

def test_shape_mismatch_unretain_cpu(self):
self.check_shape_mismatch(self.x, False)

@attr.gpu
def test_shape_mismatch_gpu(self):
self.check_shape_mismatch(cuda.to_gpu(self.x))
self.check_shape_mismatch(cuda.to_gpu(self.x), True)

@attr.gpu
def test_shape_mismatch_unretain_gpu(self):
self.check_shape_mismatch(cuda.to_gpu(self.x), False)


class TestVariableBackwardErrorTraceback(unittest.TestCase):
Expand Down

0 comments on commit 6a7413f

Please sign in to comment.