Skip to content

Commit

Permalink
update getCTCLossWorkspaceSize
Browse files Browse the repository at this point in the history
  • Loading branch information
aonotas committed Nov 2, 2018
1 parent a939329 commit 7b26f2d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
7 changes: 3 additions & 4 deletions cupy/cuda/cudnn.pxd
Expand Up @@ -456,16 +456,15 @@ cpdef size_t createCTCLossDescriptor() except? 0
cpdef destroyCTCLossDescriptor(size_t ctcLossDesc)
cpdef setCTCLossDescriptor(size_t ctcLossDesc, int dataType)
cpdef getCTCLossDescriptor(size_t ctcLossDesc)
cpdef getCTCLossWorkspaceSize(
cpdef size_t getCTCLossWorkspaceSize(
size_t handle, size_t probsDesc, size_t gradientsDesc,
size_t labels, size_t labelLengths, size_t inputLengths,
int algo, size_t ctcLossDesc)
cpdef CTCLoss(
size_t handle, size_t probsDesc,
size_t probs, size_t labels, size_t labelLengths, size_t inputLengths,
size_t costs, size_t gradientsDesc, size_t gradients,
int algo, size_t ctcLossDesc,
size_t workspace, size_t workSpaceSizeInBytes)
size_t costs, size_t gradientsDesc, size_t gradients, int algo,
size_t ctcLossDesc, size_t workspace, size_t workSpaceSizeInBytes)


###############################################################################
Expand Down
2 changes: 1 addition & 1 deletion cupy/cuda/cudnn.pyx
Expand Up @@ -1740,7 +1740,7 @@ cpdef size_t getCTCLossWorkspaceSize(
<int*>labels, <int*>labelLengths, <int*>inputLengths,
<CTCLossAlgo>algo, <CTCLossDescriptor>ctcLossDesc, &sizeInBytes)
check_status(status)
return sizeInBytes
return <size_t>sizeInBytes

cpdef CTCLoss(
size_t handle, size_t probsDesc,
Expand Down

0 comments on commit 7b26f2d

Please sign in to comment.