Skip to content

Commit

Permalink
speedup tabulate cuda kernel by reducing shm using (#830)
Browse files Browse the repository at this point in the history
* reduced the shm used in tabulate_fusion_fifth_order_polynomial cuda kernel

* formatted `MTILE` and `KTILE` used in tabulate kernels

* formatted `warp_idx` used in tabulate kernel
  • Loading branch information
darelbeida committed Jul 7, 2021
1 parent 914c054 commit 34c9bc9
Showing 1 changed file with 19 additions and 23 deletions.
42 changes: 19 additions & 23 deletions source/lib/src/cuda/tabulate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,13 @@ __global__ void tabulate_fusion_fifth_order_polynomial(
const int nnei,
const int last_layer_size)
{
extern __shared__ int _data[];
const int block_idx = blockIdx.x; // nloc
const int thread_idx = threadIdx.x; // last_layer_size
FPTYPE ago = __shfl_sync(0xffffffff, em_x[block_idx * nnei + nnei - 1], 0);
bool unloop = false;
int breakpoint = nnei - 1;
FPTYPE * iteratorC = (FPTYPE*) &_data[0];
for (int kk = 0; kk < MTILE; kk++)
iteratorC[kk * last_layer_size + thread_idx] = 0.f;
__syncthreads();

FPTYPE sum[MTILE] = {0.f};
for (int ii = 0; ii < nnei; ii++) {
FPTYPE var[6];
FPTYPE xx = em_x[block_idx * nnei + ii];
Expand All @@ -102,12 +98,12 @@ __global__ void tabulate_fusion_fifth_order_polynomial(
FPTYPE res = var[0] + (var[1] + (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * xx;

for (int kk = 0; kk < MTILE; kk++) {
iteratorC[kk * last_layer_size + thread_idx] += (nnei - breakpoint) * em[block_idx * nnei * MTILE + ii * MTILE + kk] * res;
sum[kk] += (nnei - breakpoint) * em[block_idx * nnei * MTILE + ii * MTILE + kk] * res;
}
if (unloop) break;
}
for (int ii = 0; ii < MTILE; ii++) {
out[block_idx * MTILE * last_layer_size + ii * last_layer_size + thread_idx] = iteratorC[ii * last_layer_size + thread_idx];
out[block_idx * MTILE * last_layer_size + ii * last_layer_size + thread_idx] = sum[ii];
}
}

Expand All @@ -133,8 +129,8 @@ __global__ void tabulate_fusion_grad_fifth_order_polynomial(
extern __shared__ int _data[];
const int block_idx = blockIdx.x; // nloc
const int thread_idx = threadIdx.x; // KTILE * WARP_SIZE, usally 128 here~
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / WARP_SIZE, 0);
int lane_idx = threadIdx.x % WARP_SIZE;
int breakpoint = nnei - 1;
bool unloop = false;
FPTYPE * iteratorA = (FPTYPE *)&_data[0]; // dy
Expand All @@ -145,16 +141,16 @@ __global__ void tabulate_fusion_grad_fifth_order_polynomial(
}
__syncthreads();
FPTYPE ago = __shfl_sync(0xffffffff, em_x[block_idx * nnei + nnei - 1], 0);
for (int ii = 0; ii < nnei; ii += KTILE) {
FPTYPE xx = em_x[block_idx * nnei + ii + warp_idx];
for (int ii = warp_idx; ii < nnei; ii += KTILE) {
FPTYPE xx = em_x[block_idx * nnei + ii];
if (ago == xx) {
unloop = true;
breakpoint = ii + warp_idx;
breakpoint = ii;
}

int table_idx = 0;
locate_xx(xx, table_idx, lower, upper, max, stride0, stride1);
FPTYPE sum[KTILE] = {0.f};
FPTYPE sum[MTILE] = {0.f};
FPTYPE Csub = 0.f;
for (int jj = lane_idx; jj < last_layer_size; jj += WARP_SIZE) {
FPTYPE var[6];
Expand All @@ -167,25 +163,25 @@ __global__ void tabulate_fusion_grad_fifth_order_polynomial(
var[5] = table[table_idx * last_layer_size * 6 + 6 * jj + 5];
FPTYPE res = var[0] + (var[1] + (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * xx;

for (int kk = 0; kk < KTILE; kk++) {
for (int kk = 0; kk < MTILE; kk++) {
sum[kk] += (nnei - breakpoint) * iteratorA[kk * last_layer_size + jj] * res;
}
res = em[block_idx * nnei * MTILE + (ii + warp_idx) * 4 + 0] * iteratorA[0 * last_layer_size + jj];
res += em[block_idx * nnei * MTILE + (ii + warp_idx) * 4 + 1] * iteratorA[1 * last_layer_size + jj];
res += em[block_idx * nnei * MTILE + (ii + warp_idx) * 4 + 2] * iteratorA[2 * last_layer_size + jj];
res += em[block_idx * nnei * MTILE + (ii + warp_idx) * 4 + 3] * iteratorA[3 * last_layer_size + jj];
res = em[block_idx * nnei * MTILE + ii * 4 + 0] * iteratorA[0 * last_layer_size + jj];
res += em[block_idx * nnei * MTILE + ii * 4 + 1] * iteratorA[1 * last_layer_size + jj];
res += em[block_idx * nnei * MTILE + ii * 4 + 2] * iteratorA[2 * last_layer_size + jj];
res += em[block_idx * nnei * MTILE + ii * 4 + 3] * iteratorA[3 * last_layer_size + jj];
Csub += (nnei - breakpoint) * (var[1] + (2 * var[2] + (3 * var[3] + (4 * var[4] + 5 * var[5] * xx) * xx) * xx) * xx) * res;
}
__syncwarp();
for (int kk = 0; kk < KTILE; kk++) {
for (int kk = 0; kk < MTILE; kk++) {
warp_reduce(sum[kk]);
}
warp_reduce(Csub);
if (lane_idx == 0) {
for (int kk = 0; kk < KTILE; kk++) {
dy_dem[block_idx * nnei * MTILE + (ii + warp_idx) * 4 + kk] = sum[kk];
for (int kk = 0; kk < MTILE; kk++) {
dy_dem[block_idx * nnei * MTILE + ii * 4 + kk] = sum[kk];
}
dy_dem_x[block_idx * nnei + ii + warp_idx] = Csub;
dy_dem_x[block_idx * nnei + ii] = Csub;
}
if (unloop) break;
}
Expand All @@ -204,7 +200,7 @@ void tabulate_fusion_gpu_cuda(
const int last_layer_size)
{
if (nloc <= 0) {return;}
tabulate_fusion_fifth_order_polynomial<FPTYPE, MM, KK> <<<nloc, last_layer_size, sizeof(FPTYPE) * MM * last_layer_size>>>(
tabulate_fusion_fifth_order_polynomial<FPTYPE, MM, KK> <<<nloc, last_layer_size>>>(
out,
table, em_x, em, table_info[0], table_info[1], table_info[2], table_info[3], table_info[4], nnei, last_layer_size);
}
Expand Down

0 comments on commit 34c9bc9

Please sign in to comment.