Skip to content

Commit

Permalink
Merge 4361b96 into 27e5968
Browse files Browse the repository at this point in the history
  • Loading branch information
takagi committed Dec 10, 2018
2 parents 27e5968 + 4361b96 commit 810d8a3
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 23 deletions.
2 changes: 1 addition & 1 deletion chainer/function.py
Expand Up @@ -142,7 +142,7 @@ def backward(self, target_input_indexes, grad_outputs):
with cuda.get_device_from_array(*(in_data + grad_out_data)):
gxs = self._function.backward(in_data, grad_out_data)
for x, gx in six.moves.zip(self.inputs, gxs):
variable._check_grad_type(self, x, gx)
variable._check_grad_type(self, x, True, gx, False)

ret = []
for i in target_input_indexes:
Expand Down
43 changes: 30 additions & 13 deletions chainer/variable.py
Expand Up @@ -18,21 +18,39 @@
from chainer.utils import argument


def _check_grad_type(func, x, gx):
if x.data is None or gx is None:
# ``x.data is None`` implies that the data array is not retained
def _check_grad_type(func, x, is_node_x, gx, is_var_gx):
if gx is None:
return
if not chainer.is_arrays_compatible((gx, x.data)):
x_grad = gx.array if is_var_gx else gx
x_data = x.data

# TODO(kataoka): Make _update_data_info store the array module.
# ``is_node_x and x_data is None`` implies that the data array is not
# retained.
# ``not is_node_x and x_data is None`` implies that grad of uninitialized
# variable is checked here.

if x_grad is None:
# TODO(kataoka): This should be an error.
return
elif x_data is None and not is_node_x:
# TODO(kataoka): This should be an error.
return
elif not chainer.is_arrays_compatible((x_grad, x_data)):
msg = ('Type of data and grad mismatch\ngrad: %s != data: %s' %
(type(gx), type(x.data)))
(type(x_grad), type(x_data)))
typ = TypeError
elif gx.dtype != x.data.dtype:
elif x.dtype is None or x.shape is None:
# unretained Variable(None)
# TODO(kataoka): This should be an error.
return
elif gx.dtype != x.dtype:
msg = ('Dtype of data and grad mismatch\ngrad: %s != data: %s' %
(gx.dtype, x.data.dtype))
(gx.dtype, x.dtype))
typ = TypeError
elif gx.shape != x.data.shape:
elif gx.shape != x.shape:
msg = ('Shape of data and grad mismatch\ngrad: %s != data: %s' %
(gx.shape, x.data.shape))
(gx.shape, x.shape))
typ = ValueError
else:
return
Expand All @@ -50,7 +68,7 @@ def _check_grad_type(func, x, gx):
Please report this error to the issue tracker with the stack trace,
the information of your environment, and your script:
https://github.com/chainer/chainer/issues/new.
'''.format(type(func).__name__, func.label)
'''

raise typ(detail + msg)

Expand Down Expand Up @@ -694,8 +712,7 @@ def grad_var(self):

@grad_var.setter
def grad_var(self, g):
if g is not None:
_check_grad_type(None, self, g.array)
_check_grad_type(None, self, False, g, True)
self._grad_var = g

@property
Expand Down Expand Up @@ -1053,7 +1070,7 @@ def add_cand(cand):
continue

for gx_elem in gx:
_check_grad_type(func, x, gx_elem.array)
_check_grad_type(func, x, True, gx_elem, True)
del gx_elem # to reduce memory usage

if x.creator_node is None: # leaf
Expand Down
45 changes: 36 additions & 9 deletions tests/chainer_tests/test_variable.py
Expand Up @@ -1318,13 +1318,15 @@ class TestVariableBackwardError(unittest.TestCase):
def setUp(self):
self.x = np.array([1], np.float32)

def check_type_mismatch(self, x_data):
def check_type_mismatch(self, x_data, retain):
xp = backend.get_array_module(x_data)

class DummyFunction(chainer.Function):
label = 'dummy_function'

def forward(self, inputs):
if not retain:
self.retain_inputs(())
return xp.array(1, np.float32),

def backward(self, inputs, grads):
Expand All @@ -1336,19 +1338,28 @@ def backward(self, inputs, grads):
y.backward()

def test_type_mismatch_cpu(self):
self.check_type_mismatch(self.x)
self.check_type_mismatch(self.x, True)

def test_type_mismatch_unretain_cpu(self):
self.check_type_mismatch(self.x, False)

@attr.gpu
def test_type_mismatch_gpu(self):
self.check_type_mismatch(cuda.to_gpu(self.x))
self.check_type_mismatch(cuda.to_gpu(self.x), True)

@attr.gpu
def test_type_mismatch_unretain_gpu(self):
self.check_type_mismatch(cuda.to_gpu(self.x), False)

def check_dtype_mismatch(self, x_data):
def check_dtype_mismatch(self, x_data, retain):
xp = backend.get_array_module(x_data)

class DummyFunction(chainer.Function):
label = 'dummy_function'

def forward(self, inputs):
if not retain:
self.retain_inputs(())
return xp.array(1, np.float32),

def backward(self, inputs, grads):
Expand All @@ -1360,19 +1371,28 @@ def backward(self, inputs, grads):
y.backward()

def test_dtype_mismatch_cpu(self):
self.check_dtype_mismatch(self.x)
self.check_dtype_mismatch(self.x, True)

def test_dtype_mismatch_unretain_cpu(self):
self.check_dtype_mismatch(self.x, False)

@attr.gpu
def test_dtype_mismatch_gpu(self):
self.check_dtype_mismatch(cuda.to_gpu(self.x))
self.check_dtype_mismatch(cuda.to_gpu(self.x), True)

def check_shape_mismatch(self, x_data):
@attr.gpu
def test_dtype_mismatch_unretain_gpu(self):
self.check_dtype_mismatch(cuda.to_gpu(self.x), False)

def check_shape_mismatch(self, x_data, retain):
xp = backend.get_array_module(x_data)

class DummyFunction(chainer.Function):
label = 'dummy_function'

def forward(self, inputs):
if not retain:
self.retain_inputs(())
return xp.array(1, np.float32),

def backward(self, inputs, grads):
Expand All @@ -1384,11 +1404,18 @@ def backward(self, inputs, grads):
y.backward()

def test_shape_mismatch_cpu(self):
self.check_shape_mismatch(self.x)
self.check_shape_mismatch(self.x, True)

def test_shape_mismatch_unretain_cpu(self):
self.check_shape_mismatch(self.x, False)

@attr.gpu
def test_shape_mismatch_gpu(self):
self.check_shape_mismatch(cuda.to_gpu(self.x))
self.check_shape_mismatch(cuda.to_gpu(self.x), True)

@attr.gpu
def test_shape_mismatch_unretain_gpu(self):
self.check_shape_mismatch(cuda.to_gpu(self.x), False)


class TestVariableBackwardErrorTraceback(unittest.TestCase):
Expand Down

0 comments on commit 810d8a3

Please sign in to comment.