Skip to content

Commit

Permalink
Merge 59fdc5d into 8bcac6f
Browse files Browse the repository at this point in the history
  • Loading branch information
beam2d committed Aug 24, 2017
2 parents 8bcac6f + 59fdc5d commit 26597b0
Show file tree
Hide file tree
Showing 6 changed files with 407 additions and 288 deletions.
245 changes: 127 additions & 118 deletions chainer/functions/connection/convolution_2d.py
Expand Up @@ -3,11 +3,11 @@
import chainer
from chainer import configuration
from chainer import cuda
from chainer import function
from chainer import function_node
import chainer.functions
from chainer.utils import argument
from chainer.utils import conv
from chainer.utils import type_check
from chainer import variable

if cuda.cudnn_enabled:
cudnn = cuda.cudnn
Expand All @@ -25,21 +25,24 @@ def _pair(x):
return x, x


class Convolution2DFunction(function.Function):
class Convolution2DFunction(function_node.FunctionNode):

def __init__(self, stride=1, pad=0, cover_all=False, requires_x_grad=True,
**kwargs):
def __init__(self, stride=1, pad=0, cover_all=False, **kwargs):
argument.check_unexpected_kwargs(
kwargs, deterministic="deterministic argument is not "
"supported anymore. "
"Use chainer.using_config('cudnn_deterministic', value) "
"context where value is either `True` or `False`.")
kwargs,
deterministic="deterministic argument is not supported anymore. "
"Use chainer.using_config('cudnn_deterministic', value) context "
"where value is either `True` or `False`.",
requires_x_grad="requires_x_grad argument is not supported "
"anymore. Just remove the argument. Note that whether to compute "
"the gradient w.r.t. x is automatically decided during "
"backpropagation."
)
argument.assert_kwargs_empty(kwargs)

self.sy, self.sx = _pair(stride)
self.ph, self.pw = _pair(pad)
self.cover_all = cover_all
self.requires_x_grad = requires_x_grad

def check_type_forward(self, in_types):
n_in = in_types.size()
Expand All @@ -64,6 +67,7 @@ def check_type_forward(self, in_types):
)

def forward_cpu(self, inputs):
self.retain_inputs((0, 1)) # retain only x and W
x, W = inputs[:2]
b = inputs[2] if len(inputs) == 3 else None

Expand All @@ -78,16 +82,17 @@ def forward_cpu(self, inputs):
.format(type(W), type(x)))

kh, kw = W.shape[2:]
self.col = conv.im2col_cpu(
col = conv.im2col_cpu(
x, kh, kw, self.sy, self.sx, self.ph, self.pw,
cover_all=self.cover_all)
y = numpy.tensordot(
self.col, W, ((1, 2, 3), (1, 2, 3))).astype(x.dtype, copy=False)
col, W, ((1, 2, 3), (1, 2, 3))).astype(x.dtype, copy=False)
if b is not None:
y += b
return numpy.rollaxis(y, 3, 1),

def forward_gpu(self, inputs):
self.retain_inputs((0, 1)) # retain only x and W
x, W = inputs[:2]
b = inputs[2] if len(inputs) == 3 else None

Expand Down Expand Up @@ -123,154 +128,157 @@ def forward_gpu(self, inputs):
x_desc = cudnn.create_tensor_descriptor(x)
y_desc = cudnn.create_tensor_descriptor(y)

