Skip to content

Commit

Permalink
Merge pull request #1769 from aonotas/cudnn-ctc
Browse files Browse the repository at this point in the history
Support cuDNN CTC functions
  • Loading branch information
okuta committed Apr 8, 2019
2 parents 7312c4d + cb658d8 commit b40438e
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 1 deletion.
22 changes: 22 additions & 0 deletions cupy/cuda/cudnn.pxd
Expand Up @@ -116,6 +116,9 @@ cpdef enum:
CUDNN_BATCHNORM_SPATIAL = 1
CUDNN_BATCHNORM_SPATIAL_PERSISTENT = 2

CUDNN_CTC_LOSS_ALGO_DETERMINISTIC = 0
CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC = 1

CUDNN_BATCHNORM_OPS_BN = 0
CUDNN_BATCHNORM_OPS_BN_ACTIVATION = 1
CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION = 2
Expand Down Expand Up @@ -522,6 +525,25 @@ cpdef dropoutBackward(
size_t reserveSpace, size_t reserveSpaceSizeInBytes)


###############################################################################
# CTC
###############################################################################

cpdef size_t createCTCLossDescriptor() except? 0
cpdef destroyCTCLossDescriptor(size_t ctcLossDesc)
cpdef setCTCLossDescriptor(size_t ctcLossDesc, int dataType)
cpdef getCTCLossDescriptor(size_t ctcLossDesc)
cpdef size_t getCTCLossWorkspaceSize(
size_t handle, size_t probsDesc, size_t gradientsDesc,
size_t labels, size_t labelLengths, size_t inputLengths,
int algo, size_t ctcLossDesc) except? 0
cpdef CTCLoss(
size_t handle, size_t probsDesc,
size_t probs, size_t labels, size_t labelLengths, size_t inputLengths,
size_t costs, size_t gradientsDesc, size_t gradients, int algo,
size_t ctcLossDesc, size_t workspace, size_t workSpaceSizeInBytes)


###############################################################################
# RNN
###############################################################################
Expand Down
73 changes: 73 additions & 0 deletions cupy/cuda/cudnn.pyx
Expand Up @@ -74,6 +74,7 @@ cdef extern from 'cupy_cudnn.h' nogil:
ctypedef int NanPropagation 'cudnnNanPropagation_t'
ctypedef int PoolingMode 'cudnnPoolingMode_t'
ctypedef int RNNInputMode 'cudnnRNNInputMode_t'
ctypedef int CTCLossAlgo 'cudnnCTCLossAlgo_t'
ctypedef int RNNMode 'cudnnRNNMode_t'
ctypedef int RNNAlgo 'cudnnRNNAlgo_t'
ctypedef int RNNDataLayout 'cudnnRNNDataLayout_t'
Expand All @@ -95,6 +96,7 @@ cdef extern from 'cupy_cudnn.h' nogil:
ctypedef void* FilterDescriptor 'cudnnFilterDescriptor_t'
ctypedef void* Handle 'cudnnHandle_t'
ctypedef void* PoolingDescriptor 'cudnnPoolingDescriptor_t'
ctypedef void* CTCLossDescriptor 'cudnnCTCLossDescriptor_t'
ctypedef void* RNNDescriptor 'cudnnRNNDescriptor_t'
ctypedef void* RNNDataDescriptor 'cudnnRNNDataDescriptor_t'
ctypedef void* PersistentRNNPlan 'cudnnPersistentRNNPlan_t'
Expand Down Expand Up @@ -509,6 +511,24 @@ cdef extern from 'cupy_cudnn.h' nogil:
TensorDescriptor dydesc, void* dy, TensorDescriptor dxdesc,
void* dx, void* reserveSpace, size_t reserveSpaceSizeInBytes)

