Skip to content

Commit

Permalink
Merge pull request #8243 from kmaehashi/fix-random-generator-32bit
Browse files Browse the repository at this point in the history
Fix overflow of index calculation in random generator API
  • Loading branch information
takagi committed Mar 18, 2024
2 parents 4179286 + b8dbcef commit cf4d4e9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
8 changes: 4 additions & 4 deletions cupy/random/cupy_distributions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ template<typename T>
struct array_data {}; // opaque type always used as a pointer type

template<typename T>
__device__ T get_index(array_data<T> *value, int id, ssize_t state_size) {
__device__ T get_index(array_data<T> *value, ssize_t id, ssize_t state_size) {
int64_t* data = reinterpret_cast<int64_t*>(value);
intptr_t ptr = reinterpret_cast<intptr_t>(data[0]);
int ndim = data[1];
Expand All @@ -772,11 +772,11 @@ __device__ T get_index(array_data<T> *value, int id, ssize_t state_size) {
}

template<typename T>
__device__ typename std::enable_if<std::is_arithmetic<T>::value, T>::type get_index(T value, int id, ssize_t state_size) {
__device__ typename std::enable_if<std::is_arithmetic<T>::value, T>::type get_index(T value, ssize_t id, ssize_t state_size) {
return value;
}

__device__ rk_binomial_state* get_index(rk_binomial_state *value, int id, ssize_t state_size) {
__device__ rk_binomial_state* get_index(rk_binomial_state *value, ssize_t id, ssize_t state_size) {
return (value + id % state_size);
}

Expand All @@ -785,7 +785,7 @@ __global__ void execute_dist( intptr_t state, ssize_t state_size, intptr_t out,
R* out_ptr = reinterpret_cast<R*>(out);
F func;
T random(blockIdx.x * blockDim.x + threadIdx.x, state);
for (int id = blockIdx.x * blockDim.x + threadIdx.x;
for (ssize_t id = blockIdx.x * blockDim.x + threadIdx.x;
id < size;
id += state_size) {
out_ptr[id] = func(random, (get_index(args, id, state_size))...);
Expand Down
7 changes: 7 additions & 0 deletions tests/cupy_tests/random_tests/test_generator_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,10 @@ class TestDrichlet(
GeneratorTestCase
):
pass


@testing.slow
class TestLarge:
def test_large(self):
gen = random.Generator(random.XORWOW(1234))
gen.random(2**31 + 1, dtype=cupy.int8)

0 comments on commit cf4d4e9

Please sign in to comment.