self.filter_desc = cudnn.create_filter_descriptor(W)
self.conv_desc = cudnn.create_convolution_descriptor(
filter_desc = cudnn.create_filter_descriptor(W)
conv_desc = cudnn.create_convolution_descriptor(
(self.ph, self.pw), (self.sy, self.sx), x.dtype)
if b is not None:
self.bias_desc = cudnn.create_tensor_descriptor(
bias_desc = cudnn.create_tensor_descriptor(
b[None, :, None, None])

workspace_size = cuda.get_max_workspace_size()
workspace = cuda.cupy.empty((workspace_size,), dtype='b')
algo = libcudnn.getConvolutionForwardAlgorithm(
handle, x_desc.value, self.filter_desc.value,
self.conv_desc.value, y_desc.value, _fwd_pref,
workspace_size)
handle, x_desc.value, filter_desc.value,
conv_desc.value, y_desc.value, _fwd_pref, workspace_size)

oz_dtype = 'd' if x.dtype == 'd' else 'f'
one = numpy.array(1, dtype=oz_dtype).ctypes
zero = numpy.array(0, dtype=oz_dtype).ctypes
libcudnn.convolutionForward(
handle, one.data, x_desc.value, x.data.ptr,
self.filter_desc.value, W.data.ptr, self.conv_desc.value,
filter_desc.value, W.data.ptr, conv_desc.value,
algo, workspace.data.ptr, workspace_size, zero.data,
y_desc.value, y.data.ptr)

# TODO(beam2d): Support unshared bias
if b is not None:
cudnn.add_tensor(
handle, one.data, self.bias_desc.value, b.data.ptr,
handle, one.data, bias_desc.value, b.data.ptr,
one.data, y_desc.value, y.data.ptr)
else:
# Implementation using im2col
self.col = conv.im2col_gpu(
col = conv.im2col_gpu(
x, kh, kw, self.sy, self.sx, self.ph, self.pw,
cover_all=self.cover_all)
y = cuda.cupy.tensordot(
self.col, W, ((1, 2, 3), (1, 2, 3))).astype(x.dtype,
copy=False)
col, W, ((1, 2, 3), (1, 2, 3))).astype(x.dtype, copy=False)
# TODO(beam2d): Support unshared bias
if b is not None:
y += b
y = cuda.cupy.rollaxis(y, 3, 1)

return y,

def backward_cpu(self, inputs, grad_outputs):
x, W = inputs[:2]
b = inputs[2] if len(inputs) == 3 else None

gy = grad_outputs[0]
h, w = x.shape[2:]
def backward(self, indexes, grad_outputs):
x, W = self.get_retained_inputs()
gy, = grad_outputs

ret = []
if 0 in indexes:
xh, xw = x.shape[2:]
gx = chainer.functions.deconvolution_2d(
gy, W, stride=(self.sy, self.sx), pad=(self.ph, self.pw),
outsize=(xh, xw))
ret.append(gx)
if 1 in indexes:
gW, = Convolution2DGradW(self).apply((x, gy))
ret.append(gW)
if 2 in indexes:
gb = chainer.functions.sum(gy, axis=(0, 2, 3))
ret.append(gb)

return ret


class Convolution2DGradW(function_node.FunctionNode):

def __init__(self, conv2d):
W_node = conv2d.inputs[1]
self.kh, self.kw = W_node.shape[2:]
self.sy = conv2d.sy
self.sx = conv2d.sx
self.ph = conv2d.ph
self.pw = conv2d.pw
self.cover_all = conv2d.cover_all
self.W_dtype = W_node.dtype

def forward_cpu(self, inputs):
self.retain_inputs((0, 1))
x, gy = inputs
col = conv.im2col_cpu(
x, self.kh, self.kw, self.sy, self.sx, self.ph, self.pw,
cover_all=self.cover_all)
gW = numpy.tensordot(
gy, self.col, ((0, 2, 3), (0, 4, 5))).astype(W.dtype, copy=False)

if not self.requires_x_grad:
gx = None
else:
gcol = numpy.tensordot(W, gy, (0, 1)).astype(x.dtype, copy=False)
gcol = numpy.rollaxis(gcol, 3)
gx = conv.col2im_cpu(gcol, self.sy, self.sx, self.ph, self.pw,
h, w)

if b is None:
return gx, gW
else:
gb = gy.sum(axis=(0, 2, 3))
return gx, gW, gb
gy, col, ((0, 2, 3), (0, 4, 5))).astype(self.W_dtype, copy=False)
return gW,

def backward_gpu(self, inputs, grad_outputs):
x, W = inputs[:2]
b = inputs[2] if len(inputs) == 3 else None

gy = grad_outputs[0]
def forward_gpu(self, inputs):
self.retain_inputs((0, 1))
x, gy = inputs
_, out_c, out_h, out_w = gy.shape
n, c, h, w = x.shape
kh, kw = W.shape[2:]

gW = cuda.cupy.empty_like(W)
gx = None
if (self.cover_all or not chainer.should_use_cudnn('>=auto') or
x.dtype != self.W_dtype):
col = conv.im2col_gpu(
x, self.kh, self.kw, self.sy, self.sx, self.ph, self.pw,
cover_all=self.cover_all)
gW = cuda.cupy.tensordot(
gy, col, ((0, 2, 3), (0, 4, 5))).astype(self.W_dtype,
copy=False)
return gW,

if (not self.cover_all and chainer.should_use_cudnn('>=auto') and
x.dtype == W.dtype):
x = cuda.cupy.ascontiguousarray(x)
W = cuda.cupy.ascontiguousarray(W)
gy = cuda.cupy.ascontiguousarray(gy)
gW = cuda.cupy.empty((out_c, c, self.kh, self.kw), dtype=self.W_dtype)
x = cuda.cupy.ascontiguousarray(x)
gy = cuda.cupy.ascontiguousarray(gy)

handle = cudnn.get_handle()
x_desc = cudnn.create_tensor_descriptor(x)
gy_desc = cudnn.create_tensor_descriptor(gy)
oz_dtype = 'd' if x.dtype == 'd' else 'f'
one = numpy.array(1, dtype=oz_dtype).ctypes
zero = numpy.array(0, dtype=oz_dtype).ctypes
handle = cudnn.get_handle()
x_desc = cudnn.create_tensor_descriptor(x)
gy_desc = cudnn.create_tensor_descriptor(gy)

