New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support cuDNN CTC functions #1769
Conversation
cupy/cudnn.pyx
Outdated
@@ -382,6 +382,13 @@ def set_dropout_descriptor(desc, handle, dropout): | |||
cudnn.setDropoutDescriptor(desc.value, handle, dropout, 0, 0, 0) | |||
|
|||
|
|||
def create_ctc_loss_descriptor(data_type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to hide descriptor
and wrap cuDNN interface because raw cuDNN API is too complex.
This is high level API example.
Line 311 in 834f9f3
def activation_forward(core.ndarray x, int mode, double coef=0.0): |
Cloud you remove this function from this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. Thank you for your comment.
Do you mean should I define the other high level API function like following?
# in cupy/cudnn.pyx
def ctc_loss(data_type, other_args_for_cudnnCTCLoss):
# create descriptor
desc = Descriptor(cudnn.createCTCLossDescriptor(),
py_cudnn.destroyCTCLossDescriptor)
cudnn.setCTCLossDescriptor(desc.value, data_type)
# compute workspace
getCTCLossWorkspaceSize(hoge, worksize)
# allocate worksize
workspace = ...
# compute CTC loss
CTCLoss(hoge, workspace, other_args_for_cudnnCTCLoss)
return loss, gradients
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, something like that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cloud you rename this function or delete this?
def create_ctc_loss_descriptor(data_type): | |
def _create_ctc_loss_descriptor(data_type): |
jenkins, test this please. |
Jenkins CI test (for commit 2309250, target branch master) failed with status FAILURE. |
Co-Authored-By: aonotas <nanigashi03@gmail.com>
Co-Authored-By: aonotas <nanigashi03@gmail.com>
7b26f2d
to
c206bc6
Compare
Can you check this comments when you can time? |
In ctc function, I found these things:
so I define arguments wiht CPU pointer in high-level API. Please give me some advice. and Do you know why this error occurs? |
In such case, using |
You need to fix the dummy header file? |
cupy/cudnn.pyx
Outdated
@@ -382,6 +382,13 @@ def set_dropout_descriptor(desc, handle, dropout): | |||
cudnn.setDropoutDescriptor(desc.value, handle, dropout, 0, 0, 0) | |||
|
|||
|
|||
def create_ctc_loss_descriptor(data_type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cloud you rename this function or delete this?
def create_ctc_loss_descriptor(data_type): | |
def _create_ctc_loss_descriptor(data_type): |
jenkins, test this please. |
Jenkins CI test (for commit 02df19a, target branch master) failed with status FAILURE. |
@aonotas Do you have some time to update this PR? |
Jenkins CI test (for commit 02df19a, target branch master) failed with status FAILURE. |
I fix the code (just rename the function). |
chainerci please |
jenkis, test this please. |
Jenkins CI test (for commit cb658d8, target branch master) failed with status FAILURE. |
LGTM! |
Add cuDNN CTC functions.
This cuDNN CTC is supported from v7.