Permalink
Browse files

Optimization 1: Use __restrict__ pointers.

  • Loading branch information...
galv committed Jan 7, 2018
1 parent 12c659e commit e7edd2d5b470f79dfef26a9520b80755ca60977f
Showing with 10 additions and 10 deletions.
  1. +10 −10 include/detail/gpu_ctc_kernels.h
@@ -86,11 +86,11 @@ struct CTASegReduce {
// than the labels. This is much more true for Mandarin than English.
template<typename ProbT, int NT, int VT>
__global__
void compute_alpha_kernel (const ProbT* probs, const int *label_sizes,
const int *utt_length, const int *repeats_in_labels,
const int *labels_without_blanks, const int *label_offsets,
int *labels_with_blanks, ProbT *alphas,
ProbT* nll_forward, int stride, int out_dim,
void compute_alpha_kernel (const ProbT* __restrict__ probs, const int * __restrict__ label_sizes,
const int * __restrict__ utt_length, const int * __restrict__ repeats_in_labels,
const int * __restrict__ labels_without_blanks, const int * __restrict__ label_offsets,
int * __restrict__ labels_with_blanks, ProbT * __restrict__ alphas,
ProbT* __restrict__ nll_forward, int stride, int out_dim,
int S_memoffset, int T_memoffset, int blank_label) {
ctc_helper::log_plus<ProbT> log_plus_f;
@@ -217,11 +217,11 @@ void compute_alpha_kernel (const ProbT* probs, const int *label_sizes,
// See comments above compute_alphas for more context.
template<typename ProbT, int NT, int VT>
__global__
void compute_betas_and_grad_kernel (const ProbT* probs, const int *label_sizes,
const int *utt_length, const int *repeats_in_labels,
const int *labels_with_blanks, ProbT *alphas,
const ProbT* nll_forward, ProbT *nll_backward,
ProbT *grads, int stride, int out_dim,
void compute_betas_and_grad_kernel (const ProbT* __restrict__ probs, const int * __restrict__ label_sizes,
const int * __restrict__ utt_length, const int * __restrict__ repeats_in_labels,
const int * __restrict__ labels_with_blanks, ProbT * __restrict__ alphas,
const ProbT* __restrict__ nll_forward, ProbT * __restrict__ nll_backward,
ProbT * __restrict__ grads, int stride, int out_dim,
int S_memoffset, int T_memoffset, int blank_label) {
ctc_helper::log_plus<ProbT> log_plus_f;

0 comments on commit e7edd2d

Please sign in to comment.