diff --git a/chainerx_cc/chainerx/index_iterator.h b/chainerx_cc/chainerx/index_iterator.h index 4122fa75f22f..9c7f5c68e729 100644 --- a/chainerx_cc/chainerx/index_iterator.h +++ b/chainerx_cc/chainerx/index_iterator.h @@ -64,10 +64,28 @@ class IndexIterator { CHAINERX_HOST_DEVICE void Set(int64_t i) { CHAINERX_ASSERT(total_size_ > 0); raw_index_ = i; +#ifdef __CUDA_ARCH__ + // TODO(ecastill) add 32-bit case + // 64-bit division is very slow on GPU + uint64_t a = static_cast(i); + for (int8_t dim = kNdim; --dim > 0;) { + uint64_t s = static_cast(shape_[dim]); + if (s & (s - 1)) { + uint64_t t = a / s; + index_[dim] = static_cast(a - t * s); + a = t; + } else { // exp of 2 + index_[dim] = static_cast(a & (s - 1)); + a >>= __popcll(s - 1); + } + } + index_[0] = a; +#else for (int8_t j = kNdim; --j >= 0;) { index_[j] = i % shape_[j]; // NOLINT(cppcoreguidelines-pro-bounds-constant-array-index) i /= shape_[j]; } +#endif } const int64_t* shape_; @@ -128,6 +146,8 @@ class IndexIterator<0> { CHAINERX_HOST_DEVICE const int64_t* index() const { return &raw_index_; } + CHAINERX_HOST_DEVICE void Set(int64_t i) { raw_index_ = i; } + private: int64_t raw_index_{0}; }; @@ -247,10 +267,31 @@ class IndexIterator { CHAINERX_HOST_DEVICE void Set(int64_t i) { CHAINERX_ASSERT(total_size_ > 0); raw_index_ = i; + if (ndim_ == 0) { + return; + } +#ifdef __CUDA_ARCH__ + // TODO(ecastill) add 32-bit case + // 64-bit division is very slow on GPU + uint64_t a = static_cast(i); + for (int8_t dim = ndim_; --dim > 0;) { + uint64_t s = static_cast(shape_[dim]); + if (s & (s - 1)) { + uint64_t t = a / s; + index_[dim] = static_cast(a - t * s); + a = t; + } else { // exp of 2 + index_[dim] = static_cast(a & (s - 1)); + a >>= __popcll(s - 1); + } + } + index_[0] = a; +#else for (int8_t j = ndim_; --j >= 0;) { index_[j] = i % shape_[j]; // NOLINT(cppcoreguidelines-pro-bounds-constant-array-index) i /= shape_[j]; } +#endif } const int64_t* shape_; diff --git a/tests/chainerx_tests/unit_tests/test_array_index.py b/tests/chainerx_tests/unit_tests/test_array_index.py index 43d3e65cfe0f..46bfa126763d 100644 --- a/tests/chainerx_tests/unit_tests/test_array_index.py +++ b/tests/chainerx_tests/unit_tests/test_array_index.py @@ -6,6 +6,29 @@ def test_newaxis(): assert chainerx.newaxis is None +@pytest.mark.parametrize('xp', [chainerx]) +@pytest.mark.parametrize_device(['native:0', 'cuda:0']) +@pytest.mark.parametrize('shape, transpose', [ + ((1,), None), + ((2,), None), + ((2, 3), None), + ((2, 3, 4), None), + ((2, 3, 4, 5), None), + ((2, 3, 4, 5, 6), None), + ((2, 3), (0, 1)), + ((2, 3, 4), (0, 2)), + ((2, 3, 4, 5), (0, 2)), + ((2, 3, 4, 5, 6), (1, 3)), +]) +def test_array_indexing(xp, device, shape, transpose): + a = xp.zeros(shape=shape, dtype=chainerx.int8, device=device) + if transpose: + a = a.swapaxes(*transpose) + assert not a.is_contiguous + a += 1 + assert a.sum() == a.size + + @pytest.mark.slow @pytest.mark.parametrize('xp', [chainerx]) @pytest.mark.parametrize_device(['cuda:0']) @@ -13,7 +36,7 @@ def test_newaxis(): (64, 32, 6*1024*4), # Less than 2^32 elems (64, 32, 6*1024*512), # More than 2^32 elems ]) -def test_array_contiguous_indexing(xp, device, shape): +def test_large_array_contiguous_indexing(xp, device, shape): try: a = xp.zeros(shape=shape, dtype=chainerx.int8, device=device) except chainerx.ChainerxError as ex: @@ -31,7 +54,7 @@ def test_array_contiguous_indexing(xp, device, shape): (64, 32, 6*1024*4), # Less than 2^32 elems (64, 32, 6*1024*512) # More than 2^32 elems ]) -def test_array_noncontiguous_indexing(xp, device, shape): +def test_large_array_noncontiguous_indexing(xp, device, shape): try: a = xp.zeros(shape=shape, dtype=chainerx.int8, device=device) except chainerx.ChainerxError as ex: