Skip to content

Commit

Permalink
Merge pull request #2858 from anaruse/dilated_convolution_2d
Browse files Browse the repository at this point in the history
cuDNN v6 dilated convolution
  • Loading branch information
okuta committed Oct 28, 2017
2 parents 0efd39e + 0d256b0 commit 7289901
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 386 deletions.
41 changes: 27 additions & 14 deletions chainer/functions/connection/convolution_2d.py
Expand Up @@ -12,6 +12,7 @@
if cuda.cudnn_enabled:
cudnn = cuda.cudnn
libcudnn = cuda.cudnn.cudnn
_cudnn_version = libcudnn.getVersion()
_fwd_pref = libcudnn.CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
_bwd_filter_pref = \
libcudnn.CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
Expand All @@ -38,11 +39,12 @@ def __init__(self, stride=1, pad=0, cover_all=False, **kwargs):
"the gradient w.r.t. x is automatically decided during "
"backpropagation."
)
argument.assert_kwargs_empty(kwargs)
dilate, = argument.parse_kwargs(kwargs, ('dilate', 1))

self.sy, self.sx = _pair(stride)
self.ph, self.pw = _pair(pad)
self.cover_all = cover_all
self.dy, self.dx = _pair(dilate)

def check_type_forward(self, in_types):
n_in = in_types.size()
Expand Down Expand Up @@ -84,7 +86,7 @@ def forward_cpu(self, inputs):
kh, kw = W.shape[2:]
col = conv.im2col_cpu(
x, kh, kw, self.sy, self.sx, self.ph, self.pw,
cover_all=self.cover_all)
cover_all=self.cover_all, dy=self.dy, dx=self.dx)
y = numpy.tensordot(
col, W, ((1, 2, 3), (1, 2, 3))).astype(x.dtype, copy=False)
if b is not None:
Expand All @@ -110,15 +112,16 @@ def forward_gpu(self, inputs):
n, c, h, w = x.shape

out_h = conv.get_conv_outsize(h, kh, self.sy, self.ph,
cover_all=self.cover_all)
cover_all=self.cover_all, d=self.dy)
assert out_h > 0, 'Height in the output should be positive.'
out_w = conv.get_conv_outsize(w, kw, self.sx, self.pw,
cover_all=self.cover_all)
cover_all=self.cover_all, d=self.dx)
assert out_w > 0, 'Width in the output should be positive.'

y = cuda.cupy.empty((n, out_c, out_h, out_w), dtype=x.dtype)
if (not self.cover_all and chainer.should_use_cudnn('>=auto') and
x.dtype == W.dtype):
x.dtype == W.dtype and
((self.dy == 1 and self.dx == 1) or _cudnn_version >= 6000)):
x = cuda.cupy.ascontiguousarray(x)
W = cuda.cupy.ascontiguousarray(W)
if b is not None:
Expand All @@ -133,6 +136,7 @@ def forward_gpu(self, inputs):
filter_desc = cudnn.create_filter_descriptor(W)
conv_desc = cudnn.create_convolution_descriptor(
(self.ph, self.pw), (self.sy, self.sx), x.dtype,
dilation=(self.dy, self.dx),
use_tensor_core=use_tensor_core)
if b is not None:
bias_desc = cudnn.create_tensor_descriptor(
Expand Down Expand Up @@ -167,7 +171,7 @@ def forward_gpu(self, inputs):
# Implementation using im2col
col = conv.im2col_gpu(
x, kh, kw, self.sy, self.sx, self.ph, self.pw,
cover_all=self.cover_all)
cover_all=self.cover_all, dy=self.dy, dx=self.dx)
y = cuda.cupy.tensordot(
col, W, ((1, 2, 3), (1, 2, 3))).astype(x.dtype, copy=False)
# TODO(beam2d): Support unshared bias
Expand All @@ -186,7 +190,7 @@ def backward(self, indexes, grad_outputs):
xh, xw = x.shape[2:]
gx = chainer.functions.deconvolution_2d(
gy, W, stride=(self.sy, self.sx), pad=(self.ph, self.pw),
outsize=(xh, xw))
outsize=(xh, xw), dilate=(self.dy, self.dx))
ret.append(gx)
if 1 in indexes:
gW, = Convolution2DGradW(self).apply((x, gy))
Expand All @@ -207,6 +211,8 @@ def __init__(self, conv2d):
self.sx = conv2d.sx
self.ph = conv2d.ph
self.pw = conv2d.pw
self.dy = conv2d.dy
self.dx = conv2d.dx
self.cover_all = conv2d.cover_all
self.W_dtype = W_node.dtype

Expand All @@ -215,7 +221,7 @@ def forward_cpu(self, inputs):
x, gy = inputs
col = conv.im2col_cpu(
x, self.kh, self.kw, self.sy, self.sx, self.ph, self.pw,
cover_all=self.cover_all)
cover_all=self.cover_all, dy=self.dy, dx=self.dx)

# NumPy raises an error when the array is not contiguous.
# See: https://github.com/chainer/chainer/issues/2744
Expand All @@ -235,10 +241,11 @@ def forward_gpu(self, inputs):
n, c, h, w = x.shape

if (self.cover_all or not chainer.should_use_cudnn('>=auto') or
x.dtype != self.W_dtype):
x.dtype != self.W_dtype or
((self.dy > 1 or self.dx > 1) and _cudnn_version < 6000)):
col = conv.im2col_gpu(
x, self.kh, self.kw, self.sy, self.sx, self.ph, self.pw,
cover_all=self.cover_all)
cover_all=self.cover_all, dy=self.dy, dx=self.dx)
gW = cuda.cupy.tensordot(
gy, col, ((0, 2, 3), (0, 4, 5))).astype(self.W_dtype,
copy=False)
Expand All @@ -257,6 +264,7 @@ def forward_gpu(self, inputs):
filter_desc = cudnn.create_filter_descriptor(gW)
conv_desc = cudnn.create_convolution_descriptor(
(self.ph, self.pw), (self.sy, self.sx), x.dtype,
dilation=(self.dy, self.dx),
use_tensor_core=use_tensor_core)

oz_dtype = 'd' if x.dtype == 'd' else 'f'
Expand Down Expand Up @@ -294,12 +302,12 @@ def backward(self, indexes, grad_outputs):
xh, xw = x.shape[2:]
gx = chainer.functions.deconvolution_2d(
gy, ggW, stride=(self.sy, self.sx), pad=(self.ph, self.pw),
outsize=(xh, xw))
outsize=(xh, xw), dilate=(self.dy, self.dx))
ret.append(gx)
if 1 in indexes:
ggy = convolution_2d(
x, ggW, stride=(self.sy, self.sx), pad=(self.ph, self.pw),
cover_all=self.cover_all)
cover_all=self.cover_all, dilate=(self.dy, self.dx))
ret.append(ggy)

return ret
Expand Down Expand Up @@ -361,6 +369,9 @@ def convolution_2d(x, W, b=None, stride=1, pad=0, cover_all=False, **kwargs):
If ``chainer.configuration.config.cudnn_deterministic`` is ``True`` and
cuDNN version is >= v3, it forces cuDNN to use a deterministic algorithm.
When the dilation factor is greater than one, cuDNN is not used unless
the version is 6.0 or higher.
.. warning::
``deterministic`` argument is not supported anymore since v2.
Expand All @@ -385,6 +396,8 @@ def convolution_2d(x, W, b=None, stride=1, pad=0, cover_all=False, **kwargs):
``pad=p`` and ``pad=(p, p)`` are equivalent.
cover_all (bool): If ``True``, all spatial locations are convoluted
into some output pixels.
dilate (int or pair of ints): Dilation factor of filter applications.
``dilate=d`` and ``dilate=(d, d)`` are equivalent.
Returns:
~chainer.Variable:
Expand Down Expand Up @@ -427,9 +440,9 @@ def convolution_2d(x, W, b=None, stride=1, pad=0, cover_all=False, **kwargs):
"supported anymore. "
"Use chainer.using_config('cudnn_deterministic', value) "
"context where value is either `True` or `False`.")
argument.assert_kwargs_empty(kwargs)
dilate, = argument.parse_kwargs(kwargs, ('dilate', 1))

