Skip to content

Commit

Permalink
fix flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
aonotas committed Nov 19, 2018
1 parent dcd64de commit bf0550d
Showing 1 changed file with 28 additions and 28 deletions.
56 changes: 28 additions & 28 deletions cupy/cudnn.pyx
Expand Up @@ -462,38 +462,38 @@ def set_dropout_descriptor(desc, handle, dropout):


def create_ctc_loss_descriptor(data_type):
desc = Descriptor(cudnn.createCTCLossDescriptor(),
py_cudnn.destroyCTCLossDescriptor)
cudnn.setCTCLossDescriptor(desc.value, data_type)
return desc
desc = Descriptor(cudnn.createCTCLossDescriptor(),
py_cudnn.destroyCTCLossDescriptor)
cudnn.setCTCLossDescriptor(desc.value, data_type)
return desc


def ctc_loss(core.ndarray probs, labels_ptr,
label_length_ptr, input_length_ptr, int algo):
batch_size = probs.shape[1]
handle = get_handle()
data_type = get_data_type(probs.dtype)
ctc_desc = Descriptor(cudnn.createCTCLossDescriptor(),
py_cudnn.destroyCTCLossDescriptor)
cudnn.setCTCLossDescriptor(ctc_desc.value, data_type)

gradients = core.ndarray(probs._shape, probs.dtype)
loss = core.ndarray((batch_size, ), 'f')
probs_desc = create_tensor_descriptor(probs)
gradients_desc = create_tensor_descriptor(gradients)

work_size = cudnn.getCTCLossWorkspaceSize(
handle, probs_desc.value, gradients_desc.value,
labels_ptr, label_length_ptr,
input_length_ptr, algo, ctc_desc.value)
workspace = core.ndarray((work_size,), 'b')

cudnn.CTCLoss(handle, probs_desc.value, probs.data.ptr,
labels_ptr, label_length_ptr,
input_length_ptr, loss.data.ptr, gradients_desc.value,
gradients.data.ptr, algo, ctc_desc.value,
workspace.data.ptr, work_size)
return loss, gradients
batch_size = probs.shape[1]
handle = get_handle()
data_type = get_data_type(probs.dtype)
ctc_desc = Descriptor(cudnn.createCTCLossDescriptor(),
py_cudnn.destroyCTCLossDescriptor)
cudnn.setCTCLossDescriptor(ctc_desc.value, data_type)

gradients = core.ndarray(probs._shape, probs.dtype)
loss = core.ndarray((batch_size, ), 'f')
probs_desc = create_tensor_descriptor(probs)
gradients_desc = create_tensor_descriptor(gradients)

work_size = cudnn.getCTCLossWorkspaceSize(
handle, probs_desc.value, gradients_desc.value,
labels_ptr, label_length_ptr,
input_length_ptr, algo, ctc_desc.value)
workspace = core.ndarray((work_size,), 'b')

cudnn.CTCLoss(handle, probs_desc.value, probs.data.ptr,
labels_ptr, label_length_ptr,
input_length_ptr, loss.data.ptr, gradients_desc.value,
gradients.data.ptr, algo, ctc_desc.value,
workspace.data.ptr, work_size)
return loss, gradients


def create_rnn_descriptor(hidden_size, num_layers, dropout_desc,
Expand Down

0 comments on commit bf0550d

Please sign in to comment.