Skip to content

Commit

Permalink
Merge pull request #8103 from asi1024/cumsum-cuda
Browse files Browse the repository at this point in the history
Add cuda `ScanKernel`
  • Loading branch information
mergify[bot] committed Sep 13, 2019
2 parents 9abccfd + ce67abe commit d4965bb
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 8 deletions.
16 changes: 13 additions & 3 deletions chainerx_cc/chainerx/cuda/cuda_device/reduction.cu
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,19 @@ CHAINERX_CUDA_REGISTER_KERNEL(SumKernel, CudaSumKernel);

class CudaCumsumKernel : public CumsumKernel {
public:
void Call(const Array& /* a */, int8_t /* axis */, const Array& /* out */) override {
// TODO(aksub99): CUDA Implementation is to be suppported.
throw NotImplementedError{"CUDA Implementation is not yet supported."};
void Call(const Array& a, int8_t axis, const Array& out) override {
Device& device = a.device();
CHAINERX_ASSERT(a.shape() == out.shape());
device.CheckDevicesCompatible(a, out);
CudaSetDeviceScope scope{device.index()};

auto do_sum = [&a, &axis, &out](auto in_pt, auto out_pt) {
using In = typename decltype(in_pt)::type;
using Out = typename decltype(out_pt)::type;
Scan<In, Out>(a, axis, out, SumImpl<In, Out>{});
};

VisitDtype(out.dtype(), [a_dtype = a.dtype(), &do_sum](auto out_pt) { VisitDtype(a_dtype, do_sum, out_pt); });
}
};

Expand Down
103 changes: 103 additions & 0 deletions chainerx_cc/chainerx/cuda/reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,70 @@ __global__ void ReductionKernel(
}
}

template <typename In, typename Out, typename ReductionImpl, int8_t InNdim = kDynamicNdim, int8_t OutNdim = kDynamicNdim>
__global__ void ScanKernel(
ReductionKernelArg<In, Out, InNdim, OutNdim> arg,
int out_block_size,
int reduce_block_size,
ReductionImpl impl,
int64_t reduce_len) {
int tid = threadIdx.x;

int64_t len = arg.out_indexer.total_size() / reduce_len;
int64_t reduce_block_offset = tid / out_block_size;
int64_t reduce_offset = reduce_block_offset * len;
int64_t reduce_stride = reduce_block_size * len;

int64_t out_offset = tid % out_block_size;
int64_t out_base = blockIdx.x * out_block_size;
int64_t out_stride = gridDim.x * out_block_size;

auto reduce = [&impl, &arg](auto& it_from, auto& it_to) {
auto from = cuda_internal::StorageToDataType<Out>(arg.out[it_from]);
auto& to = cuda_internal::StorageToDataType<Out>(arg.out[it_to]);
impl.Reduce(from, to);
++it_from;
++it_to;
};

for (int64_t i = out_base + out_offset; i < len; i += out_stride) {
// Copy input array to output array
auto it_in = arg.in_indexer.It(i + reduce_offset, reduce_stride);
auto it_out = arg.out_indexer.It(i + reduce_offset, reduce_stride);
for (int64_t j = reduce_block_offset; j < reduce_len; j += reduce_block_size, ++it_in, ++it_out) {
auto value = cuda_internal::StorageToDataType<const In>(arg.in[it_in]);
arg.out[it_out] = cuda_internal::DataToStorageType<Out>(impl.MapIn(value, j));
}
__syncthreads();

int64_t stride = 1;

// Up-Sweep Phase
for (stride = 1; stride * 2 <= reduce_len; stride <<= 1) {
int64_t index_from = reduce_block_offset * stride * 2 + stride - 1;
int64_t index_to = index_from + stride;
auto it_from = arg.out_indexer.It(i + index_from * len, reduce_stride * stride * 2);
auto it_to = arg.out_indexer.It(i + index_to * len, reduce_stride * stride * 2);
for (int64_t j = index_to; j < reduce_len; j += reduce_block_size * stride * 2) {
reduce(it_from, it_to);
}
__syncthreads();
}

// Down-Sweep Phase
for (; stride >= 1; stride >>= 1) {
int64_t index_from = reduce_block_offset * stride * 2 + stride * 2 - 1;
int64_t index_to = index_from + stride;
auto it_from = arg.out_indexer.It(i + index_from * len, reduce_stride * stride * 2);
auto it_to = arg.out_indexer.It(i + index_to * len, reduce_stride * stride * 2);
for (int64_t j = index_to; j < reduce_len; j += reduce_block_size * stride * 2) {
reduce(it_from, it_to);
}
__syncthreads();
}
}
}

} // namespace reduce_detail

// Computes the reduction of the input and stores into the output array.
Expand Down Expand Up @@ -188,5 +252,44 @@ void Reduce(const Array& in, const Axes& axis, const Array& out, ReductionImpl&&
MakeReductionKernelArg<In, Out>(arg), out_block_size, reduce_block_size, impl);
}