# CTC
int cudnnCreateCTCLossDescriptor(CTCLossDescriptor* ctcLossDesc)
int cudnnDestroyCTCLossDescriptor(CTCLossDescriptor ctcLossDesc)
int cudnnSetCTCLossDescriptor(
CTCLossDescriptor ctcLossDesc, DataType dataType)
int cudnnGetCTCLossDescriptor(
CTCLossDescriptor ctcLossDesc, DataType* dataType)
int cudnnGetCTCLossWorkspaceSize(
Handle handle, TensorDescriptor probsDesc,
TensorDescriptor gradientsDesc, int* labels,
int* labelLengths, int* inputLengths, CTCLossAlgo algo,
CTCLossDescriptor ctcLossDesc, size_t* sizeInBytes)
int cudnnCTCLoss(
Handle handle, TensorDescriptor probsDesc,
void* probs, int* labels, int* labelLengths, int* inputLengths,
void* costs, TensorDescriptor gradientsDesc, void* gradients,
CTCLossAlgo algo, CTCLossDescriptor ctcLossDesc,
void* workspace, size_t workSpaceSizeInBytes)
# RNN
int cudnnCreateRNNDescriptor(RNNDescriptor* rnnDesc)
int cudnnDestroyRNNDescriptor(RNNDescriptor rnnDesc)
Expand Down Expand Up @@ -1879,6 +1899,59 @@ cpdef dropoutBackward(
check_status(status)


###############################################################################
# CTC
###############################################################################
cpdef size_t createCTCLossDescriptor() except? 0:
cdef CTCLossDescriptor desc
status = cudnnCreateCTCLossDescriptor(&desc)
check_status(status)
return <size_t>desc

cpdef destroyCTCLossDescriptor(size_t ctcLossDesc):
status = cudnnDestroyCTCLossDescriptor(<CTCLossDescriptor>ctcLossDesc)
check_status(status)

cpdef setCTCLossDescriptor(size_t ctcLossDesc, int dataType):
status = cudnnSetCTCLossDescriptor(
<CTCLossDescriptor>ctcLossDesc, <DataType>dataType)
check_status(status)

cpdef getCTCLossDescriptor(size_t ctcLossDesc):
cdef DataType compType
status = cudnnGetCTCLossDescriptor(
<CTCLossDescriptor>ctcLossDesc, &compType)
check_status(status)
return compType

cpdef size_t getCTCLossWorkspaceSize(
size_t handle, size_t probsDesc, size_t gradientsDesc,
size_t labels, size_t labelLengths, size_t inputLengths,
int algo, size_t ctcLossDesc) except? 0:
cdef size_t sizeInBytes
status = cudnnGetCTCLossWorkspaceSize(
<Handle>handle, <TensorDescriptor>probsDesc,
<TensorDescriptor>gradientsDesc,
<int*>labels, <int*>labelLengths, <int*>inputLengths,
<CTCLossAlgo>algo, <CTCLossDescriptor>ctcLossDesc, &sizeInBytes)
check_status(status)
return sizeInBytes

cpdef CTCLoss(
size_t handle, size_t probsDesc,
size_t probs, size_t labels, size_t labelLengths, size_t inputLengths,
size_t costs, size_t gradientsDesc, size_t gradients,
int algo, size_t ctcLossDesc,
size_t workspace, size_t workSpaceSizeInBytes):
status = cudnnCTCLoss(
<Handle>handle, <TensorDescriptor>probsDesc, <void*>probs,
<int*>labels, <int*>labelLengths, <int*>inputLengths,
<void*>costs, <TensorDescriptor>gradientsDesc, <void*>gradients,
<CTCLossAlgo>algo, <CTCLossDescriptor>ctcLossDesc,
<void*>workspace, <size_t>workSpaceSizeInBytes)
check_status(status)


###############################################################################
# RNN
###############################################################################
Expand Down
46 changes: 45 additions & 1 deletion cupy/cuda/cupy_cudnn.h
Expand Up @@ -26,6 +26,7 @@ typedef enum {} cudnnActivationMode_t;
typedef enum {} cudnnConvolutionFwdAlgo_t;
typedef enum {} cudnnConvolutionFwdPreference_t;
typedef enum {} cudnnConvolutionMode_t;
typedef enum {} cudnnCTCLossAlgo_t;
typedef enum {} cudnnDataType_t;
typedef enum {} cudnnPoolingMode_t;
typedef enum {} cudnnSoftmaxAlgorithm_t;
Expand All @@ -35,6 +36,7 @@ typedef enum {} cudnnErrQueryMode_t;
typedef struct cudnnRuntimeTag_t cudnnRuntimeTag_t;

typedef void* cudnnConvolutionDescriptor_t;
typedef void* cudnnCTCLossDescriptor_t;
typedef void* cudnnFilterDescriptor_t;
typedef void* cudnnHandle_t;
typedef void* cudnnPoolingDescriptor_t;
Expand Down Expand Up @@ -317,6 +319,26 @@ cudnnStatus_t cudnnActivationBackward_v4(...) {
return CUDNN_STATUS_NOT_SUPPORTED;
}

// CTC
cudnnStatus_t cudnnCreateCTCLossDescriptor(...) {
return CUDNN_STATUS_NOT_SUPPORTED;
}
cudnnStatus_t cudnnDestroyCTCLossDescriptor(...) {
return CUDNN_STATUS_NOT_SUPPORTED;
}
cudnnStatus_t cudnnSetCTCLossDescriptor(...) {
return CUDNN_STATUS_NOT_SUPPORTED;
}
cudnnStatus_t cudnnGetCTCLossDescriptor(...) {
return CUDNN_STATUS_NOT_SUPPORTED;
}
cudnnStatus_t cudnnGetCTCLossWorkspaceSize(...) {
return CUDNN_STATUS_NOT_SUPPORTED;
}
cudnnStatus_t cudnnCTCLoss(...) {
return CUDNN_STATUS_NOT_SUPPORTED;
}

typedef enum {} cudnnMathType_t;

cudnnStatus_t cudnnSetConvolutionMathType(...) {
Expand Down Expand Up @@ -674,6 +696,28 @@ cudnnStatus_t cudnnReduceTensor(...) {
#define cudnnSetRNNDescriptor_v5 cudnnSetRNNDescriptor

typedef enum {} cudnnMathType_t;
typedef enum {} cudnnCTCLossAlgo_t;
typedef void* cudnnCTCLossDescriptor_t;

// CTC
cudnnStatus_t cudnnCreateCTCLossDescriptor(...) {
return CUDNN_STATUS_NOT_SUPPORTED;
}
cudnnStatus_t cudnnDestroyCTCLossDescriptor(...) {
return CUDNN_STATUS_NOT_SUPPORTED;
}
cudnnStatus_t cudnnSetCTCLossDescriptor(...) {
return CUDNN_STATUS_NOT_SUPPORTED;
}
cudnnStatus_t cudnnGetCTCLossDescriptor(...) {
return CUDNN_STATUS_NOT_SUPPORTED;
}
cudnnStatus_t cudnnGetCTCLossWorkspaceSize(...) {
return CUDNN_STATUS_NOT_SUPPORTED;
}
cudnnStatus_t cudnnCTCLoss(...) {
return CUDNN_STATUS_NOT_SUPPORTED;
}

cudnnStatus_t cudnnSetConvolutionMathType(...) {
return CUDNN_STATUS_NOT_SUPPORTED;
Expand Down Expand Up @@ -789,7 +833,7 @@ typedef void* cudnnRNNDataDescriptor_t;

typedef enum {} cudnnRNNDataLayout_t;
typedef enum {} cudnnRNNPaddingMode_t;

cudnnStatus_t cudnnSetRNNPaddingMode(...) {
return CUDNN_STATUS_NOT_SUPPORTED;
}
Expand Down
38 changes: 38 additions & 0 deletions cupy/cudnn.pyx
Expand Up @@ -729,6 +729,44 @@ def set_dropout_descriptor(desc, handle, dropout):
cudnn.setDropoutDescriptor(desc.value, handle, dropout, 0, 0, 0)


def _create_ctc_loss_descriptor(data_type):
desc = Descriptor(cudnn.createCTCLossDescriptor(),
py_cudnn.destroyCTCLossDescriptor)
cudnn.setCTCLossDescriptor(desc.value, data_type)
return desc


def ctc_loss(core.ndarray probs, labels,
label_length, input_length, int algo):
batch_size = probs.shape[1]
labels_ptr = labels.ctypes.data
label_length_ptr = label_length.ctypes.data
input_length_ptr = input_length.ctypes.data
handle = get_handle()
data_type = get_data_type(probs.dtype)
ctc_desc = Descriptor(cudnn.createCTCLossDescriptor(),
py_cudnn.destroyCTCLossDescriptor)
cudnn.setCTCLossDescriptor(ctc_desc.value, data_type)

gradients = core.ndarray(probs._shape, probs.dtype)
loss = core.ndarray((batch_size, ), 'f')
probs_desc = create_tensor_descriptor(probs)
gradients_desc = create_tensor_descriptor(gradients)

work_size = cudnn.getCTCLossWorkspaceSize(
handle, probs_desc.value, gradients_desc.value,
labels_ptr, label_length_ptr,
input_length_ptr, algo, ctc_desc.value)
workspace = core.ndarray((work_size,), 'b')

cudnn.CTCLoss(handle, probs_desc.value, probs.data.ptr,
labels_ptr, label_length_ptr,
input_length_ptr, loss.data.ptr, gradients_desc.value,
gradients.data.ptr, algo, ctc_desc.value,
workspace.data.ptr, work_size)
return loss, gradients


def create_rnn_descriptor(hidden_size, num_layers, dropout_desc,
input_mode, direction, mode, data_type, algo=None):
desc = Descriptor(cudnn.createRNNDescriptor(),
Expand Down

0 comments on commit b40438e

Please sign in to comment.