fnode = Convolution2DFunction(stride, pad, cover_all)
fnode = Convolution2DFunction(stride, pad, cover_all, dilate=dilate)
if b is None:
args = x, W
else:
Expand Down
50 changes: 33 additions & 17 deletions chainer/functions/connection/deconvolution_2d.py
Expand Up @@ -13,6 +13,7 @@
if cuda.cudnn_enabled:
cudnn = cuda.cudnn
libcudnn = cuda.cudnn.cudnn
_cudnn_version = libcudnn.getVersion()
_fwd_pref = libcudnn.CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
_bwd_filter_pref = \
libcudnn.CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
Expand Down Expand Up @@ -41,11 +42,12 @@ def __init__(self, stride=1, pad=0, outsize=None, **kwargs):
"the gradient w.r.t. x is automatically decided during "
"backpropagation."
)
argument.assert_kwargs_empty(kwargs)
dilate, = argument.parse_kwargs(kwargs, ('dilate', 1))

self.sy, self.sx = _pair(stride)
self.ph, self.pw = _pair(pad)
self.outh, self.outw = (None, None) if outsize is None else outsize
self.dy, self.dx = _pair(dilate)

def check_type_forward(self, in_types):
n_in = in_types.size()
Expand All @@ -62,17 +64,21 @@ def check_type_forward(self, in_types):

if self.outh is not None:
lower_bound = conv.get_conv_outsize(
self.outh, w_type.shape[2], self.sy, self.ph)
self.outh, w_type.shape[2], self.sy, self.ph,
d=self.dy)
upper_bound = conv.get_conv_outsize(
self.outh, w_type.shape[2], self.sy, self.ph, cover_all=True)
self.outh, w_type.shape[2], self.sy, self.ph, cover_all=True,
d=self.dy)
type_check.expect(
lower_bound <= x_type.shape[2],
x_type.shape[2] <= upper_bound)
if self.outw is not None:
lower_bound = conv.get_conv_outsize(
self.outw, w_type.shape[3], self.sx, self.pw)
self.outw, w_type.shape[3], self.sx, self.pw,
d=self.dx)
upper_bound = conv.get_conv_outsize(
self.outw, w_type.shape[3], self.sx, self.pw, cover_all=True)
self.outw, w_type.shape[3], self.sx, self.pw, cover_all=True,
d=self.dx)
type_check.expect(
lower_bound <= x_type.shape[3],
x_type.shape[3] <= upper_bound)
Expand Down Expand Up @@ -109,13 +115,16 @@ def forward_cpu(self, inputs):
# k, m, n, b, h, w -> b, k, m, n, h, w
gcol = numpy.rollaxis(gcol, 3)
if self.outh is None:
self.outh = conv.get_deconv_outsize(h, kh, self.sy, self.ph)
self.outh = conv.get_deconv_outsize(h, kh, self.sy, self.ph,
d=self.dy)
assert self.outh > 0, 'Height in the output should be positive.'
if self.outw is None:
self.outw = conv.get_deconv_outsize(w, kw, self.sx, self.pw)
self.outw = conv.get_deconv_outsize(w, kw, self.sx, self.pw,
d=self.dx)
assert self.outw > 0, 'Width in the output should be positive.'
y = conv.col2im_cpu(
gcol, self.sy, self.sx, self.ph, self.pw, self.outh, self.outw)
gcol, self.sy, self.sx, self.ph, self.pw, self.outh, self.outw,
dy=self.dy, dx=self.dx)
# b, k, h, w
if b is not None:
y += b.reshape(1, b.size, 1, 1)
Expand All @@ -140,16 +149,19 @@ def forward_gpu(self, inputs):
n, in_c, in_h, in_w = x.shape
c = W.shape[1] # out_c
if self.outh is None:
self.outh = conv.get_deconv_outsize(in_h, kh, self.sy, self.ph)
self.outh = conv.get_deconv_outsize(in_h, kh, self.sy, self.ph,
d=self.dy)
assert self.outh > 0, 'Height in the output should be positive.'
if self.outw is None:
self.outw = conv.get_deconv_outsize(in_w, kw, self.sx, self.pw)
self.outw = conv.get_deconv_outsize(in_w, kw, self.sx, self.pw,
d=self.dx)
assert self.outw > 0, 'Width in the output should be positive.'

