diff --git a/chainer/functions/connection/convolution_2d.py b/chainer/functions/connection/convolution_2d.py index b0bdfaecffc0..8899c4e327f4 100644 --- a/chainer/functions/connection/convolution_2d.py +++ b/chainer/functions/connection/convolution_2d.py @@ -3,11 +3,11 @@ import chainer from chainer import configuration from chainer import cuda -from chainer import function +from chainer import function_node +import chainer.functions from chainer.utils import argument from chainer.utils import conv from chainer.utils import type_check -from chainer import variable if cuda.cudnn_enabled: cudnn = cuda.cudnn @@ -25,21 +25,24 @@ def _pair(x): return x, x -class Convolution2DFunction(function.Function): +class Convolution2DFunction(function_node.FunctionNode): - def __init__(self, stride=1, pad=0, cover_all=False, requires_x_grad=True, - **kwargs): + def __init__(self, stride=1, pad=0, cover_all=False, **kwargs): argument.check_unexpected_kwargs( - kwargs, deterministic="deterministic argument is not " - "supported anymore. " - "Use chainer.using_config('cudnn_deterministic', value) " - "context where value is either `True` or `False`.") + kwargs, + deterministic="deterministic argument is not supported anymore. " + "Use chainer.using_config('cudnn_deterministic', value) context " + "where value is either `True` or `False`.", + requires_x_grad="requires_x_grad argument is not supported " + "anymore. Just remove the argument. Note that whether to compute " + "the gradient w.r.t. x is automatically decided during " + "backpropagation." + ) argument.assert_kwargs_empty(kwargs) self.sy, self.sx = _pair(stride) self.ph, self.pw = _pair(pad) self.cover_all = cover_all - self.requires_x_grad = requires_x_grad def check_type_forward(self, in_types): n_in = in_types.size() @@ -64,6 +67,7 @@ def check_type_forward(self, in_types): ) def forward_cpu(self, inputs): + self.retain_inputs((0, 1)) # retain only x and W x, W = inputs[:2] b = inputs[2] if len(inputs) == 3 else None @@ -78,16 +82,17 @@ def forward_cpu(self, inputs): .format(type(W), type(x))) kh, kw = W.shape[2:] - self.col = conv.im2col_cpu( + col = conv.im2col_cpu( x, kh, kw, self.sy, self.sx, self.ph, self.pw, cover_all=self.cover_all) y = numpy.tensordot( - self.col, W, ((1, 2, 3), (1, 2, 3))).astype(x.dtype, copy=False) + col, W, ((1, 2, 3), (1, 2, 3))).astype(x.dtype, copy=False) if b is not None: y += b return numpy.rollaxis(y, 3, 1), def forward_gpu(self, inputs): + self.retain_inputs((0, 1)) # retain only x and W x, W = inputs[:2] b = inputs[2] if len(inputs) == 3 else None @@ -123,42 +128,40 @@ def forward_gpu(self, inputs): x_desc = cudnn.create_tensor_descriptor(x) y_desc = cudnn.create_tensor_descriptor(y) - self.filter_desc = cudnn.create_filter_descriptor(W) - self.conv_desc = cudnn.create_convolution_descriptor( + filter_desc = cudnn.create_filter_descriptor(W) + conv_desc = cudnn.create_convolution_descriptor( (self.ph, self.pw), (self.sy, self.sx), x.dtype) if b is not None: - self.bias_desc = cudnn.create_tensor_descriptor( + bias_desc = cudnn.create_tensor_descriptor( b[None, :, None, None]) workspace_size = cuda.get_max_workspace_size() workspace = cuda.cupy.empty((workspace_size,), dtype='b') algo = libcudnn.getConvolutionForwardAlgorithm( - handle, x_desc.value, self.filter_desc.value, - self.conv_desc.value, y_desc.value, _fwd_pref, - workspace_size) + handle, x_desc.value, filter_desc.value, + conv_desc.value, y_desc.value, _fwd_pref, workspace_size) oz_dtype = 'd' if x.dtype == 'd' else 'f' one = numpy.array(1, dtype=oz_dtype).ctypes zero = numpy.array(0, dtype=oz_dtype).ctypes libcudnn.convolutionForward( handle, one.data, x_desc.value, x.data.ptr, - self.filter_desc.value, W.data.ptr, self.conv_desc.value, + filter_desc.value, W.data.ptr, conv_desc.value, algo, workspace.data.ptr, workspace_size, zero.data, y_desc.value, y.data.ptr) # TODO(beam2d): Support unshared bias if b is not None: cudnn.add_tensor( - handle, one.data, self.bias_desc.value, b.data.ptr, + handle, one.data, bias_desc.value, b.data.ptr, one.data, y_desc.value, y.data.ptr) else: # Implementation using im2col - self.col = conv.im2col_gpu( + col = conv.im2col_gpu( x, kh, kw, self.sy, self.sx, self.ph, self.pw, cover_all=self.cover_all) y = cuda.cupy.tensordot( - self.col, W, ((1, 2, 3), (1, 2, 3))).astype(x.dtype, - copy=False) + col, W, ((1, 2, 3), (1, 2, 3))).astype(x.dtype, copy=False) # TODO(beam2d): Support unshared bias if b is not None: y += b @@ -166,111 +169,116 @@ def forward_gpu(self, inputs): return y, - def backward_cpu(self, inputs, grad_outputs): - x, W = inputs[:2] - b = inputs[2] if len(inputs) == 3 else None - - gy = grad_outputs[0] - h, w = x.shape[2:] + def backward(self, indexes, grad_outputs): + x, W = self.get_retained_inputs() + gy, = grad_outputs + + ret = [] + if 0 in indexes: + xh, xw = x.shape[2:] + gx = chainer.functions.deconvolution_2d( + gy, W, stride=(self.sy, self.sx), pad=(self.ph, self.pw), + outsize=(xh, xw)) + ret.append(gx) + if 1 in indexes: + gW, = Convolution2DGradW(self).apply((x, gy)) + ret.append(gW) + if 2 in indexes: + gb = chainer.functions.sum(gy, axis=(0, 2, 3)) + ret.append(gb) + + return ret + + +class Convolution2DGradW(function_node.FunctionNode): + + def __init__(self, conv2d): + W_node = conv2d.inputs[1] + self.kh, self.kw = W_node.shape[2:] + self.sy = conv2d.sy + self.sx = conv2d.sx + self.ph = conv2d.ph + self.pw = conv2d.pw + self.cover_all = conv2d.cover_all + self.W_dtype = W_node.dtype + def forward_cpu(self, inputs): + self.retain_inputs((0, 1)) + x, gy = inputs + col = conv.im2col_cpu( + x, self.kh, self.kw, self.sy, self.sx, self.ph, self.pw, + cover_all=self.cover_all) gW = numpy.tensordot( - gy, self.col, ((0, 2, 3), (0, 4, 5))).astype(W.dtype, copy=False) - - if not self.requires_x_grad: - gx = None - else: - gcol = numpy.tensordot(W, gy, (0, 1)).astype(x.dtype, copy=False) - gcol = numpy.rollaxis(gcol, 3) - gx = conv.col2im_cpu(gcol, self.sy, self.sx, self.ph, self.pw, - h, w) - - if b is None: - return gx, gW - else: - gb = gy.sum(axis=(0, 2, 3)) - return gx, gW, gb + gy, col, ((0, 2, 3), (0, 4, 5))).astype(self.W_dtype, copy=False) + return gW, - def backward_gpu(self, inputs, grad_outputs): - x, W = inputs[:2] - b = inputs[2] if len(inputs) == 3 else None - - gy = grad_outputs[0] + def forward_gpu(self, inputs): + self.retain_inputs((0, 1)) + x, gy = inputs _, out_c, out_h, out_w = gy.shape n, c, h, w = x.shape - kh, kw = W.shape[2:] - gW = cuda.cupy.empty_like(W) - gx = None + if (self.cover_all or not chainer.should_use_cudnn('>=auto') or + x.dtype != self.W_dtype): + col = conv.im2col_gpu( + x, self.kh, self.kw, self.sy, self.sx, self.ph, self.pw, + cover_all=self.cover_all) + gW = cuda.cupy.tensordot( + gy, col, ((0, 2, 3), (0, 4, 5))).astype(self.W_dtype, + copy=False) + return gW, - if (not self.cover_all and chainer.should_use_cudnn('>=auto') and - x.dtype == W.dtype): - x = cuda.cupy.ascontiguousarray(x) - W = cuda.cupy.ascontiguousarray(W) - gy = cuda.cupy.ascontiguousarray(gy) + gW = cuda.cupy.empty((out_c, c, self.kh, self.kw), dtype=self.W_dtype) + x = cuda.cupy.ascontiguousarray(x) + gy = cuda.cupy.ascontiguousarray(gy) - handle = cudnn.get_handle() - x_desc = cudnn.create_tensor_descriptor(x) - gy_desc = cudnn.create_tensor_descriptor(gy) - oz_dtype = 'd' if x.dtype == 'd' else 'f' - one = numpy.array(1, dtype=oz_dtype).ctypes - zero = numpy.array(0, dtype=oz_dtype).ctypes + handle = cudnn.get_handle() + x_desc = cudnn.create_tensor_descriptor(x) + gy_desc = cudnn.create_tensor_descriptor(gy) - workspace_size = cuda.get_max_workspace_size() - workspace = cuda.cupy.empty((workspace_size,), dtype='b') + filter_desc = cudnn.create_filter_descriptor(gW) + conv_desc = cudnn.create_convolution_descriptor( + (self.ph, self.pw), (self.sy, self.sx), x.dtype) - if configuration.config.cudnn_deterministic: - algo = libcudnn.CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 - else: - algo = libcudnn.getConvolutionBackwardFilterAlgorithm( - handle, x_desc.value, gy_desc.value, - self.conv_desc.value, self.filter_desc.value, - _bwd_filter_pref, workspace_size) + oz_dtype = 'd' if x.dtype == 'd' else 'f' + one = numpy.array(1, dtype=oz_dtype).ctypes + zero = numpy.array(0, dtype=oz_dtype).ctypes - libcudnn.convolutionBackwardFilter_v3( - handle, one.data, x_desc.value, x.data.ptr, - gy_desc.value, gy.data.ptr, self.conv_desc.value, - algo, workspace.data.ptr, workspace_size, - zero.data, self.filter_desc.value, gW.data.ptr) - - if self.requires_x_grad: - if configuration.config.cudnn_deterministic: - algo = libcudnn.CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 - else: - algo = libcudnn.getConvolutionBackwardDataAlgorithm( - handle, self.filter_desc.value, gy_desc.value, - self.conv_desc.value, x_desc.value, _bwd_data_pref, - workspace_size) - - gx = cuda.cupy.empty_like(x) - libcudnn.convolutionBackwardData_v3( - handle, one.data, self.filter_desc.value, W.data.ptr, - gy_desc.value, gy.data.ptr, self.conv_desc.value, - algo, workspace.data.ptr, workspace_size, - zero.data, x_desc.value, gx.data.ptr) + workspace_size = cuda.get_max_workspace_size() + workspace = cuda.cupy.empty((workspace_size,), dtype='b') - if b is not None: - gb = cuda.cupy.empty_like(b) - libcudnn.convolutionBackwardBias( - handle, one.data, gy_desc.value, gy.data.ptr, - zero.data, self.bias_desc.value, gb.data.ptr) + if configuration.config.cudnn_deterministic: + algo = libcudnn.CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 else: - gW = cuda.cupy.tensordot( - gy, self.col, ((0, 2, 3), (0, 4, 5))).astype(W.dtype, - copy=False) - if self.requires_x_grad: - gcol = cuda.cupy.tensordot(W, gy, (0, 1)).astype(x.dtype, - copy=False) - gcol = cuda.cupy.rollaxis(gcol, 3) - gx = conv.col2im_gpu( - gcol, self.sy, self.sx, self.ph, self.pw, h, w) - - if b is not None: - gb = gy.sum(axis=(0, 2, 3)) + algo = libcudnn.getConvolutionBackwardFilterAlgorithm( + handle, x_desc.value, gy_desc.value, conv_desc.value, + filter_desc.value, _bwd_filter_pref, workspace_size) + + libcudnn.convolutionBackwardFilter_v3( + handle, one.data, x_desc.value, x.data.ptr, gy_desc.value, + gy.data.ptr, conv_desc.value, algo, workspace.data.ptr, + workspace_size, zero.data, filter_desc.value, gW.data.ptr) + + return gW, + + def backward(self, indexes, grad_outputs): + x, gy = self.get_retained_inputs() + ggW, = grad_outputs + + ret = [] + if 0 in indexes: + xh, xw = x.shape[2:] + gx = chainer.functions.deconvolution_2d( + gy, ggW, stride=(self.sy, self.sx), pad=(self.ph, self.pw), + outsize=(xh, xw)) + ret.append(gx) + if 1 in indexes: + ggy = convolution_2d( + x, ggW, stride=(self.sy, self.sx), pad=(self.ph, self.pw), + cover_all=self.cover_all) + ret.append(ggy) - if b is None: - return gx, gW - else: - return gx, gW, gb + return ret def convolution_2d(x, W, b=None, stride=1, pad=0, cover_all=False, **kwargs): @@ -397,9 +405,10 @@ def convolution_2d(x, W, b=None, stride=1, pad=0, cover_all=False, **kwargs): "context where value is either `True` or `False`.") argument.assert_kwargs_empty(kwargs) - requires_x_grad = isinstance(x, variable.Variable) and x.requires_grad - func = Convolution2DFunction(stride, pad, cover_all, requires_x_grad) + fnode = Convolution2DFunction(stride, pad, cover_all) if b is None: - return func(x, W) + args = x, W else: - return func(x, W, b) + args = x, W, b + y, = fnode.apply(args) + return y diff --git a/chainer/functions/connection/deconvolution_2d.py b/chainer/functions/connection/deconvolution_2d.py index 2b8f0845905e..e9ef3b541b69 100644 --- a/chainer/functions/connection/deconvolution_2d.py +++ b/chainer/functions/connection/deconvolution_2d.py @@ -3,11 +3,12 @@ import chainer from chainer import configuration from chainer import cuda -from chainer import function +from chainer import function_node +import chainer.functions +from chainer.functions.connection import convolution_2d from chainer.utils import argument from chainer.utils import conv from chainer.utils import type_check -from chainer import variable if cuda.cudnn_enabled: cudnn = cuda.cudnn @@ -25,21 +26,26 @@ def _pair(x): return x, x -class Deconvolution2DFunction(function.Function): +class Deconvolution2DFunction(function_node.FunctionNode): - def __init__(self, stride=1, pad=0, outsize=None, requires_x_grad=True, - **kwargs): + cover_all = None + + def __init__(self, stride=1, pad=0, outsize=None, **kwargs): argument.check_unexpected_kwargs( - kwargs, deterministic="deterministic argument is not " - "supported anymore. " - "Use chainer.using_config('cudnn_deterministic', value) " - "context where value is either `True` or `False`.") + kwargs, + deterministic="deterministic argument is not supported anymore. " + "Use chainer.using_config('cudnn_deterministic', value) context " + "where value is either `True` or `False`.", + requires_x_grad="requires_x_grad argument is not supported " + "anymore. Just remove the argument. Note that whether to compute " + "the gradient w.r.t. x is automatically decided during " + "backpropagation." + ) argument.assert_kwargs_empty(kwargs) self.sy, self.sx = _pair(stride) self.ph, self.pw = _pair(pad) self.outh, self.outw = (None, None) if outsize is None else outsize - self.requires_x_grad = requires_x_grad def check_type_forward(self, in_types): n_in = in_types.size() @@ -55,17 +61,21 @@ def check_type_forward(self, in_types): ) if self.outh is not None: + lower_bound = conv.get_conv_outsize( + self.outh, w_type.shape[2], self.sy, self.ph) + upper_bound = conv.get_conv_outsize( + self.outh, w_type.shape[2], self.sy, self.ph, cover_all=True) type_check.expect( - x_type.shape[2] == - conv.get_conv_outsize(self.outh, w_type.shape[2], - self.sy, self.ph), - ) + lower_bound <= x_type.shape[2], + x_type.shape[2] <= upper_bound) if self.outw is not None: + lower_bound = conv.get_conv_outsize( + self.outw, w_type.shape[3], self.sx, self.pw) + upper_bound = conv.get_conv_outsize( + self.outw, w_type.shape[3], self.sx, self.pw, cover_all=True) type_check.expect( - x_type.shape[3] == - conv.get_conv_outsize(self.outw, w_type.shape[3], - self.sx, self.pw), - ) + lower_bound <= x_type.shape[3], + x_type.shape[3] <= upper_bound) if type_check.eval(n_in) == 3: b_type = in_types[2] @@ -76,6 +86,7 @@ def check_type_forward(self, in_types): ) def forward_cpu(self, inputs): + self.retain_inputs((0, 1)) # only retain x and W x, W = inputs[:2] b = inputs[2] if len(inputs) == 3 else None @@ -111,6 +122,7 @@ def forward_cpu(self, inputs): return y, def forward_gpu(self, inputs): + self.retain_inputs((0, 1)) # only retain x and W x, W = inputs[:2] b = inputs[2] if len(inputs) == 3 else None @@ -133,7 +145,11 @@ def forward_gpu(self, inputs): if self.outw is None: self.outw = conv.get_deconv_outsize(in_w, kw, self.sx, self.pw) assert self.outw > 0, 'Width in the output should be positive.' - if chainer.should_use_cudnn('>=auto') and x.dtype == W.dtype: + + self._set_cover_all(x, W) + + if (not self.cover_all and chainer.should_use_cudnn('>=auto') and + x.dtype == W.dtype): x = cuda.cupy.ascontiguousarray(x) W = cuda.cupy.ascontiguousarray(W) if b is not None: @@ -145,11 +161,11 @@ def forward_gpu(self, inputs): dtype=x.dtype) y_desc = cudnn.create_tensor_descriptor(y) - self.filter_desc = cudnn.create_filter_descriptor(W) - self.conv_desc = cudnn.create_convolution_descriptor( + filter_desc = cudnn.create_filter_descriptor(W) + conv_desc = cudnn.create_convolution_descriptor( (self.ph, self.pw), (self.sy, self.sx), x.dtype) if b is not None: - self.bias_desc = cudnn.create_tensor_descriptor( + bias_desc = cudnn.create_tensor_descriptor( b[None, :, None, None]) oz_dtype = 'd' if x.dtype == 'd' else 'f' @@ -162,19 +178,19 @@ def forward_gpu(self, inputs): algo = libcudnn.CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 else: algo = libcudnn.getConvolutionBackwardDataAlgorithm( - handle, self.filter_desc.value, x_desc.value, - self.conv_desc.value, y_desc.value, _bwd_data_pref, + handle, filter_desc.value, x_desc.value, + conv_desc.value, y_desc.value, _bwd_data_pref, workspace_size) libcudnn.convolutionBackwardData_v3( - handle, one.data, self.filter_desc.value, W.data.ptr, - x_desc.value, x.data.ptr, self.conv_desc.value, + handle, one.data, filter_desc.value, W.data.ptr, + x_desc.value, x.data.ptr, conv_desc.value, algo, workspace.data.ptr, workspace_size, zero.data, y_desc.value, y.data.ptr) if b is not None: cudnn.add_tensor( - handle, one.data, self.bias_desc.value, b.data.ptr, + handle, one.data, bias_desc.value, b.data.ptr, one.data, y_desc.value, y.data.ptr) else: gcol = cuda.cupy.tensordot(W, x, (0, 1)).astype(x.dtype, @@ -190,109 +206,35 @@ def forward_gpu(self, inputs): y += b.reshape(1, b.size, 1, 1) return y, - def backward_cpu(self, inputs, grad_outputs): - x, W = inputs[:2] - b = inputs[2] if len(inputs) == 3 else None - - gy = grad_outputs[0] + def backward(self, indexes, grad_outputs): + x, W = self.get_retained_inputs() + gy, = grad_outputs + + ret = [] + if 0 in indexes: + if self.cover_all is None: + self._set_cover_all(x, W) + gx = chainer.functions.convolution_2d( + gy, W, stride=(self.sy, self.sx), pad=(self.ph, self.pw), + cover_all=self.cover_all) + ret.append(gx) + if 1 in indexes: + if self.cover_all is None: + self._set_cover_all(x, W) + gW, = convolution_2d.Convolution2DGradW(self).apply((gy, x)) + ret.append(gW) + if 2 in indexes: + gb = chainer.functions.sum(gy, axis=(0, 2, 3)) + ret.append(gb) + + return ret + + def _set_cover_all(self, x, W): + in_h, in_w = x.shape[2:] kh, kw = W.shape[2:] - col = conv.im2col_cpu( - gy, kh, kw, self.sy, self.sx, self.ph, self.pw) - gW = numpy.tensordot( - x, col, ([0, 2, 3], [0, 4, 5])).astype(W.dtype, copy=False) - if not self.requires_x_grad: - gx = None - else: - gx = numpy.tensordot( - col, W, ([1, 2, 3], [1, 2, 3])).astype(x.dtype, copy=False) - gx = numpy.rollaxis(gx, 3, 1) - - if b is None: - return gx, gW - else: - gb = gy.sum(axis=(0, 2, 3)) - return gx, gW, gb - - def backward_gpu(self, inputs, grad_outputs): - x, W = inputs[:2] - b = inputs[2] if len(inputs) == 3 else None - - gy = grad_outputs[0] - n, in_c, in_h, in_w = x.shape - _, out_channels, kh, kw = W.shape - c, h, w = gy.shape[1:] - gx = None - - if chainer.should_use_cudnn('>=auto') and x.dtype == W.dtype: - gx = cuda.cupy.empty((n, in_c, in_h, in_w), dtype=x.dtype) - x = cuda.cupy.ascontiguousarray(x) - W = cuda.cupy.ascontiguousarray(W) - gy = cuda.cupy.ascontiguousarray(gy) - if b is not None: - b = cuda.cupy.ascontiguousarray(b) - - handle = cudnn.get_handle() - gy_desc = cudnn.create_tensor_descriptor(gy) - gx_desc = cudnn.create_tensor_descriptor(gx) - - # chance to choose implicit-precomp-gemm algorithm - workspace_size = cuda.get_max_workspace_size() - algo = libcudnn.getConvolutionForwardAlgorithm( - handle, gy_desc.value, self.filter_desc.value, - self.conv_desc.value, gx_desc.value, _fwd_pref, - workspace_size) - workspace = cuda.cupy.empty((workspace_size,), dtype='b') - - oz_dtype = 'd' if x.dtype == 'd' else 'f' - one = numpy.array(1, dtype=oz_dtype).ctypes - zero = numpy.array(0, dtype=oz_dtype).ctypes - - libcudnn.convolutionForward( - handle, one.data, gy_desc.value, gy.data.ptr, - self.filter_desc.value, W.data.ptr, - self.conv_desc.value, algo, workspace.data.ptr, workspace_size, - zero.data, gx_desc.value, gx.data.ptr) - # bias backward - if b is not None: - gb = cuda.cupy.empty_like(b) - libcudnn.convolutionBackwardBias( - handle, one.data, gy_desc.value, gy.data.ptr, - zero.data, self.bias_desc.value, gb.data.ptr) - gW = cuda.cupy.empty_like(W) - # filter backward - if configuration.config.cudnn_deterministic: - algo = libcudnn.CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 - else: - algo = libcudnn.getConvolutionBackwardFilterAlgorithm( - handle, gy_desc.value, gx_desc.value, - self.conv_desc.value, self.filter_desc.value, - _bwd_filter_pref, workspace_size) - - libcudnn.convolutionBackwardFilter_v3( - handle, one.data, gy_desc.value, gy.data.ptr, - gx_desc.value, x.data.ptr, self.conv_desc.value, - algo, workspace.data.ptr, workspace_size, - zero.data, self.filter_desc.value, gW.data.ptr) - else: - # Implementation using im2col - col = conv.im2col_gpu( - gy, kh, kw, self.sy, self.sx, self.ph, self.pw) - - gW = cuda.cupy.tensordot( - x, col, ([0, 2, 3], [0, 4, 5])).astype(W.dtype, copy=False) - if self.requires_x_grad: - gx = cuda.cupy.tensordot( - col, W, ([1, 2, 3], [1, 2, 3])).astype(x.dtype, copy=False) - gx = cuda.cupy.rollaxis(gx, 3, 1) - - # bias backward - if b is not None: - gb = gy.sum(axis=(0, 2, 3)) - - if b is None: - return gx, gW - else: - return gx, gW, gb + self.cover_all = ( + in_h != conv.get_conv_outsize(self.outh, kh, self.sy, self.ph) or + in_w != conv.get_conv_outsize(self.outw, kw, self.sx, self.pw)) def deconvolution_2d(x, W, b=None, stride=1, pad=0, outsize=None, **kwargs): @@ -402,9 +344,10 @@ def deconvolution_2d(x, W, b=None, stride=1, pad=0, outsize=None, **kwargs): "context where value is either `True` or `False`.") argument.assert_kwargs_empty(kwargs) - requires_x_grad = isinstance(x, variable.Variable) and x.requires_grad - func = Deconvolution2DFunction(stride, pad, outsize, requires_x_grad) + func = Deconvolution2DFunction(stride, pad, outsize) if b is None: - return func(x, W) + args = x, W else: - return func(x, W, b) + args = x, W, b + y, = func.apply(args) + return y diff --git a/chainer/utils/conv.py b/chainer/utils/conv.py index 2839bc6f0763..601249788047 100644 --- a/chainer/utils/conv.py +++ b/chainer/utils/conv.py @@ -20,11 +20,14 @@ def get_deconv_outsize(size, k, s, p, cover_all=False): def im2col_cpu( - img, kh, kw, sy, sx, ph, pw, pval=0, cover_all=False, dy=1, dx=1): + img, kh, kw, sy, sx, ph, pw, pval=0, cover_all=False, dy=1, dx=1, + out_h=None, out_w=None): n, c, h, w = img.shape - out_h = get_conv_outsize(h, kh, sy, ph, cover_all, dy) + if out_h is None: + out_h = get_conv_outsize(h, kh, sy, ph, cover_all, dy) assert out_h > 0, 'Height in the output should be positive.' - out_w = get_conv_outsize(w, kw, sx, pw, cover_all, dx) + if out_w is None: + out_w = get_conv_outsize(w, kw, sx, pw, cover_all, dx) assert out_w > 0, 'Width in the output should be positive.' img = numpy.pad(img, @@ -43,11 +46,14 @@ def im2col_cpu( return col -def im2col_gpu(img, kh, kw, sy, sx, ph, pw, cover_all=False, dy=1, dx=1): +def im2col_gpu(img, kh, kw, sy, sx, ph, pw, cover_all=False, dy=1, dx=1, + out_h=None, out_w=None): n, c, h, w = img.shape - out_h = get_conv_outsize(h, kh, sy, ph, cover_all, dy) + if out_h is None: + out_h = get_conv_outsize(h, kh, sy, ph, cover_all, dy) assert out_h > 0, 'Height in the output should be positive.' - out_w = get_conv_outsize(w, kw, sx, pw, cover_all, dx) + if out_w is None: + out_w = get_conv_outsize(w, kw, sx, pw, cover_all, dx) assert out_w > 0, 'Width in the output should be positive.' col = cuda.cupy.empty((n, c, kh, kw, out_h, out_w), dtype=img.dtype) diff --git a/tests/chainer_tests/functions_tests/connection_tests/test_convolution_2d.py b/tests/chainer_tests/functions_tests/connection_tests/test_convolution_2d.py index ef92880d622d..2a0972c6791b 100644 --- a/tests/chainer_tests/functions_tests/connection_tests/test_convolution_2d.py +++ b/tests/chainer_tests/functions_tests/connection_tests/test_convolution_2d.py @@ -5,8 +5,7 @@ import chainer from chainer import cuda -from chainer import functions -from chainer.functions.connection import convolution_2d +import chainer.functions as F from chainer import gradient_check from chainer import testing from chainer.testing import attr @@ -49,12 +48,19 @@ def setUp(self): else: self.gy = numpy.random.uniform( -1, 1, (2, 2, 2, 2)).astype(self.x_dtype) + self.ggx = numpy.random.uniform(-1, 1, self.x.shape).astype( + self.x_dtype) + self.ggW = numpy.random.uniform(-1, 1, self.W.shape).astype( + self.W_dtype) + self.ggb = numpy.random.uniform(-1, 1, self.b.shape).astype( + self.x_dtype) self.check_forward_options = {} self.check_backward_options = {'dtype': numpy.float64} + self.check_double_backward_options = {'dtype': numpy.float64} if self.x_dtype == numpy.float16 or self.W_dtype == numpy.float16: - self.check_forward_options = {'atol': 5e-4, 'rtol': 5e-3} - self.check_backward_options = { - 'dtype': numpy.float64, 'atol': 5e-4, 'rtol': 5e-3} + self.check_forward_options.update(atol=5e-4, rtol=5e-3) + self.check_backward_options.update(atol=5e-4, rtol=5e-3) + self.check_double_backward_options.update(atol=5e-3, rtol=5e-2) @attr.gpu def test_forward_consistency(self, nobias=False): @@ -63,7 +69,7 @@ def test_forward_consistency(self, nobias=False): b_cpu = None if nobias else chainer.Variable(self.b) with chainer.using_config('cudnn_deterministic', self.cudnn_deterministic): - y_cpu = functions.convolution_2d( + y_cpu = F.convolution_2d( x_cpu, W_cpu, b_cpu, stride=self.stride, pad=self.pad, cover_all=self.cover_all) @@ -73,7 +79,7 @@ def test_forward_consistency(self, nobias=False): with chainer.using_config('use_cudnn', self.use_cudnn): with chainer.using_config('cudnn_deterministic', self.cudnn_deterministic): - y_gpu = functions.convolution_2d( + y_gpu = F.convolution_2d( x_gpu, W_gpu, b_gpu, stride=self.stride, pad=self.pad, cover_all=self.cover_all) @@ -110,13 +116,15 @@ def check_backward(self, x_data, W_data, b_data, y_grad): if b_data is not None: args = args + (b_data,) + def f(*args): + return F.convolution_2d(*args, stride=self.stride, pad=self.pad, + cover_all=self.cover_all) + with chainer.using_config('use_cudnn', self.use_cudnn): with chainer.using_config('cudnn_deterministic', self.cudnn_deterministic): gradient_check.check_backward( - convolution_2d.Convolution2DFunction( - self.stride, self.pad, self.cover_all), - args, y_grad, **self.check_backward_options) + f, args, y_grad, **self.check_backward_options) @condition.retry(3) def test_backward_cpu(self): @@ -152,6 +160,89 @@ def test_backward_gpu_im2col_nobias(self): self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.W), None, cuda.to_gpu(self.gy)) + def check_double_backward(self, x_data, W_data, b_data, y_grad, + x_grad_grad, W_grad_grad, b_grad_grad): + xp = cuda.get_array_module(x_data) + + if not self.c_contiguous: + x_data = xp.asfortranarray(x_data) + W_data = xp.asfortranarray(W_data) + y_grad = xp.asfortranarray(y_grad) + x_grad_grad = xp.asfortranarray(x_grad_grad) + W_grad_grad = xp.asfortranarray(W_grad_grad) + self.assertFalse(x_data.flags.c_contiguous) + self.assertFalse(W_data.flags.c_contiguous) + self.assertFalse(y_grad.flags.c_contiguous) + self.assertFalse(x_grad_grad.flags.c_contiguous) + self.assertFalse(W_grad_grad.flags.c_contiguous) + if b_data is not None: + b = xp.empty((len(b_data) * 2,), dtype=self.b.dtype) + b[::2] = b_data + b_data = b[::2] + self.assertFalse(b_data.flags.c_contiguous) + + ggb = xp.empty((len(b_data) * 2,), dtype=self.b.dtype) + ggb[::2] = b_grad_grad + b_grad_grad = ggb[::2] + self.assertFalse(b_grad_grad.flags.c_contiguous) + + args = (x_data, W_data) + grad_grads = (x_grad_grad, W_grad_grad) + if b_data is not None: + args = args + (b_data,) + grad_grads = grad_grads + (b_grad_grad,) + + def f(*args): + y = F.convolution_2d(*args, stride=self.stride, pad=self.pad, + cover_all=self.cover_all) + return y * y # make the function nonlinear + + with chainer.using_config('use_cudnn', self.use_cudnn): + with chainer.using_config('cudnn_deterministic', + self.cudnn_deterministic): + gradient_check.check_double_backward( + f, args, y_grad, grad_grads, + **self.check_double_backward_options) + + @condition.retry(3) + def test_double_backward_cpu(self): + self.check_double_backward(self.x, self.W, self.b, self.gy, + self.ggx, self.ggW, self.ggb) + + @condition.retry(3) + def test_double_backward_cpu_nobias(self): + self.check_double_backward(self.x, self.W, None, self.gy, + self.ggx, self.ggW, None) + + def check_double_backward_gpu(self, bias=True, im2col=False): + if im2col: + self.use_cudnn = 'never' + self.check_double_backward( + cuda.to_gpu(self.x), cuda.to_gpu(self.W), + cuda.to_gpu(self.b) if bias else None, + cuda.to_gpu(self.gy), cuda.to_gpu(self.ggx), cuda.to_gpu(self.ggW), + cuda.to_gpu(self.ggb) if bias else None) + + @attr.gpu + @condition.retry(3) + def test_double_backward_gpu(self): + self.check_double_backward_gpu() + + @attr.gpu + @condition.retry(3) + def test_double_backward_gpu_nobias(self): + self.check_double_backward_gpu(bias=False) + + @attr.gpu + @condition.retry(3) + def test_double_backward_gpu_im2col(self): + self.check_double_backward_gpu(im2col=True) + + @attr.gpu + @condition.retry(3) + def test_double_backward_gpu_im2col_nobias(self): + self.check_double_backward_gpu(bias=False, im2col=True) + @testing.parameterize(*testing.product({ 'use_cudnn': ['always', 'auto', 'never'], @@ -180,8 +271,7 @@ def setUp(self): def forward(self): x = chainer.Variable(self.x) W = chainer.Variable(self.W) - return functions.convolution_2d( - x, W, None, stride=self.stride, pad=self.pad) + return F.convolution_2d(x, W, None, stride=self.stride, pad=self.pad) def test_call_cudnn_forward(self): with chainer.using_config('use_cudnn', self.use_cudnn): @@ -234,21 +324,25 @@ def setUp(self): def test_called(self): with mock.patch( - 'chainer.functions.connection.convolution_2d.libcudnn', - autospec=True) as mlibcudnn: + 'chainer.functions.connection.convolution_2d.libcudnn', + autospec=True + ) as mlibcudnn_conv, mock.patch( + 'chainer.functions.connection.deconvolution_2d.libcudnn', + autospec=True + ) as mlibcudnn_deconv: # cuDNN version >= v3 supports `cudnn_deterministic` option x, W, b, y = self._run() # in Convolution2DFunction.backward_gpu() self.assertFalse( - mlibcudnn.getConvolutionBackwardFilterAlgorithm.called) + mlibcudnn_conv.getConvolutionBackwardFilterAlgorithm.called) self.assertEqual( - mlibcudnn.convolutionBackwardFilter_v3.call_count, 1) + mlibcudnn_conv.convolutionBackwardFilter_v3.call_count, 1) self.assertFalse( - mlibcudnn.getConvolutionBackwardDataAlgorithm.called) + mlibcudnn_deconv.getConvolutionBackwardDataAlgorithm.called) self.assertEqual( - mlibcudnn.convolutionBackwardData_v3.call_count, 1) + mlibcudnn_deconv.convolutionBackwardData_v3.call_count, 1) def test_cudnn_deterministic(self): x1, W1, b1, y1 = self._run() @@ -289,9 +383,8 @@ def _run_forward(self, x_data, W_data, b_data): x = chainer.Variable(x_data) W = chainer.Variable(W_data) b = None if self.nobias else chainer.Variable(b_data) - y = functions.convolution_2d( - x, W, b, stride=self.stride, pad=self.pad, - cover_all=False) + y = F.convolution_2d(x, W, b, stride=self.stride, pad=self.pad, + cover_all=False) return x, W, b, y diff --git a/tests/chainer_tests/functions_tests/connection_tests/test_deconvolution_2d.py b/tests/chainer_tests/functions_tests/connection_tests/test_deconvolution_2d.py index 313edee9251a..f277a20e1da8 100644 --- a/tests/chainer_tests/functions_tests/connection_tests/test_deconvolution_2d.py +++ b/tests/chainer_tests/functions_tests/connection_tests/test_deconvolution_2d.py @@ -6,7 +6,6 @@ import chainer from chainer import cuda import chainer.functions as F -from chainer.functions.connection import deconvolution_2d from chainer import gradient_check from chainer import testing from chainer.testing import attr @@ -67,15 +66,24 @@ def setUp(self): -1, 1, (N, self.in_channels, inh, inw)).astype(self.x_dtype) self.gy = numpy.random.uniform( -1, 1, (N, self.out_channels, outh, outw)).astype(self.x_dtype) + + self.ggx = numpy.random.uniform(-1, 1, self.x.shape).astype( + self.x_dtype) + self.ggW = numpy.random.uniform(-1, 1, self.W.shape).astype( + self.W_dtype) + self.ggb = None if self.nobias else numpy.random.uniform( + -1, 1, self.b.shape).astype(self.x_dtype) + self.test_forward_options = {} self.check_backward_options = {'dtype': numpy.float64} + self.check_double_backward_options = {'dtype': numpy.float64} if self.x_dtype == numpy.float16: - self.test_forward_options = {'atol': 5e-3, 'rtol': 5e-2} - self.check_backward_options = { - 'dtype': numpy.float64, 'atol': 5e-4, 'rtol': 5e-3} + self.test_forward_options.update(atol=5e-3, rtol=5e-2) + self.check_backward_options.update(atol=5e-4, rtol=5e-3) + self.check_double_backward_options.update(atol=5e-3, rtol=5e-2) elif self.W_dtype == numpy.float16: - self.check_backward_options = { - 'dtype': numpy.float64, 'atol': 5e-4, 'rtol': 5e-3} + self.check_backward_options.update(atol=5e-4, rtol=5e-3) + self.check_double_backward_options.update(atol=5e-3, rtol=5e-2) @attr.gpu def test_forward_consistency(self): @@ -129,13 +137,15 @@ def check_backward(self, x_data, W_data, b_data, y_grad): if b_data is not None: args = args + (b_data,) + def f(*args): + return F.deconvolution_2d( + *args, stride=self.stride, pad=self.pad, outsize=self.outsize) + with chainer.using_config('use_cudnn', self.use_cudnn): with chainer.using_config('cudnn_deterministic', self.cudnn_deterministic): gradient_check.check_backward( - deconvolution_2d.Deconvolution2DFunction( - self.stride, self.pad, self.outsize), - args, y_grad, **self.check_backward_options) + f, args, y_grad, **self.check_backward_options) @condition.retry(10) def test_backward_cpu(self): @@ -148,6 +158,64 @@ def test_backward_gpu(self): self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.W), b, cuda.to_gpu(self.gy)) + def check_double_backward(self, x_data, W_data, b_data, y_grad, + x_grad_grad, W_grad_grad, b_grad_grad): + xp = cuda.get_array_module(x_data) + + if not self.c_contiguous: + x_data = xp.asfortranarray(x_data) + W_data = xp.asfortranarray(W_data) + y_grad = xp.asfortranarray(y_grad) + x_grad_grad = xp.asfortranarray(x_grad_grad) + W_grad_grad = xp.asfortranarray(W_grad_grad) + self.assertFalse(x_data.flags.c_contiguous) + self.assertFalse(W_data.flags.c_contiguous) + self.assertFalse(y_grad.flags.c_contiguous) + self.assertFalse(x_grad_grad.flags.c_contiguous) + self.assertFalse(W_grad_grad.flags.c_contiguous) + if b_data is not None: + b = xp.empty((len(b_data) * 2,), dtype=self.b.dtype) + b[::2] = b_data + b_data = b[::2] + self.assertFalse(b_data.flags.c_contiguous) + + ggb = xp.empty((len(b_data) * 2,), dtype=self.b.dtype) + ggb[::2] = b_grad_grad + b_grad_grad = ggb[::2] + self.assertFalse(b_grad_grad.flags.c_contiguous) + + args = (x_data, W_data) + grad_grads = (x_grad_grad, W_grad_grad) + if b_data is not None: + args = args + (b_data,) + grad_grads = grad_grads + (b_grad_grad,) + + def f(*args): + y = F.deconvolution_2d( + *args, stride=self.stride, pad=self.pad, outsize=self.outsize) + return y * y # make the function nonlinear + + with chainer.using_config('use_cudnn', self.use_cudnn): + with chainer.using_config('cudnn_deterministic', + self.cudnn_deterministic): + gradient_check.check_double_backward( + f, args, y_grad, grad_grads, + **self.check_double_backward_options) + + @condition.retry(10) + def test_double_backward_cpu(self): + self.check_double_backward(self.x, self.W, self.b, self.gy, + self.ggx, self.ggW, self.ggb) + + @attr.gpu + @condition.retry(10) + def test_double_backward_gpu(self): + self.check_double_backward( + cuda.to_gpu(self.x), cuda.to_gpu(self.W), + None if self.b is None else cuda.to_gpu(self.b), + cuda.to_gpu(self.gy), cuda.to_gpu(self.ggx), cuda.to_gpu(self.ggW), + None if self.ggb is None else cuda.to_gpu(self.ggb)) + @testing.parameterize(*testing.product({ 'use_cudnn': ['always', 'auto', 'never'], diff --git a/tests/chainer_tests/functions_tests/connection_tests/test_depthwise_convolution_2d.py b/tests/chainer_tests/functions_tests/connection_tests/test_depthwise_convolution_2d.py index 6eb6bf5c382b..ad251db0bb34 100644 --- a/tests/chainer_tests/functions_tests/connection_tests/test_depthwise_convolution_2d.py +++ b/tests/chainer_tests/functions_tests/connection_tests/test_depthwise_convolution_2d.py @@ -55,7 +55,7 @@ def check_forward(self, x_data, W_data, b_data): y1 = sum(arys) f2 = convolution_2d.Convolution2DFunction(self.stride, self.pad) - y2 = f2(*args2).data + y2 = f2.apply(args2)[0].data testing.assert_allclose(y1, y2, **self.check_forward_options) def test_forward_cpu(self):