Skip to content

Commit

Permalink
Merge pull request #6777 from chainer-ci/bp-6725-v10-fix-matmul-int16
Browse files Browse the repository at this point in the history
[backport] Fix batched matmul for integral numbers
  • Loading branch information
kmaehashi committed Jun 14, 2022
2 parents d7b8b9f + 694f30d commit 3719ba9
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 22 deletions.
164 changes: 142 additions & 22 deletions cupy/_core/_routines_linalg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,14 @@ cpdef compute_type_to_str(compute_type):
return compute_type


@cupy._util.memoize(for_each_device=True)
def _tensordot_core_int_kernel(config, dtype):
def _tensordot_core_int_kernel_impl(config, dtype, code, name):
# This code is based in the GEMM implementation from MAGMA
# (http://icl.cs.utk.edu/magma/)
code = '''
#define fetch(arr, col, m, n, bound) arr[min(n*col + m, bound)]
template<typename T>
__global__ void _tensordot_core_int_kernel(
__device__ void _tensordot_core_int_kernel_impl(
int M, int N, int K,
const T* A,
const T* B,
Expand Down Expand Up @@ -284,31 +283,86 @@ __global__ void _tensordot_core_int_kernel(
}
}
}
'''
''' + code
for k, v in config:
code = '#define ' + k + ' ' + str(v) + '\n' + code
name_expressions = ['_tensordot_core_int_kernel<bool>',
'_tensordot_core_int_kernel<signed char>',
'_tensordot_core_int_kernel<unsigned char>',
'_tensordot_core_int_kernel<short>',
'_tensordot_core_int_kernel<unsigned short>',
'_tensordot_core_int_kernel<int>',
'_tensordot_core_int_kernel<unsigned int>',
'_tensordot_core_int_kernel<long>',
'_tensordot_core_int_kernel<unsigned long>',
'_tensordot_core_int_kernel<long long>',
'_tensordot_core_int_kernel<unsigned long long>']
name_expressions = [f'{name}<bool>',
f'{name}<signed char>',
f'{name}<unsigned char>',
f'{name}<short>',
f'{name}<unsigned short>',
f'{name}<int>',
f'{name}<unsigned int>',
f'{name}<long>',
f'{name}<unsigned long>',
f'{name}<long long>',
f'{name}<unsigned long long>']
mod = cupy.RawModule(code=code, options=('--std=c++11',),
name_expressions=name_expressions)
ker = mod.get_function(
'_tensordot_core_int_kernel<'+get_typename(dtype)+'>')
ker = mod.get_function(name + '<' + get_typename(dtype) + '>')
return ker


cdef ndarray _integral_tensordot_core(
ndarray a, ndarray b, ndarray out, Py_ssize_t m, Py_ssize_t n,
Py_ssize_t k, str dtype, const shape_t& ret_shape):
@cupy._util.memoize(for_each_device=True)
def _tensordot_core_int_kernel(config, dtype):
code = '''
template<typename T>
__global__ void _tensordot_core_int_kernel(
int M, int N, int K,
const T* A,
const T* B,
T * C)
{
_tensordot_core_int_kernel_impl(M, N, K, A, B, C);
}
'''
name = '_tensordot_core_int_kernel'
return _tensordot_core_int_kernel_impl(config, dtype, code, name)


@cupy._util.memoize(for_each_device=True)
def _tensordot_core_int_batched_kernel(config, dtype):
code = '''
template<typename T>
__global__ void _tensordot_core_int_batched_kernel(
int M, int N, int K,
const T* A[], const T* B[],
T* C[])
{
int batchid = blockIdx.z;
_tensordot_core_int_kernel_impl(
M, N, K, A[batchid], B[batchid], C[batchid]
);
}
'''
name = '_tensordot_core_int_batched_kernel'
return _tensordot_core_int_kernel_impl(config, dtype, code, name)


@cupy._util.memoize(for_each_device=True)
def _tensordot_core_int_strided_batched_kernel(config, dtype):
code = '''
template<typename T>
__global__ void _tensordot_core_int_strided_batched_kernel(
int M, int N, int K,
const T* A, long long strideA,
const T* B, long long strideB,
T * C, long long strideC)
{
int batchid = blockIdx.z;
_tensordot_core_int_kernel_impl(
M, N, K,
&A[batchid * strideA],
&B[batchid * strideB],
&C[batchid * strideC]
);
}
'''
name = '_tensordot_core_int_strided_batched_kernel'
return _tensordot_core_int_kernel_impl(config, dtype, code, name)


cdef tuple _integral_tensordot_core_config():
# TODO(leofang): autotune the tuning parameters here? See the discussion
# in this thread: https://groups.google.com/a/icl.utk.edu/g/magma-user/c/igc66uduTfI # NOQA
dim_x=16
Expand All @@ -325,6 +379,14 @@ cdef ndarray _integral_tensordot_core(
('DIM_XA', dim_xa), ('DIM_YA', dim_ya),
('DIM_XB', dim_xb), ('DIM_YB', dim_yb),
('THR_M', blk_m // dim_x), ('THR_N', blk_n // dim_y))
return config, dim_x, dim_y, blk_m, blk_n


cdef ndarray _integral_tensordot_core(
ndarray a, ndarray b, ndarray out, Py_ssize_t m, Py_ssize_t n,
Py_ssize_t k, str dtype, const shape_t& ret_shape):

config, dim_x, dim_y, blk_m, blk_n = _integral_tensordot_core_config()
kern = _tensordot_core_int_kernel(config, dtype)
args = (m, n, k, a, b, out)
grid = (int(math.ceil(m / blk_m)), int(math.ceil(n / blk_n)), 1)
Expand All @@ -333,6 +395,51 @@ cdef ndarray _integral_tensordot_core(
return out


cdef ndarray _integral_tensordot_core_batched(
ndarray a, ndarray b, ndarray out, Py_ssize_t m, Py_ssize_t n,
Py_ssize_t k, str dtype, Py_ssize_t batch_count):

config, dim_x, dim_y, blk_m, blk_n = _integral_tensordot_core_config()
kern = _tensordot_core_int_batched_kernel(config, dtype)
block = (dim_x, dim_y, 1)
matPtrA = _mat_ptrs(a)
matPtrB = _mat_ptrs(b)
matPtrOut = _mat_ptrs(out)
max_batch_count = 65000
for i in range(0, batch_count, max_batch_count):
ibatch = min(max_batch_count, batch_count - i)
args = (
m, n, k, matPtrA[i:i + ibatch], matPtrB[i:i + ibatch],
matPtrOut[i:i + ibatch])
grid = (int(math.ceil(m / blk_m)), int(math.ceil(n / blk_n)), ibatch)
kern(grid, block, args=args)
return out


cdef ndarray _integral_tensordot_core_strided_batched(
ndarray a, ndarray b, ndarray out, Py_ssize_t m, Py_ssize_t n,
Py_ssize_t k, str dtype, Py_ssize_t batch_count):

config, dim_x, dim_y, blk_m, blk_n = _integral_tensordot_core_config()
kern = _tensordot_core_int_strided_batched_kernel(config, dtype)
block = (dim_x, dim_y, 1)
a = a.reshape((-1,) + a.shape[-2:])
b = b.reshape((-1,) + b.shape[-2:])
out = out.reshape((-1,) + out.shape[-2:])
strideA = _get_stride_for_strided_batched_gemm(a)
strideB = _get_stride_for_strided_batched_gemm(b)
strideOut = _get_stride_for_strided_batched_gemm(out)
max_batch_count = 65000
for i in range(0, batch_count, max_batch_count):
ibatch = min(max_batch_count, batch_count - i)
args = (
m, n, k, a[i:i + ibatch], strideA, b[i:i + ibatch], strideB,
out[i:i + ibatch], strideOut)
grid = (int(math.ceil(m / blk_m)), int(math.ceil(n / blk_n)), ibatch)
kern(grid, block, args=args)
return out


cdef _tensordot_core_mul_sum = ReductionKernel(
'S x, T y', 'U out',
'static_cast<U>(x) * static_cast<U>(y)',
Expand Down Expand Up @@ -680,7 +787,7 @@ cpdef ndarray _mat_ptrs(ndarray a):
"""Creates an array of pointers to matrices
Args:
a: A batch of matrices on GPU.
shape: (A, B, C) -> A ptrs to mat o size (B, C)
shape: (A, B, C) -> A ptrs to mat of size (B, C)
shape: (A_1, ..., A_N, B, C) -> A_1*...*A_N ptrs to mat of
size (B, C)
Returns:
Expand Down Expand Up @@ -781,7 +888,9 @@ cpdef ndarray matmul(ndarray a, ndarray b, ndarray out=None):
True, True)

ret_dtype = numpy.promote_types(a.dtype, b.dtype)
dtype = numpy.promote_types(ret_dtype, 'f')
dtype = ret_dtype
if dtype.char == 'e':
dtype = numpy.dtype('f')

a = ascontiguousarray(a, dtype)
b = ascontiguousarray(b, dtype)
Expand Down Expand Up @@ -872,6 +981,17 @@ cpdef ndarray matmul(ndarray a, ndarray b, ndarray out=None):
else:
c_view = c

if dtype.char not in 'efdFD':
if not use_broadcast:
_integral_tensordot_core_strided_batched(
a, b, c_view, n, m, ka, dtype.char, batchCount)
else:
_integral_tensordot_core_batched(
a, b, c_view, n, m, ka, dtype.char, batchCount)
if out is not c:
elementwise_copy(c, out)
return out

global _cuda_runtime_version
if _cuda_runtime_version < 0:
_cuda_runtime_version = runtime.runtimeGetVersion()
Expand Down
5 changes: 5 additions & 0 deletions tests/cupy_tests/core_tests/test_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,11 @@ class TestRaw(unittest.TestCase):

def setUp(self):
if hasattr(self, 'clean_up'):
if cupy.cuda.runtime.is_hip:
# Clearing memo triggers recompiling kernels using name
# expressions in other tests, e.g. dot and matmul, which
# hits a nvrtc bug. See #5843, #5945 and #6725.
self.skipTest('Clearing memo hits a nvrtc bug in other tests')
_util.clear_memo()
self.dev = cupy.cuda.runtime.getDevice()
assert self.dev != 1
Expand Down
22 changes: 22 additions & 0 deletions tests/cupy_tests/math_tests/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,28 @@ def test_cupy_matmul(self, xp, dtype1, dtype2):
return xp.matmul(x1, x2)


@pytest.mark.parametrize('shape1,shape2', [
((256, 256, 3, 2), (256, 256, 2, 4)),
((256, 256, 3, 2), (2, 4)),
((3, 2), (256, 256, 2, 4))
])
class TestMatmulIntegralLargeBatch:

@testing.for_int_dtypes(name='dtype')
@testing.numpy_cupy_array_equal()
def test_operator_matmul(self, xp, dtype, shape1, shape2):
x1 = testing.shaped_random(shape1, xp, dtype)
x2 = testing.shaped_random(shape2, xp, dtype)
return operator.matmul(x1, x2)

@testing.for_int_dtypes(name='dtype')
@testing.numpy_cupy_array_equal()
def test_cupy_matmul(self, xp, dtype, shape1, shape2):
x1 = testing.shaped_random(shape1, xp, dtype)
x2 = testing.shaped_random(shape2, xp, dtype)
return xp.matmul(x1, x2)


class TestMatmulOverflow(unittest.TestCase):

@testing.for_int_dtypes(name='dtype', no_bool=True)
Expand Down

0 comments on commit 3719ba9

Please sign in to comment.