self._set_cover_all(x, W)

if (not self.cover_all and chainer.should_use_cudnn('>=auto') and
x.dtype == W.dtype):
x.dtype == W.dtype and
((self.dy == 1 and self.dx == 1) or _cudnn_version >= 6000)):
x = cuda.cupy.ascontiguousarray(x)
W = cuda.cupy.ascontiguousarray(W)
if b is not None:
Expand All @@ -166,6 +178,7 @@ def forward_gpu(self, inputs):
filter_desc = cudnn.create_filter_descriptor(W)
conv_desc = cudnn.create_convolution_descriptor(
(self.ph, self.pw), (self.sy, self.sx), x.dtype,
dilation=(self.dy, self.dx),
use_tensor_core=use_tensor_core)
if b is not None:
bias_desc = cudnn.create_tensor_descriptor(
Expand Down Expand Up @@ -209,7 +222,8 @@ def forward_gpu(self, inputs):
# k, m, n, b, h, w -> b, k, m, n, h, w
gcol = cuda.cupy.rollaxis(gcol, 3)
y = conv.col2im_gpu(
gcol, self.sy, self.sx, self.ph, self.pw, self.outh, self.outw)
gcol, self.sy, self.sx, self.ph, self.pw, self.outh, self.outw,
dy=self.dy, dx=self.dx)
if b is not None:
y += b.reshape(1, b.size, 1, 1)
return y,
Expand All @@ -224,7 +238,7 @@ def backward(self, indexes, grad_outputs):
self._set_cover_all(x, W)
gx = chainer.functions.convolution_2d(
gy, W, stride=(self.sy, self.sx), pad=(self.ph, self.pw),
cover_all=self.cover_all)
cover_all=self.cover_all, dilate=(self.dy, self.dx))
ret.append(gx)
if 1 in indexes:
if self.cover_all is None:
Expand All @@ -241,8 +255,10 @@ def _set_cover_all(self, x, W):
in_h, in_w = x.shape[2:]
kh, kw = W.shape[2:]
self.cover_all = (
in_h != conv.get_conv_outsize(self.outh, kh, self.sy, self.ph) or
in_w != conv.get_conv_outsize(self.outw, kw, self.sx, self.pw))
in_h != conv.get_conv_outsize(self.outh, kh, self.sy,
self.ph, d=self.dy) or
in_w != conv.get_conv_outsize(self.outw, kw, self.sx,
self.pw, d=self.dx))


def deconvolution_2d(x, W, b=None, stride=1, pad=0, outsize=None, **kwargs):
Expand Down Expand Up @@ -350,9 +366,9 @@ def deconvolution_2d(x, W, b=None, stride=1, pad=0, outsize=None, **kwargs):
"supported anymore. "
"Use chainer.using_config('cudnn_deterministic', value) "
"context where value is either `True` or `False`.")
argument.assert_kwargs_empty(kwargs)
dilate, = argument.parse_kwargs(kwargs, ('dilate', 1))

func = Deconvolution2DFunction(stride, pad, outsize)
func = Deconvolution2DFunction(stride, pad, outsize, dilate=dilate)
if b is None:
args = x, W
else:
Expand Down

0 comments on commit 7289901

Please sign in to comment.