Skip to content

Commit

Permalink
Merge pull request #3782 from okuta/refactoring-cudnn-conv
Browse files Browse the repository at this point in the history
 Delegate cuDNN convolution operation to CuPy
  • Loading branch information
hvy committed Mar 19, 2018
2 parents a7cd05c + b2797bf commit 551ce0f
Show file tree
Hide file tree
Showing 12 changed files with 103 additions and 586 deletions.
13 changes: 7 additions & 6 deletions chainer/backends/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,9 +523,6 @@ def get_array_module(*args):
return numpy


_max_workspace_size = 8 * 1024 * 1024


def get_max_workspace_size():
"""Gets the workspace size for cuDNN.
Expand All @@ -535,7 +532,10 @@ def get_max_workspace_size():
int: The workspace size for cuDNN.
"""
return _max_workspace_size
# To avoid error on no cuDNN environment
if cudnn_enabled:
return cudnn.get_max_workspace_size()
return 0


def set_max_workspace_size(size):
Expand All @@ -547,8 +547,9 @@ def set_max_workspace_size(size):
size: The workspace size for cuDNN.
"""
global _max_workspace_size
_max_workspace_size = size
# To avoid error on no cuDNN environment
if cudnn_enabled:
cudnn.set_max_workspace_size(size)


def fuse(*args, **kwargs):
Expand Down
165 changes: 17 additions & 148 deletions chainer/functions/connection/convolution_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,7 @@
from chainer.utils import type_check

if cuda.cudnn_enabled:
cudnn = cuda.cudnn
libcudnn = cuda.cuda.cudnn
_cudnn_version = libcudnn.getVersion()
_fwd_pref = libcudnn.CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
_bwd_filter_pref = \
libcudnn.CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
_algorithm_fwd = {}
_algorithm_bwd_filter = {}
_cudnn_version = cuda.cuda.cudnn.getVersion()


def _pair(x):
Expand All @@ -28,36 +21,6 @@ def _pair(x):
return x, x


def _get_algorithm_fwd(
x, W, y, conv_param, handle, x_desc, filter_desc, conv_desc, y_desc,
workspace):
key = (x.shape, W.shape, y.shape, conv_param)
if key in _algorithm_fwd:
return _algorithm_fwd[key]
ret = libcudnn.findConvolutionForwardAlgorithmEx(
handle, x_desc.value, x.data.ptr, filter_desc.value, W.data.ptr,
conv_desc.value, y_desc.value, y.data.ptr, 1, workspace.data.ptr,
workspace.size)
algo = ret[0]['algo']
_algorithm_fwd[key] = algo
return algo


def _get_algorithm_bwd_filter(
x, dy, dW, conv_param, handle, x_desc, dy_desc, conv_desc, filter_desc,
workspace):
key = (x.shape, dW.shape, dy.shape, conv_param)
if key in _algorithm_bwd_filter:
return _algorithm_bwd_filter[key]
ret = libcudnn.findConvolutionBackwardFilterAlgorithmEx(
handle, x_desc.value, x.data.ptr, dy_desc.value, dy.data.ptr,
conv_desc.value, filter_desc.value, dW.data.ptr, 1,
workspace.data.ptr, workspace.size)
algo = ret[0]['algo']
_algorithm_bwd_filter[key] = algo
return algo


class Convolution2DFunction(function_node.FunctionNode):

_use_ideep = False
Expand Down Expand Up @@ -254,68 +217,16 @@ def _forward_grouped_convolution(self, x, W, b):
return y,

def _forward_cudnn(self, x, W, b, y):
x = cuda.cupy.ascontiguousarray(x)
W = cuda.cupy.ascontiguousarray(W)
if b is not None:
b = cuda.cupy.ascontiguousarray(b)

use_tensor_core = chainer.should_use_cudnn_tensor_core(x.dtype)

# cuDNN 7 supports dilation only in *_FWD_ALGO_IMPLICIT_GEMM, but
# it supports Tensor Cores only in *_FWD_ALGO_IMPLICIT_PRECOMP_GEMM.
if use_tensor_core and (self.dx > 1 or self.dy > 1):
use_tensor_core = False

handle = cudnn.get_handle()
x_desc = cudnn.create_tensor_descriptor(x)
y_desc = cudnn.create_tensor_descriptor(y)

filter_desc = cudnn.create_filter_descriptor(W)
conv_param = ((self.ph, self.pw), (self.sy, self.sx), x.dtype)
pad = (self.ph, self.pw)
stride = (self.sy, self.sx)
dilation = (self.dy, self.dx)
conv_desc = cudnn.create_convolution_descriptor(
*conv_param, dilation=dilation,
use_tensor_core=use_tensor_core,
groups=self.groups)
if b is not None:
bias_desc = cudnn.create_tensor_descriptor(
b[None, :, None, None])
workspace_size = cuda.get_max_workspace_size()
workspace = cuda.cupy.empty((workspace_size,), dtype='b')
if configuration.config.autotune and _cudnn_version >= 5000:
algo = _get_algorithm_fwd(
x, W, y, conv_param + (dilation,), handle, x_desc,
filter_desc, conv_desc, y_desc, workspace)
else:
algo = libcudnn.getConvolutionForwardAlgorithm(
handle, x_desc.value, filter_desc.value,
conv_desc.value, y_desc.value, _fwd_pref, workspace_size)

if use_tensor_core:
algo = self._tensor_core_adjust_algo()

oz_dtype = 'd' if x.dtype == 'd' else 'f'
one = numpy.array(1, dtype=oz_dtype).ctypes
zero = numpy.array(0, dtype=oz_dtype).ctypes
libcudnn.convolutionForward(
handle, one.data, x_desc.value, x.data.ptr,
filter_desc.value, W.data.ptr, conv_desc.value,
algo, workspace.data.ptr, workspace_size, zero.data,
y_desc.value, y.data.ptr)

# TODO(beam2d): Support unshared bias
if b is not None:
cudnn.add_tensor(
handle, one.data, bias_desc.value, b.data.ptr,
one.data, y_desc.value, y.data.ptr)

auto_tune = configuration.config.autotune
tensor_core = configuration.config.use_cudnn_tensor_core
cuda.cudnn.convolution_forward(
x, W, b, y, pad, stride, dilation, self.groups,
auto_tune=auto_tune, tensor_core=tensor_core)
return y,

def _tensor_core_adjust_algo(self):
# Only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
# supports Tensor-Core in cuDNN7.
return libcudnn.CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM

def backward(self, indexes, grad_outputs):
x, W = self.get_retained_inputs()
gy, = grad_outputs
Expand Down Expand Up @@ -480,61 +391,19 @@ def _forward_cudnn(self, x, gy):
iCg = int(iC / self.groups)
gW = cuda.cupy.empty((out_c, iCg, self.kh, self.kw),
dtype=self.W_dtype)
x = cuda.cupy.ascontiguousarray(x)
gy = cuda.cupy.ascontiguousarray(gy)

use_tensor_core = chainer.should_use_cudnn_tensor_core(x.dtype)

# cuDNN 7 supports dilation only in *_BWD_FILTER_ALGO_0, but
# it supports Tensor Cores only in *_BWD_FILTER_ALGO_1.
if use_tensor_core and (self.dx > 1 or self.dy > 1):
use_tensor_core = False

handle = cudnn.get_handle()
x_desc = cudnn.create_tensor_descriptor(x)
gy_desc = cudnn.create_tensor_descriptor(gy)

filter_desc = cudnn.create_filter_descriptor(gW)
conv_param = (self.ph, self.pw), (self.sy, self.sx), x.dtype
pad = (self.ph, self.pw)
stride = (self.sy, self.sx)
dilation = (self.dy, self.dx)
conv_desc = cudnn.create_convolution_descriptor(
*conv_param, dilation=dilation,
use_tensor_core=use_tensor_core,
groups=self.groups)

oz_dtype = 'd' if x.dtype == 'd' else 'f'
one = numpy.array(1, dtype=oz_dtype).ctypes
zero = numpy.array(0, dtype=oz_dtype).ctypes

workspace_size = cuda.get_max_workspace_size()
workspace = cuda.cupy.empty((workspace_size,), dtype='b')

if configuration.config.cudnn_deterministic:
algo = libcudnn.CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1
elif configuration.config.autotune and _cudnn_version >= 5000:
algo = _get_algorithm_bwd_filter(
x, gy, gW, conv_param + (dilation,), handle, x_desc, gy_desc,
conv_desc, filter_desc, workspace)
else:
algo = libcudnn.getConvolutionBackwardFilterAlgorithm(
handle, x_desc.value, gy_desc.value, conv_desc.value,
filter_desc.value, _bwd_filter_pref, workspace_size)

if use_tensor_core:
algo = self._tensor_core_adjust_algo()

libcudnn.convolutionBackwardFilter_v3(
handle, one.data, x_desc.value, x.data.ptr, gy_desc.value,
gy.data.ptr, conv_desc.value, algo, workspace.data.ptr,
workspace_size, zero.data, filter_desc.value, gW.data.ptr)
deterministic = configuration.config.cudnn_deterministic
auto_tune = configuration.config.autotune
tensor_core = configuration.config.use_cudnn_tensor_core
cuda.cudnn.convolution_backward_filter(
x, gy, gW, pad, stride, dilation, self.groups,
deterministic=deterministic, auto_tune=auto_tune,
tensor_core=tensor_core)

return gW,

def _tensor_core_adjust_algo(self):
# Only CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 supports
# Tensor-Core in cuDNN7.
return libcudnn.CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1

def backward(self, indexes, grad_outputs):
x, gy = self.get_retained_inputs()
ggW, = grad_outputs
Expand Down
121 changes: 19 additions & 102 deletions chainer/functions/connection/convolution_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,11 @@
from chainer.backends import cuda
from chainer import configuration
from chainer import function_node
from chainer.functions.connection import convolution_2d
from chainer.utils import conv
from chainer.utils import conv_nd
from chainer.utils import type_check


if cuda.cudnn_enabled:
cudnn = cuda.cudnn
libcudnn = cuda.cuda.cudnn
_cudnn_version_ = libcudnn.getVersion()
_fwd_pref = libcudnn.CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
_bwd_filter_pref = \
libcudnn.CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
_bwd_data_pref = \
libcudnn.CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT


class ConvolutionND(function_node.FunctionNode):

def __init__(self, ndim, stride=1, pad=0, cover_all=False):
Expand Down Expand Up @@ -89,8 +77,6 @@ def _forward_cudnn(self, x, W, b):
dims = x.shape[2:]
stride = self.stride
pad = self.pad
ndim = self.ndim
colon = slice(None)

# Make empty array for result.
outs = tuple(
Expand All @@ -99,55 +85,13 @@ def _forward_cudnn(self, x, W, b):
assert all(out > 0 for out in outs), 'Output sizes should be positive.'
y_shape = (n, out_c) + outs # (n, c_O, out_1, out_2, ..., out_N)
y = cuda.cupy.empty(y_shape, dtype=x.dtype)

# Convert to C-contiguous arrays.
x = cuda.cupy.ascontiguousarray(x)
W = cuda.cupy.ascontiguousarray(W)
if b is not None:
b = cuda.cupy.ascontiguousarray(b)

# Get cuDNN handler and descriptors.
handle = cudnn.get_handle()
x_desc = cudnn.create_tensor_descriptor(x)
y_desc = cudnn.create_tensor_descriptor(y)

self.filter_desc = cudnn.create_filter_descriptor(W)
self.conv_param = (pad, stride, x.dtype)
self.conv_desc = cudnn.create_convolution_descriptor(*self.conv_param)
if b is not None:
b_index = (None, colon) + (None,) * ndim
self.bias_desc = cudnn.create_tensor_descriptor(b[b_index])

# Find cuDNN algorithm to be used.
workspace_size = cuda.get_max_workspace_size()
workspace = cuda.cupy.empty((workspace_size,), dtype='b')
if configuration.config.autotune and _cudnn_version_ >= 5000:
algo = convolution_2d._get_algorithm_fwd(
x, W, y, self.conv_param, handle, x_desc, self.filter_desc,
self.conv_desc, y_desc, workspace)
else:
algo = libcudnn.getConvolutionForwardAlgorithm(
handle, x_desc.value, self.filter_desc.value,
self.conv_desc.value, y_desc.value, _fwd_pref,
workspace_size)

# cuDNN forward computation.
oz_dtype = 'd' if x.dtype == 'd' else 'f'
one = numpy.array(1, dtype=oz_dtype).ctypes
zero = numpy.array(0, dtype=oz_dtype).ctypes
libcudnn.convolutionForward(
handle, one.data, x_desc.value, x.data.ptr,
self.filter_desc.value, W.data.ptr, self.conv_desc.value,
algo, workspace.data.ptr, workspace_size, zero.data,
y_desc.value, y.data.ptr)

# Add bias if given.
# TODO(takagi) Support unshared bias
if b is not None:
cudnn.add_tensor(
handle, one.data, self.bias_desc.value, b.data.ptr,
one.data, y_desc.value, y.data.ptr)

dilation = (1,) * self.ndim
groups = 1
auto_tune = configuration.config.autotune
tensor_core = configuration.config.use_cudnn_tensor_core
cuda.cudnn.convolution_forward(
x, W, b, y, pad, stride, dilation, groups,
auto_tune=auto_tune, tensor_core=tensor_core)
return y,

def forward(self, inputs):
Expand Down Expand Up @@ -241,51 +185,24 @@ def _forward_xp(self, x, gy, xp):
return gW,

def _forward_cudnn(self, x, gy):
# Convert to C-contiguous arrays.
x = cuda.cupy.ascontiguousarray(x)
gy = cuda.cupy.ascontiguousarray(gy)

# Make empty arrays for result.
out_c = gy.shape[1]
in_c = x.shape[1]
gW = cuda.cupy.empty(
(out_c, in_c) + self.ksize, dtype=self.W_dtype)

# Get cuDNN handler and descriptors.
use_tensor_core = chainer.should_use_cudnn_tensor_core(x.dtype)

handle = cudnn.get_handle()
x_desc = cudnn.create_tensor_descriptor(x)
gy_desc = cudnn.create_tensor_descriptor(gy)

filter_desc = cudnn.create_filter_descriptor(gW)
conv_param = (self.pad, self.stride, self.W_dtype)
conv_desc = cudnn.create_convolution_descriptor(
*conv_param, use_tensor_core=use_tensor_core)

# Compute gradients.
oz_dtype = 'd' if x.dtype == 'd' else 'f'
one = numpy.array(1, dtype=oz_dtype).ctypes
zero = numpy.array(0, dtype=oz_dtype).ctypes

workspace_size = cuda.get_max_workspace_size()
workspace = cuda.cupy.empty((workspace_size,), dtype='b')

# Compute filter weight gradient.
if configuration.config.autotune and _cudnn_version_ >= 5000:
algo = convolution_2d._get_algorithm_bwd_filter(
x, gy, gW, conv_param, handle, x_desc, gy_desc,
conv_desc, filter_desc, workspace)
else:
algo = libcudnn.getConvolutionBackwardFilterAlgorithm(
handle, x_desc.value, gy_desc.value, conv_desc.value,
filter_desc.value, _bwd_filter_pref, workspace_size)

libcudnn.convolutionBackwardFilter_v3(
handle, one.data, x_desc.value, x.data.ptr,
gy_desc.value, gy.data.ptr, conv_desc.value,
algo, workspace.data.ptr, workspace_size,
zero.data, filter_desc.value, gW.data.ptr)
# Compute
pad = self.pad
stride = self.stride
dilation = (1,) * self.ndim
groups = 1
deterministic = configuration.config.cudnn_deterministic
auto_tune = configuration.config.autotune
tensor_core = configuration.config.use_cudnn_tensor_core
cuda.cudnn.convolution_backward_filter(
x, gy, gW, pad, stride, dilation, groups,
deterministic=deterministic, auto_tune=auto_tune,
tensor_core=tensor_core)

return gW,

Expand Down
Loading

0 comments on commit 551ce0f

Please sign in to comment.