Skip to content

Commit

Permalink
Speed up maxout by exploiting parallelism better
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Jan 25, 2022
1 parent ae7bcdd commit d022b76
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions thinc/backends/_custom_kernels.cu
Expand Up @@ -88,26 +88,23 @@ void maxout(float* best, int* which,
{
int _loop_start = blockIdx.x * blockDim.x + threadIdx.x;
int _loop_stride = blockDim.x * gridDim.x;
for (int b = _loop_start; b < B; b += _loop_stride)
for (int bo = _loop_start; bo < B * O; bo += _loop_stride)
{
// Go to the regions we're working on
float* best_b = &best[b*O];
int* which_b = &which[b*O];
// Go to the candidates at the output we're working on
const float* cands_bo = &cands[bo * P];

for (int i=0; i < O; ++i)
int best_idx = 0;
float best_val = cands_bo[0];
for (int p=1; p < P; ++p)
{
const float* cands_bi = &cands[b*O*P+(i*P)];
which_b[i] = 0;
best_b[i] = cands_bi[0];
for (int p=1; p < P; ++p)
{
if (cands_bi[p] > best_b[i])
{
which_b[i] = p;
best_b[i] = cands_bi[p];
}
if (cands_bo[p] > best_val) {
best_idx = p;
best_val = cands_bo[p];
}
}

which[bo] = best_idx;
best[bo] = best_val;
}
}

Expand Down

0 comments on commit d022b76

Please sign in to comment.