Skip to content

Commit

Permalink
Merge pull request #31 from cupy/support-64bit-address
Browse files Browse the repository at this point in the history
Support 64bit address
  • Loading branch information
niboshi committed May 23, 2017
2 parents a1e49b6 + d43d118 commit 5bd9b98
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 41 deletions.
94 changes: 68 additions & 26 deletions cupy/core/carray.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -135,28 +135,28 @@ __device__ float16 nextafter(float16 x, float16 y) {return float16::nextafter(x,

// CArray
#define CUPY_FOR(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
for (ptrdiff_t i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)

template <typename T, int ndim>
class CArray {
private:
T* data_;
int size_;
int shape_[ndim];
int strides_[ndim];
ptrdiff_t size_;
ptrdiff_t shape_[ndim];
ptrdiff_t strides_[ndim];

public:
__device__ int size() const {
return size_;
}

__device__ const int* shape() const {
__device__ const ptrdiff_t* shape() const {
return shape_;
}

__device__ const int* strides() const {
__device__ const ptrdiff_t* strides() const {
return strides_;
}

Expand All @@ -172,7 +172,19 @@ public:
return (*const_cast<CArray<T, ndim>*>(this))[idx];
}

__device__ T& operator[](int i) {
__device__ T& operator[](const ptrdiff_t* idx) {
char* ptr = reinterpret_cast<char*>(data_);
for (int dim = 0; dim < ndim; ++dim) {
ptr += strides_[dim] * idx[dim];
}
return *reinterpret_cast<T*>(ptr);
}

__device__ T operator[](const ptrdiff_t* idx) const {
return (*const_cast<CArray<T, ndim>*>(this))[idx];
}

__device__ T& operator[](ptrdiff_t i) {
char* ptr = reinterpret_cast<char*>(data_);
for (int dim = ndim; --dim > 0; ) {
ptr += static_cast<ptrdiff_t>(strides_[dim]) * (i % shape_[dim]);
Expand All @@ -185,7 +197,7 @@ public:
return *reinterpret_cast<T*>(ptr);
}

__device__ T operator[](int i) const {
__device__ T operator[](ptrdiff_t i) const {
return (*const_cast<CArray<T, ndim>*>(this))[i];
}
};
Expand All @@ -194,13 +206,21 @@ template <typename T>
class CArray<T, 0> {
private:
T* data_;
int size_;
ptrdiff_t size_;

public:
__device__ int size() const {
return size_;
}

__device__ const ptrdiff_t* shape() const {
return NULL;
}

__device__ const ptrdiff_t* strides() const {
return NULL;
}

__device__ T& operator[](const int* idx) {
return *reinterpret_cast<T*>(data_);
}
Expand All @@ -209,58 +229,80 @@ public:
return (*const_cast<CArray<T, 0>*>(this))[idx];
}

__device__ T& operator[](int i) {
__device__ T& operator[](const ptrdiff_t* idx) {
return *reinterpret_cast<T*>(data_);
}

__device__ T operator[](const ptrdiff_t* idx) const {
return (*const_cast<CArray<T, 0>*>(this))[idx];
}

__device__ T& operator[](ptrdiff_t i) {
return *reinterpret_cast<T*>(data_);
}

__device__ T operator[](int i) const {
__device__ T operator[](ptrdiff_t i) const {
return (*const_cast<CArray<T, 0>*>(this))[i];
}
};

template <int ndim>
class CIndexer {
private:
int size_;
int shape_[ndim];
int index_[ndim];
ptrdiff_t size_;
ptrdiff_t shape_[ndim];
ptrdiff_t index_[ndim];

public:
__device__ int size() const {
__device__ ptrdiff_t size() const {
return size_;
}

__device__ void set(int i) {
unsigned int a = i;
for (int dim = ndim; --dim > 0; ) {
unsigned int s = shape_[dim];
index_[dim] = (a % s);
a /= s;
__device__ void set(ptrdiff_t i) {
// ndim == 0 case uses partial template specialization
if (ndim == 1) {
index_[0] = i;
return;
}
if (ndim > 0) {
if (size_ > 1LL << 31) {
// 64-bit division is very slow on GPU
size_t a = static_cast<size_t>(i);
for (int dim = ndim; --dim > 0; ) {
size_t s = static_cast<size_t>(shape_[dim]);
index_[dim] = a % s;
a /= s;
}
index_[0] = a;
} else {
unsigned int a = static_cast<unsigned int>(i);
for (int dim = ndim; --dim > 0; ) {
unsigned int s = static_cast<unsigned int>(shape_[dim]);
index_[dim] = a % s;
a /= s;
}
index_[0] = a;
}
}

__device__ const int* get() const {
__device__ const ptrdiff_t* get() const {
return index_;
}
};

template <>
class CIndexer<0> {
private:
int size_;
ptrdiff_t size_;

public:
__device__ int size() const {
return size_;
}

__device__ void set(int i) {
__device__ void set(ptrdiff_t i) {
}

__device__ const int* get() const {
__device__ const ptrdiff_t* get() const {
return NULL;
}
};
Expand Down
11 changes: 6 additions & 5 deletions cupy/core/carray.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ from cupy.cuda cimport function

cdef struct _CArray:
void* data
int size
int shape_and_strides[MAX_NDIM * 2]
Py_ssize_t size
Py_ssize_t shape_and_strides[MAX_NDIM * 2]


cdef class CArray(CPointer):
Expand All @@ -17,7 +17,8 @@ cdef class CArray(CPointer):
_CArray val

def __init__(self, ndarray arr):
cdef Py_ssize_t i, ndim = arr.ndim
cdef Py_ssize_t i
cdef int ndim = arr.ndim
self.val.data = <void*>arr.data.ptr
self.val.size = arr.size
for i in range(ndim):
Expand All @@ -27,8 +28,8 @@ cdef class CArray(CPointer):


cdef struct _CIndexer:
int size
int shape_and_index[MAX_NDIM * 2]
Py_ssize_t size
Py_ssize_t shape_and_index[MAX_NDIM * 2]


cdef class CIndexer(CPointer):
Expand Down
18 changes: 9 additions & 9 deletions cupy/core/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2212,8 +2212,8 @@ cdef _concatenate_kernel = ElementwiseKernel(
axis_ind -= cum_sizes[left];
char* ptr = reinterpret_cast<char*>(x[array_ind]);
for (int j = ndim - 1; j >= 0; --j) {
int ind[] = {array_ind, j};
int offset;
ptrdiff_t ind[] = {array_ind, j};
ptrdiff_t offset;
if (j == axis) {
offset = axis_ind;
} else {
Expand Down Expand Up @@ -2322,8 +2322,8 @@ cdef _take_kernel = ElementwiseKernel(
S wrap_indices = indices % index_range;
if (wrap_indices < 0) wrap_indices += index_range;
int li = i / (rdim * cdim);
int ri = i % rdim;
ptrdiff_t li = i / (rdim * cdim);
ptrdiff_t ri = i % rdim;
out = a[(li * adim + wrap_indices) * rdim + ri];
''',
'cupy_take')
Expand Down Expand Up @@ -2369,8 +2369,8 @@ cdef _scatter_update_kernel = ElementwiseKernel(
'''
S wrap_indices = indices % adim;
if (wrap_indices < 0) wrap_indices += adim;
int li = i / (rdim * cdim);
int ri = i % rdim;
ptrdiff_t li = i / (rdim * cdim);
ptrdiff_t ri = i % rdim;
a[(li * adim + wrap_indices) * rdim + ri] = v;
''',
'cupy_scatter_update')
Expand All @@ -2382,8 +2382,8 @@ cdef _scatter_add_kernel = ElementwiseKernel(
'''
S wrap_indices = indices % adim;
if (wrap_indices < 0) wrap_indices += adim;
int li = i / (rdim * cdim);
int ri = i % rdim;
ptrdiff_t li = i / (rdim * cdim);
ptrdiff_t ri = i % rdim;
atomicAdd(&a[(li * adim + wrap_indices) * rdim + ri], v[i]);
''',
'cupy_scatter_add')
Expand Down Expand Up @@ -3676,7 +3676,7 @@ def _nonzero_1d_kernel(src_dtype, index_dtype):
const CArray<${index_dtype}, 1> scaned_index,
CArray<${index_dtype}, 1> dst){
int thid = blockIdx.x * blockDim.x + threadIdx.x;
int n = src.size();
ptrdiff_t n = src.size();
if (thid < n){
if (src[thid] != 0){
dst[scaned_index[thid] - 1] = thid;
Expand Down
1 change: 1 addition & 0 deletions cupy/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@

gpu = attr.gpu
multi_gpu = attr.multi_gpu
slow = attr.slow
1 change: 1 addition & 0 deletions cupy/testing/attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

gpu = attrib.attr('gpu')
cudnn = attrib.attr('gpu', 'cudnn')
slow = attrib.attr('slow')


def multi_gpu(gpu_num):
Expand Down
2 changes: 1 addition & 1 deletion tests/cupy_tests/core_tests/test_carray.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_getitem_idx(self):
y = cupy.empty_like(x)
y = cupy.ElementwiseKernel(
'raw T x', 'int32 y',
'int idx[] = {i / 12, i / 4 % 3, i % 4}; y = x[idx]',
'ptrdiff_t idx[] = {i / 12, i / 4 % 3, i % 4}; y = x[idx]',
'test_carray_getitem_idx',
)(x, y)
testing.assert_array_equal(y, x)
24 changes: 24 additions & 0 deletions tests/cupy_tests/creation_tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@ def test_empty(self, xp, dtype, order):
a.fill(0)
return a

@testing.slow
def test_empty_huge_size(self):
a = cupy.empty((1024, 2048, 1024), dtype='b')
a.fill(123)
self.assertTrue((a == 123).all())

@testing.slow
def test_empty_huge_size_fill0(self):
a = cupy.empty((1024, 2048, 1024), dtype='b')
a.fill(0)
self.assertTrue((a == 0).all())

@testing.for_CF_orders()
@testing.for_all_dtypes()
@testing.numpy_cupy_array_equal()
Expand All @@ -35,6 +47,18 @@ def test_empty_int(self, xp, dtype, order):
a.fill(0)
return a

@testing.slow
def test_empty_int_huge_size(self):
a = cupy.empty(2 ** 31, dtype='b')
a.fill(123)
self.assertTrue((a == 123).all())

@testing.slow
def test_empty_int_huge_size_fill0(self):
a = cupy.empty(2 ** 31, dtype='b')
a.fill(0)
self.assertTrue((a == 0).all())

@testing.for_all_dtypes()
@testing.numpy_cupy_array_equal()
def test_empty_like(self, xp, dtype):
Expand Down
8 changes: 8 additions & 0 deletions tests/cupy_tests/math_tests/test_sumprod.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ def test_sum_axis(self, xp, dtype):
a = testing.shaped_arange((2, 3, 4), xp, dtype)
return a.sum(axis=1)

@testing.slow
@testing.with_requires('numpy>=1.10')
@testing.numpy_cupy_allclose()
def test_sum_axis_huge(self, xp):
a = testing.shaped_random((2048, 1, 1024), xp, 'b')
a = xp.broadcast_to(a, (2048, 1024, 1024))
return a.sum(axis=0)

@testing.for_all_dtypes()
@testing.numpy_cupy_allclose()
def test_external_sum_axis(self, xp, dtype):
Expand Down

0 comments on commit 5bd9b98

Please sign in to comment.