Skip to content

Commit

Permalink
Merge pull request #41 from cupy/merge-v1.24
Browse files Browse the repository at this point in the history
Merge v1.24
  • Loading branch information
niboshi committed May 16, 2017
2 parents 39a9dd5 + da32e3b commit 0049900
Show file tree
Hide file tree
Showing 26 changed files with 765 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .pep8
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[pep8]
exclude=caffe_pb*,.eggs,*.egg,build
diff=True

2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ install:

script:
- flake8 --config=.flake8.cython
- autopep8 -r . --global-config .pep8 | tee check_autopep8
- autopep8 -r . --global-config .pep8 --diff | tee check_autopep8
- test ! -s check_autopep8
- PYTHONWARNINGS='ignore::FutureWarning,module::DeprecationWarning' nosetests -a '!gpu,!slow' --with-doctest tests/install_tests
- if [[ $TRAVIS_OS_NAME == "linux" ]]; then
Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ Do not forget to restart your terminal session (or `source` it) to enable these
And then, reinstall CuPy.


### Multi-GPU Support

Multi-GPU training is supported by MultiprocessParallelUpdater.
If you want to use MultiprocessParallelUpdater, please install [NCCL](https://github.com/NVIDIA/nccl) by following the installation guide.


## Run with Docker

We provide the official Docker image.
Expand Down
6 changes: 3 additions & 3 deletions cupy/core/carray.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ public:
__device__ T& operator[](const int* idx) {
char* ptr = reinterpret_cast<char*>(data_);
for (int dim = 0; dim < ndim; ++dim) {
ptr += strides_[dim] * idx[dim];
ptr += static_cast<ptrdiff_t>(strides_[dim]) * idx[dim];
}
return *reinterpret_cast<T*>(ptr);
}
Expand All @@ -175,11 +175,11 @@ public:
__device__ T& operator[](int i) {
char* ptr = reinterpret_cast<char*>(data_);
for (int dim = ndim; --dim > 0; ) {
ptr += strides_[dim] * (i % shape_[dim]);
ptr += static_cast<ptrdiff_t>(strides_[dim]) * (i % shape_[dim]);
i /= shape_[dim];
}
if (ndim > 0) {
ptr += strides_[0] * i;
ptr += static_cast<ptrdiff_t>(strides_[0]) * i;
}

return *reinterpret_cast<T*>(ptr);
Expand Down
10 changes: 9 additions & 1 deletion cupy/core/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,12 @@ cdef class ndarray:
.. seealso:: :meth:`numpy.ndarray.fill`
"""
if isinstance(value, numpy.ndarray):
if value.shape != ():
raise ValueError(
'non-scalar numpy.ndarray cannot be used for fill')
value = value.item()

if value == 0 and self._c_contiguous:
self.data.memset_async(0, self.nbytes, stream.Stream(True))
else:
Expand Down Expand Up @@ -1219,7 +1225,7 @@ cdef class ndarray:
shape.push_back(dim)
strides.push_back(self._strides[j] * s_step)

offset += s_start * self._strides[j]
offset += max(0, s_start) * self._strides[j]
j += 1
elif numpy.isscalar(s):
ind = int(s)
Expand All @@ -1234,6 +1240,8 @@ cdef class ndarray:
else:
raise TypeError('Invalid index type: %s' % type(slices[i]))

# TODO(niboshi): offset can be non-zero even if self.data is an empty
# pointer.
v = self.view()
v.data = self.data + offset
v._set_shape_and_strides(shape, strides)
Expand Down
9 changes: 5 additions & 4 deletions cupy/core/fusion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import inspect
import six
from six.moves import builtins
import string
Expand Down Expand Up @@ -489,8 +488,6 @@ def _get_fix_code(data_type, fixed_type, operation):


def _get_fusion(func, nin, reduce, post_map, identity, input_types, name=None):
if nin is None:
nin = len(inspect.getargspec(func).args)
in_vars = [_FusionVar(i, t) for i, t in enumerate(input_types)]
mem = _FusionMem(in_vars)
in_refs = [_FusionRef(_, mem) for _ in in_vars]
Expand Down Expand Up @@ -611,7 +608,11 @@ def is_cupy_data(a):
types = [_.dtype for _ in args]
key = tuple(types)
if key not in self._memo:
f = _get_fusion(self.func, self.input_num, self.reduce,
if self.input_num is not None:
nin = self.input_num
else:
nin = len(args)
f = _get_fusion(self.func, nin, self.reduce,
self.post_map, self.identity, types)
self._memo[key] = f
f = self._memo[key]
Expand Down
27 changes: 27 additions & 0 deletions cupy/cuda/cudnn.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ cdef extern from *:
ctypedef void* PoolingDescriptor 'cudnnPoolingDescriptor_t'
ctypedef void* RNNDescriptor 'cudnnRNNDescriptor_t'
ctypedef void* TensorDescriptor 'cudnnTensorDescriptor_t'
ctypedef void* SpatialTransformerDescriptor \
'cudnnSpatialTransformerDescriptor_t'
ctypedef void* SamplerType 'cudnnSamplerType_t'


###############################################################################
Expand Down Expand Up @@ -147,6 +150,8 @@ cpdef enum:
CUDNN_LINEAR_INPUT = 0
CUDNN_SKIP_INPUT = 1

CUDNN_SAMPLER_BILINEAR = 0


###############################################################################
# Initialization and CUDA cooperation
Expand Down Expand Up @@ -366,3 +371,25 @@ cpdef activationBackward_v3(
cpdef size_t createDropoutDescriptor() except *
cpdef destroyDropoutDescriptor(size_t dropoutDesc)
cpdef size_t dropoutGetStatesSize(size_t handle) except *


###############################################################################
# Spatial Transformer
###############################################################################

cpdef size_t createSpatialTransformerDescriptor() except *
cpdef destroySpatialTransformerDescriptor(size_t stDesc)
cpdef setSpatialTransformerDescriptor(
size_t stDesc, size_t samplerType, int dataType,
int nbDims, size_t dimA)
cpdef spatialTfGridGeneratorForward(
size_t handle, size_t stDesc, size_t theta, size_t grid)
cpdef spatialTfGridGeneratorBackward(
size_t handle, size_t stDesc, size_t dgrid, size_t dtheta)
cpdef spatialTfSamplerForward(
size_t handle, size_t stDesc, size_t alpha, size_t xDesc,
size_t x, size_t grid, size_t beta, size_t yDesc, size_t y)
cpdef spatialTfSamplerBackward(
size_t handle, size_t stDesc, size_t alpha, size_t xDesc,
size_t x, size_t beta, size_t dxDesc, size_t dx, size_t alphaDgrid,
size_t dyDesc, size_t dy, size_t grid, size_t betaDgrid, size_t dgrid)
92 changes: 92 additions & 0 deletions cupy/cuda/cudnn.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,31 @@ cdef extern from "cupy_cudnn.h":
void* workspace, size_t workSpaceSizeInBytes, FilterDescriptor dwDesc,
void* dw, void* reserveSpace, size_t reserveSpaceSizeInBytes) nogil

# Spatial Transformer
int cudnnCreateSpatialTransformerDescriptor(
SpatialTransformerDescriptor* stDesc) nogil
int cudnnDestroySpatialTransformerDescriptor(
SpatialTransformerDescriptor stDesc) nogil
int cudnnSetSpatialTransformerNdDescriptor(
SpatialTransformerDescriptor stDesc, SamplerType samplerType,
DataType dataType, int nbDims, int dimA[]) nogil
int cudnnSpatialTfGridGeneratorForward(
Handle handle, SpatialTransformerDescriptor stDesc,
void* theta, void* grid) nogil
int cudnnSpatialTfGridGeneratorBackward(
Handle handle, SpatialTransformerDescriptor stDesc,
void* dgrid, void* dtheta) nogil
int cudnnSpatialTfSamplerForward(
Handle handle, SpatialTransformerDescriptor stDesc,
void* alpha, TensorDescriptor xDesc, void* x,
void* grid, void* beta, TensorDescriptor yDesc, void* y) nogil
int cudnnSpatialTfSamplerBackward(
Handle handle, SpatialTransformerDescriptor stDesc,
void* alpha, TensorDescriptor xDesc, void* x, void* beta,
TensorDescriptor dxDesc, void* dx, void* alphaDgrid,
TensorDescriptor dyDesc, void* dy, void* grid,
void* betaDgrid, void* dgrid) nogil

###############################################################################
# Error handling
###############################################################################
Expand Down Expand Up @@ -1299,3 +1324,70 @@ cpdef RNNBackwardWeights(
<FilterDescriptor>dwDesc, <void*>dw,
<void*>reserveSpace, reserveSpaceSizeInBytes)
check_status(status)


# Spatial Transformer

cpdef size_t createSpatialTransformerDescriptor() except *:
cdef SpatialTransformerDescriptor stDesc
status = cudnnCreateSpatialTransformerDescriptor(&stDesc)
check_status(status)
return <size_t>stDesc


cpdef destroySpatialTransformerDescriptor(size_t stDesc):
status = cudnnDestroySpatialTransformerDescriptor(
<SpatialTransformerDescriptor>stDesc)
check_status(status)


cpdef setSpatialTransformerDescriptor(
size_t stDesc, size_t samplerType, int dataType,
int nbDims, size_t dimA):
status = cudnnSetSpatialTransformerNdDescriptor(
<SpatialTransformerDescriptor>stDesc, <SamplerType>samplerType,
<DataType>dataType, nbDims, <int*>dimA)
check_status(status)


cpdef spatialTfGridGeneratorForward(
size_t handle, size_t stDesc, size_t theta, size_t grid):
with nogil:
status = cudnnSpatialTfGridGeneratorForward(
<Handle>handle, <SpatialTransformerDescriptor> stDesc,
<void*>theta, <void*>grid)
check_status(status)


cpdef spatialTfGridGeneratorBackward(
size_t handle, size_t stDesc, size_t dgrid, size_t dtheta):
with nogil:
status = cudnnSpatialTfGridGeneratorBackward(
<Handle>handle, <SpatialTransformerDescriptor>stDesc,
<void*>dgrid, <void*>dtheta)
check_status(status)


cpdef spatialTfSamplerForward(
size_t handle, size_t stDesc, size_t alpha, size_t xDesc,
size_t x, size_t grid, size_t beta, size_t yDesc, size_t y):
with nogil:
status = cudnnSpatialTfSamplerForward(
<Handle>handle, <SpatialTransformerDescriptor>stDesc,
<void*>alpha, <TensorDescriptor>xDesc, <void*>x, <void*>grid,
<void*>beta, <TensorDescriptor>yDesc, <void*>y)
check_status(status)


cpdef spatialTfSamplerBackward(
size_t handle, size_t stDesc, size_t alpha, size_t xDesc,
size_t x, size_t beta, size_t dxDesc, size_t dx, size_t alphaDgrid,
size_t dyDesc, size_t dy, size_t grid, size_t betaDgrid, size_t dgrid):
with nogil:
status = cudnnSpatialTfSamplerBackward(
<Handle>handle, <SpatialTransformerDescriptor>stDesc,
<void*>alpha, <TensorDescriptor>xDesc, <void*>x, <void*>beta,
<TensorDescriptor>dxDesc, <void*>dx, <void*>alphaDgrid,
<TensorDescriptor>dyDesc, <void*>dy, <void*>grid,
<void*>betaDgrid, <void*>dgrid)
check_status(status)
30 changes: 30 additions & 0 deletions cupy/cuda/cupy_cudnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,8 @@ typedef enum {} cudnnRNNInputMode_t;

typedef void* cudnnDropoutDescriptor_t;
typedef void* cudnnRNNDescriptor_t;
typedef void* cudnnSpatialTransformerDescriptor_t;
typedef void* cudnnSamplerType_t;


cudnnStatus_t cudnnSetConvolution2dDescriptor_v5(...) {
Expand Down Expand Up @@ -466,6 +468,34 @@ cudnnStatus_t cudnnFindConvolutionBackwardDataAlgorithmEx(...) {
return CUDNN_STATUS_NOT_SUPPORTED;
}

cudnnStatus_t cudnnCreateSpatialTransformerDescriptor(...) {
return CUDNN_STATUS_NOT_SUPPORTED;
}

cudnnStatus_t cudnnDestroySpatialTransformerDescriptor(...) {
return CUDNN_STATUS_NOT_SUPPORTED;
}

cudnnStatus_t cudnnSetSpatialTransformerNdDescriptor(...) {
return CUDNN_STATUS_NOT_SUPPORTED;
}

cudnnStatus_t cudnnSpatialTfGridGeneratorForward(...) {
return CUDNN_STATUS_NOT_SUPPORTED;
}

cudnnStatus_t cudnnSpatialTfGridGeneratorBackward(...) {
return CUDNN_STATUS_NOT_SUPPORTED;
}

cudnnStatus_t cudnnSpatialTfSamplerForward(...) {
return CUDNN_STATUS_NOT_SUPPORTED;
}

cudnnStatus_t cudnnSpatialTfSamplerBackward(...) {
return CUDNN_STATUS_NOT_SUPPORTED;
}

#endif // #if defined(CUPY_NO_CUDA) || (CUDNN_VERSION < 5000)


Expand Down
45 changes: 45 additions & 0 deletions cupy/cuda/cupy_nccl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// This file is a stub header file of cudnn for Read the Docs.

#ifndef INCLUDE_GUARD_CUPY_NCCL_H
#define INCLUDE_GUARD_CUPY_NCCL_H

#ifndef CUPY_NO_CUDA

#include <nccl.h>

#else // #ifndef CUPY_NO_CUDA

extern "C" {

typedef struct ncclComm* ncclComm_t;

enum {
NCCL_UNIQUE_ID_BYTES = 128
};
typedef struct {
char internal[NCCL_UNIQUE_ID_BYTES];
} ncclUniqueId;

typedef enum {
ncclSuccess
} ncclResult_t;

typedef enum {} ncclRedOp_t;

typedef enum {} ncclDataType_t;

const char* ncclGetErrorString(...);
ncclResult_t ncclGetUniqueId(...);
ncclResult_t ncclCommInitRank(...);
void ncclCommDestroy(...);
ncclResult_t ncclCommCuDevice(...);
ncclResult_t ncclCommUserRank(...);
ncclResult_t ncclAllReduce(...);
ncclResult_t ncclReduce(...);
ncclResult_t ncclBcast(...);

}

#endif // #ifndef CUPY_NO_CUDA

#endif // #ifndef INCLUDE_GUARD_CUPY_NCCL_H
16 changes: 16 additions & 0 deletions cupy/cuda/nccl.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
Wrapper for NCCL: Optimized primiteive for collective multi-GPU communication
"""
cpdef enum:
NCCL_SUM = 0
NCCL_PROD = 1
NCCL_MAX = 2
NCCL_MIN = 3

NCCL_CHAR = 0
NCCL_INT = 1
NCCL_HALF = 2
NCCL_FLOAT = 3
NCCL_DOUBLE = 4
NCCL_INT64 = 5
NCCL_UINT64 = 6

0 comments on commit 0049900

Please sign in to comment.