workspace_size = cuda.get_max_workspace_size()
workspace = cuda.cupy.empty((workspace_size,), dtype='b')
filter_desc = cudnn.create_filter_descriptor(gW)
conv_desc = cudnn.create_convolution_descriptor(
(self.ph, self.pw), (self.sy, self.sx), x.dtype)

if configuration.config.cudnn_deterministic:
algo = libcudnn.CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1
else:
algo = libcudnn.getConvolutionBackwardFilterAlgorithm(
handle, x_desc.value, gy_desc.value,
self.conv_desc.value, self.filter_desc.value,
_bwd_filter_pref, workspace_size)
oz_dtype = 'd' if x.dtype == 'd' else 'f'
one = numpy.array(1, dtype=oz_dtype).ctypes
zero = numpy.array(0, dtype=oz_dtype).ctypes

libcudnn.convolutionBackwardFilter_v3(
handle, one.data, x_desc.value, x.data.ptr,
gy_desc.value, gy.data.ptr, self.conv_desc.value,
algo, workspace.data.ptr, workspace_size,
zero.data, self.filter_desc.value, gW.data.ptr)

if self.requires_x_grad:
if configuration.config.cudnn_deterministic:
algo = libcudnn.CUDNN_CONVOLUTION_BWD_DATA_ALGO_1
else:
algo = libcudnn.getConvolutionBackwardDataAlgorithm(
handle, self.filter_desc.value, gy_desc.value,
self.conv_desc.value, x_desc.value, _bwd_data_pref,
workspace_size)

gx = cuda.cupy.empty_like(x)
libcudnn.convolutionBackwardData_v3(
handle, one.data, self.filter_desc.value, W.data.ptr,
gy_desc.value, gy.data.ptr, self.conv_desc.value,
algo, workspace.data.ptr, workspace_size,
zero.data, x_desc.value, gx.data.ptr)
workspace_size = cuda.get_max_workspace_size()
workspace = cuda.cupy.empty((workspace_size,), dtype='b')

if b is not None:
gb = cuda.cupy.empty_like(b)
libcudnn.convolutionBackwardBias(
handle, one.data, gy_desc.value, gy.data.ptr,
zero.data, self.bias_desc.value, gb.data.ptr)
if configuration.config.cudnn_deterministic:
algo = libcudnn.CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1
else:
gW = cuda.cupy.tensordot(
gy, self.col, ((0, 2, 3), (0, 4, 5))).astype(W.dtype,
copy=False)
if self.requires_x_grad:
gcol = cuda.cupy.tensordot(W, gy, (0, 1)).astype(x.dtype,
copy=False)
gcol = cuda.cupy.rollaxis(gcol, 3)
gx = conv.col2im_gpu(
gcol, self.sy, self.sx, self.ph, self.pw, h, w)

if b is not None:
gb = gy.sum(axis=(0, 2, 3))
algo = libcudnn.getConvolutionBackwardFilterAlgorithm(
handle, x_desc.value, gy_desc.value, conv_desc.value,
filter_desc.value, _bwd_filter_pref, workspace_size)

libcudnn.convolutionBackwardFilter_v3(
handle, one.data, x_desc.value, x.data.ptr, gy_desc.value,
gy.data.ptr, conv_desc.value, algo, workspace.data.ptr,
workspace_size, zero.data, filter_desc.value, gW.data.ptr)

return gW,

def backward(self, indexes, grad_outputs):
x, gy = self.get_retained_inputs()
ggW, = grad_outputs

ret = []
if 0 in indexes:
xh, xw = x.shape[2:]
gx = chainer.functions.deconvolution_2d(
gy, ggW, stride=(self.sy, self.sx), pad=(self.ph, self.pw),
outsize=(xh, xw))
ret.append(gx)
if 1 in indexes:
ggy = convolution_2d(
x, ggW, stride=(self.sy, self.sx), pad=(self.ph, self.pw),
cover_all=self.cover_all)
ret.append(ggy)

if b is None:
return gx, gW
else:
return gx, gW, gb
return ret


def convolution_2d(x, W, b=None, stride=1, pad=0, cover_all=False, **kwargs):
Expand Down Expand Up @@ -397,9 +405,10 @@ def convolution_2d(x, W, b=None, stride=1, pad=0, cover_all=False, **kwargs):
"context where value is either `True` or `False`.")
argument.assert_kwargs_empty(kwargs)

requires_x_grad = isinstance(x, variable.Variable) and x.requires_grad
func = Convolution2DFunction(stride, pad, cover_all, requires_x_grad)
fnode = Convolution2DFunction(stride, pad, cover_all)
if b is None:
return func(x, W)
args = x, W
else:
return func(x, W, b)
args = x, W, b
y, = fnode.apply(args)
return y

0 comments on commit 26597b0

Please sign in to comment.