Skip to content

Commit

Permalink
Merge pull request #8031 from asi1024/fix-large-argmax
Browse files Browse the repository at this point in the history
Fix argmax/argmin for large reduction axis
  • Loading branch information
takagi committed Dec 20, 2023
2 parents 1ebddbe + bfd0c34 commit 90a566c
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 12 deletions.
4 changes: 2 additions & 2 deletions cupy/_core/_cub_reduction.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,9 @@ __global__ void ${name}(${params}) {
// some pre_map_expr uses _J internally...
#if defined FIRST_PASS
int _J = (segment_idx + i + e_idx);
IndexT _J = (segment_idx + i + e_idx);
#else // only one pass
int _J = (segment_idx + i + e_idx) % _seg_size;
IndexT _J = (segment_idx + i + e_idx) % _seg_size;
#endif
if (e_idx < tile_size) {
Expand Down
12 changes: 10 additions & 2 deletions cupy/_core/_fusion_kernel.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -336,15 +336,23 @@ cdef class FusedKernel:
self._cuda_params_memo[key] = ret
return ret

def _get_typedefs(self, tuple args):
index_type = 'int'
for array in args:
if isinstance(array, _cupy.ndarray) and array.size > 0x7fffffff:
index_type = 'long long'
return f'typedef {index_type} IndexT;\n'

def execute(self, tuple args, list shapes):
ndarray_list = self._get_ndarray_list(args, shapes)
ret = self._get_return_value(ndarray_list)
reduce_key = self._reduce_dims(ndarray_list)
inout_args = self._get_inout_args(args, ndarray_list)
cuda_params = self._get_cuda_params(reduce_key, ndarray_list)
typedef = self._get_typedefs(args)
kern = _cuda_compile(
self._submodule_code, self._name, cuda_params, self._cuda_body,
self._use_grid_sync)
typedef + self._submodule_code,
self._name, cuda_params, self._cuda_body, self._use_grid_sync)

block_strides, block_size, shared_mem = (
self._get_kernel_size(ndarray_list))
Expand Down
4 changes: 2 additions & 2 deletions cupy/_core/_fusion_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,9 @@ def emit_submodule_codes(self):
extern __shared__ char _sdata_raw[];
_type_reduce *sdata = reinterpret_cast<_type_reduce*>(_sdata_raw);
unsigned int tid = threadIdx.x;
int _J = tid >> __popc(block_stride - 1);
IndexT _J = tid >> __popc(block_stride - 1);
ptrdiff_t _j = (ptrdiff_t)_J * out_ind.size();
int J_stride = blockDim.x >> __popc(block_stride - 1);
IndexT J_stride = blockDim.x >> __popc(block_stride - 1);
ptrdiff_t j_stride = (ptrdiff_t)J_stride * out_ind.size();
for (ptrdiff_t _i = (ptrdiff_t)blockIdx.x * block_stride; _i < out_ind.size(); _i += (ptrdiff_t)gridDim.x * block_stride) {
Expand Down
2 changes: 1 addition & 1 deletion cupy/_core/_kernel.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ cdef class _TypeMap:
# Typedef mapping between C types.
# This class is immutable.

cdef:
cdef public:
tuple _pairs

cdef str get_typedef_code(self)
Expand Down
12 changes: 9 additions & 3 deletions cupy/_core/_reduction.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ extern "C" __global__ void ${name}(${params}) {
_type_reduce *_sdata = reinterpret_cast<_type_reduce*>(_sdata_raw);
unsigned int _tid = threadIdx.x;
int _J_offset = _tid >> __popc(_block_stride - 1); // _tid / _block_stride
IndexT _J_offset = _tid >> __popc(_block_stride - 1); // _tid / _block_stride
ptrdiff_t _j_offset = (ptrdiff_t)_J_offset * _out_ind.size();
int _J_stride = ${block_size} >> __popc(_block_stride - 1);
IndexT _J_stride = ${block_size} >> __popc(_block_stride - 1);
ptrdiff_t _j_stride = (ptrdiff_t)_J_stride * _out_ind.size();
for (ptrdiff_t _i_base = (ptrdiff_t)blockIdx.x * _block_stride;
Expand All @@ -78,7 +78,7 @@ extern "C" __global__ void ${name}(${params}) {
_type_reduce _s = _type_reduce(${identity});
ptrdiff_t _i =
_i_base + (_tid & (_block_stride - 1)); // _tid % _block_stride
int _J = _J_offset;
IndexT _J = _J_offset;
for (ptrdiff_t _j = _i + _j_offset; _j < _in_ind.size();
_j += _j_stride, _J += _J_stride) {
_in_ind.set(_j);
Expand Down Expand Up @@ -346,6 +346,12 @@ cdef class _AbstractReductionKernel:
raise ValueError(('zero-size array to reduction operation'
' %s which has no identity') % self.name)

if internal.prod(a_shape) / internal.prod(out_shape) > 0x7fffffff:
index_type = ('IndexT', 'int64')
else:
index_type = ('IndexT', 'int32')
type_map = _kernel._TypeMap(type_map._pairs + (index_type,))

in_args = [x if isinstance(x, _ndarray_base) else
_scalar.CScalar.from_numpy_scalar_with_dtype(x, t)
for x, t in zip(in_args, in_types)]
Expand Down
4 changes: 2 additions & 2 deletions cupy/_core/_routines_statistics.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,10 @@ cdef _min_max_preamble = '''
template <typename T>
struct min_max_st{
T value;
int index;
IndexT index;
__device__ min_max_st() : index(-1) { }
__device__ min_max_st(T v) : value(v), index(0) { }
__device__ min_max_st(T v, int i) : value(v), index(i) { }
__device__ min_max_st(T v, IndexT(i)) : value(v), index(i) { }
};
template <typename T>
Expand Down
11 changes: 11 additions & 0 deletions tests/cupy_tests/sorting_tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def test_argmax_zero_size_axis1(self, xp, dtype):
a = testing.shaped_random((0, 1), xp, dtype)
return a.argmax(axis=1)

@testing.slow
def test_argmax_int32_overflow(self):
a = testing.shaped_arange((2 ** 32 + 1,), cupy, numpy.float64)
assert a.argmax().item() == 2 ** 32

@testing.for_all_dtypes(no_complex=True)
@testing.numpy_cupy_allclose()
def test_argmin_all(self, xp, dtype):
Expand Down Expand Up @@ -157,6 +162,12 @@ def test_argmin_zero_size_axis1(self, xp, dtype):
a = testing.shaped_random((0, 1), xp, dtype)
return a.argmin(axis=1)

@testing.slow
def test_argmin_int32_overflow(self):
a = testing.shaped_arange((2 ** 32 + 1,), cupy, numpy.float64)
cupy.negative(a, out=a)
assert a.argmin().item() == 2 ** 32


# TODO(leofang): remove this once CUDA 9.0 is dropped
def _skip_cuda90(dtype):
Expand Down

0 comments on commit 90a566c

Please sign in to comment.