Skip to content

Commit

Permalink
formatted warp_idx used in tabulate kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
darelbeida committed Jul 6, 2021
1 parent ad3e30e commit 7d2be2d
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions source/lib/src/cuda/tabulate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,11 @@ __global__ void tabulate_fusion_grad_fifth_order_polynomial(
}
__syncthreads();
FPTYPE ago = __shfl_sync(0xffffffff, em_x[block_idx * nnei + nnei - 1], 0);
for (int ii = 0; ii < nnei; ii += KTILE) {
FPTYPE xx = em_x[block_idx * nnei + ii + warp_idx];
for (int ii = warp_idx; ii < nnei; ii += KTILE) {
FPTYPE xx = em_x[block_idx * nnei + ii];
if (ago == xx) {
unloop = true;
breakpoint = ii + warp_idx;
breakpoint = ii;
}

int table_idx = 0;
Expand All @@ -166,10 +166,10 @@ __global__ void tabulate_fusion_grad_fifth_order_polynomial(
for (int kk = 0; kk < MTILE; kk++) {
sum[kk] += (nnei - breakpoint) * iteratorA[kk * last_layer_size + jj] * res;
}
res = em[block_idx * nnei * MTILE + (ii + warp_idx) * 4 + 0] * iteratorA[0 * last_layer_size + jj];
res += em[block_idx * nnei * MTILE + (ii + warp_idx) * 4 + 1] * iteratorA[1 * last_layer_size + jj];
res += em[block_idx * nnei * MTILE + (ii + warp_idx) * 4 + 2] * iteratorA[2 * last_layer_size + jj];
res += em[block_idx * nnei * MTILE + (ii + warp_idx) * 4 + 3] * iteratorA[3 * last_layer_size + jj];
res = em[block_idx * nnei * MTILE + ii * 4 + 0] * iteratorA[0 * last_layer_size + jj];
res += em[block_idx * nnei * MTILE + ii * 4 + 1] * iteratorA[1 * last_layer_size + jj];
res += em[block_idx * nnei * MTILE + ii * 4 + 2] * iteratorA[2 * last_layer_size + jj];
res += em[block_idx * nnei * MTILE + ii * 4 + 3] * iteratorA[3 * last_layer_size + jj];
Csub += (nnei - breakpoint) * (var[1] + (2 * var[2] + (3 * var[3] + (4 * var[4] + 5 * var[5] * xx) * xx) * xx) * xx) * res;
}
__syncwarp();
Expand All @@ -179,9 +179,9 @@ __global__ void tabulate_fusion_grad_fifth_order_polynomial(
warp_reduce(Csub);
if (lane_idx == 0) {
for (int kk = 0; kk < MTILE; kk++) {
dy_dem[block_idx * nnei * MTILE + (ii + warp_idx) * 4 + kk] = sum[kk];
dy_dem[block_idx * nnei * MTILE + ii * 4 + kk] = sum[kk];
}
dy_dem_x[block_idx * nnei + ii + warp_idx] = Csub;
dy_dem_x[block_idx * nnei + ii] = Csub;
}
if (unloop) break;
}
Expand Down

0 comments on commit 7d2be2d

Please sign in to comment.