template <typename In, typename Out, typename ReductionImpl>
void Scan(const Array& in, int8_t axis, const Array& out, ReductionImpl&& impl) {
if (out.GetTotalSize() == 0) {
return;
}

ReductionArg arg{in, Axes{axis}, out};
int64_t reduce_len = in.shape()[axis];

// TODO(niboshi): Calculate kMaxBlockSize per device
std::lock_guard<std::mutex> lock{*cuda_internal::g_mutex};
static const int64_t kMaxBlockSize = std::min(
reduce_detail::kMaxReductionBlockSize,
CudaOccupancyMaxPotentialBlockSize(&reduce_detail::ReductionKernel<In, Out, ReductionImpl>).block_size);

int64_t reduce_total_size_pow2 = reduce_detail::RoundUpToPowerOf2(std::max(int64_t{1}, reduce_len));

int64_t reduce_block_size = std::min(kMaxBlockSize, reduce_total_size_pow2);
int64_t out_block_size = kMaxBlockSize / reduce_block_size;
int64_t out_block_num = (arg.in_shape().GetTotalSize() / reduce_len + out_block_size - 1) / out_block_size;

int64_t block_size = kMaxBlockSize;
int64_t grid_size = std::min(reduce_detail::kMaxGridSize, out_block_num);
int64_t shared_mem_size = sizeof(decltype(impl.Identity())) * block_size;

#ifdef NDEBUG // Optimize only in Release build to save time on development
// TODO(sonots): Reconsider the number of statically-optimized kernels in terms of speed and binary size trade-offs.
// Currently, only contiguous output arrays are optimized.
if (arg.in_strides().ndim() == 1 && arg.out_strides().ndim() == 1) {
reduce_detail::ScanKernel<<<grid_size, block_size, shared_mem_size>>>(
MakeReductionKernelArg<In, Out, 1, 1>(arg), out_block_size, reduce_block_size, impl, reduce_len);
return;
}
#endif // NDEBUG

reduce_detail::ScanKernel<<<grid_size, block_size, shared_mem_size>>>(
MakeReductionKernelArg<In, Out>(arg), out_block_size, reduce_block_size, impl, reduce_len);
}

} // namespace cuda
} // namespace chainerx
6 changes: 3 additions & 3 deletions chainerx_cc/chainerx/native/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,15 @@ void ReductionKernel(ReductionKernelArg<In, Out, InNdim, OutNdim> arg, Reduction
}

template <typename In, typename Out, typename ReductionImpl, int8_t InNdim = kDynamicNdim, int8_t OutNdim = kDynamicNdim>
void ScanKernel(ReductionKernelArg<In, Out, InNdim, OutNdim> arg, ReductionImpl&& impl, int64_t reduce_dim) {
int64_t len = arg.in_indexer.total_size() / reduce_dim;
void ScanKernel(ReductionKernelArg<In, Out, InNdim, OutNdim> arg, ReductionImpl&& impl, int64_t reduce_len) {
int64_t len = arg.in_indexer.total_size() / reduce_len;
auto it_in = arg.in_indexer.It(0, len);
auto it_out = arg.out_indexer.It(0, len);
for (int64_t i = 0; i < len; ++i) {
it_in.Restart(i);
it_out.Restart(i);
auto accum = impl.Identity();
for (int64_t j = 0; j < reduce_dim; ++j, ++it_in, ++it_out) {
for (int64_t j = 0; j < reduce_len; ++j, ++it_in, ++it_out) {
auto in = native_internal::StorageToDataType<const In>(arg.in[it_in]);
impl.Reduce(impl.MapIn(in, i), accum);
arg.out[it_out] = native_internal::DataToStorageType<Out>(impl.MapOut(accum));
Expand Down
10 changes: 8 additions & 2 deletions tests/chainerx_tests/unit_tests/routines_tests/test_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
((2, 3, 4), -2),
((2, 3, 4), -1),
((2, 3, 4), None),
((100000, 2), None),
((100000, 2), 0),
((100000, 2), 1),
]


Expand Down Expand Up @@ -269,11 +272,10 @@ def test_log_softmax_invalid(device, a_shape, axis, dtype):
return chainerx.log_softmax(a, axis=axis)


@op_utils.op_test(['native:0'])
@op_utils.op_test(['native:0', 'cuda:0'])
@chainer.testing.parameterize_pytest(
'in_dtypes,out_dtype', _in_out_dtypes_sum)
@chainer.testing.parameterize_pytest('shape,axis', _cumsum_params)
# TODO(aksub99): Add cuda device tests when cuda implementation is supported.
class TestCumsum(math_utils.UnaryMathTestBase, op_utils.NumpyOpTest):

input = 'random'
Expand All @@ -287,6 +289,10 @@ def setup(self):
self.check_double_backward_options.update(
{'rtol': 1e-2, 'atol': 1e-2})

if (numpy.dtype(in_dtype).kind in ('float16, float32')
and numpy.prod(self.shape) > 1000):
pytest.skip('Skip large tests for float16/float32 dtypes')

def func(self, xp, a):
return xp.cumsum(a, axis=self.axis)

Expand Down

0 comments on commit d4965bb

Please sign in to comment.