Skip to content

Commit

Permalink
relax ln backward hidden size restriction (#161)
Browse files Browse the repository at this point in the history
  • Loading branch information
neopro12 committed Aug 19, 2021
1 parent 4a0b1b0 commit 88a4e6d
Showing 1 changed file with 144 additions and 117 deletions.
261 changes: 144 additions & 117 deletions lightseq/training/csrc/kernels/normalize_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -220,26 +220,28 @@ __global__ void ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad,
float dbetta = 0;
float dgamma = 0;
float dout, val;
if (means == nullptr) {
float vbetta = (float)betta[idx];
float vgamma = (float)gamma[idx];
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
dout = (float)out_grad[offset];
// inp_or_out is output
val = (float)inp_or_out[offset];
dbetta += dout;
dgamma += ((val - vbetta) / add_eps(vgamma) * dout);
offset += y_stride;
}
} else {
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
dout = (float)out_grad[offset];
// inp_or_out is input
val = (float)inp_or_out[offset];
dbetta += dout;
dgamma += ((val - (float)means[r]) * rsqrtf((float)vars[r] + LN_EPSILON) *
dout);
offset += y_stride;
if (idx < width) {
if (means == nullptr) {
float vbetta = (float)betta[idx];
float vgamma = (float)gamma[idx];
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
dout = (float)out_grad[offset];
// inp_or_out is output
val = (float)inp_or_out[offset];
dbetta += dout;
dgamma += ((val - vbetta) / add_eps(vgamma) * dout);
offset += y_stride;
}
} else {
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
dout = (float)out_grad[offset];
// inp_or_out is input
val = (float)inp_or_out[offset];
dbetta += dout;
dgamma += ((val - (float)means[r]) *
rsqrtf((float)vars[r] + LN_EPSILON) * dout);
offset += y_stride;
}
}
}

Expand All @@ -256,8 +258,8 @@ __global__ void ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad,
s2 += g.shfl_down(s2, i);
}

if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
if (threadIdx.x == 0 && idx < width) {
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
Expand Down Expand Up @@ -297,57 +299,66 @@ template <typename T>
__global__ void ker_ln_bw_dinp(T *inp_grad, const T *out_grad,
const T *residual_grad, const T *inp_or_out,
const T *gamma, const T *betta, const T *vars,
const T *means) {
float hidden_dim = blockDim.x * 4;
int offset = blockIdx.x * blockDim.x + threadIdx.x;
float var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON);

// step 0. dxhat = dout * gamma
float4 dxhat = ((const float4 *)out_grad)[offset];
float4 vgamma = ((const float4 *)gamma)[threadIdx.x];
dxhat.x *= vgamma.x;
dxhat.y *= vgamma.y;
dxhat.z *= vgamma.z;
dxhat.w *= vgamma.w;

/*
step 1. xhat = (output - betta) / gamma or
(input - mean) * rsqrtf(var)
*/
float4 xhat = ((const float4 *)inp_or_out)[offset];
if (means == nullptr) {
// inp_or_out is output, xhat = (output - betta) / gamma
float4 vbetta = ((const float4 *)betta)[threadIdx.x];
xhat.x = (xhat.x - vbetta.x) / add_eps(vgamma.x);
xhat.y = (xhat.y - vbetta.y) / add_eps(vgamma.y);
xhat.z = (xhat.z - vbetta.z) / add_eps(vgamma.z);
xhat.w = (xhat.w - vbetta.w) / add_eps(vgamma.w);
} else {
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
float fmean = (float)means[blockIdx.x];
xhat.x = (xhat.x - fmean) * var_rsqrt;
xhat.y = (xhat.y - fmean) * var_rsqrt;
xhat.z = (xhat.z - fmean) * var_rsqrt;
xhat.w = (xhat.w - fmean) * var_rsqrt;
const T *means, int hidden_dim) {
int offset = blockIdx.x * hidden_dim + threadIdx.x;
float4 dxhat, xhat;
float var_rsqrt;

if (threadIdx.x < hidden_dim) {
// step 0. dxhat = dout * gamma
dxhat = ((const float4 *)out_grad)[offset];
float4 vgamma = ((const float4 *)gamma)[threadIdx.x];
dxhat.x *= vgamma.x;
dxhat.y *= vgamma.y;
dxhat.z *= vgamma.z;
dxhat.w *= vgamma.w;

/*
step 1. xhat = (output - betta) / gamma or
(input - mean) * rsqrtf(var)
*/
xhat = ((const float4 *)inp_or_out)[offset];
var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON);
if (means == nullptr) {
// inp_or_out is output, xhat = (output - betta) / gamma
float4 vbetta = ((const float4 *)betta)[threadIdx.x];
xhat.x = (xhat.x - vbetta.x) / add_eps(vgamma.x);
xhat.y = (xhat.y - vbetta.y) / add_eps(vgamma.y);
xhat.z = (xhat.z - vbetta.z) / add_eps(vgamma.z);
xhat.w = (xhat.w - vbetta.w) / add_eps(vgamma.w);
} else {
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
float fmean = (float)means[blockIdx.x];
xhat.x = (xhat.x - fmean) * var_rsqrt;
xhat.y = (xhat.y - fmean) * var_rsqrt;
xhat.z = (xhat.z - fmean) * var_rsqrt;
xhat.w = (xhat.w - fmean) * var_rsqrt;
}
}

/* step2. block reduce sum for dxhat and dxhat*xhat */
float sum_dxhat = dxhat.x + dxhat.y + dxhat.z + dxhat.w;
float sum_dxhat_xhat =
dxhat.x * xhat.x + dxhat.y * xhat.y + dxhat.z * xhat.z + dxhat.w * xhat.w;
float reduce_val[2] = {sum_dxhat, sum_dxhat_xhat};
float reduce_val[2] = {0.f, 0.f};
if (threadIdx.x < hidden_dim) {
reduce_val[0] = dxhat.x + dxhat.y + dxhat.z + dxhat.w;
reduce_val[1] = dxhat.x * xhat.x + dxhat.y * xhat.y + dxhat.z * xhat.z +
dxhat.w * xhat.w;
}
blockReduce<ReduceType::kSum, 2>(reduce_val);
__shared__ float s_sum_dxhat, s_sum_dxhat_xhat;
if (threadIdx.x == 0) {
s_sum_dxhat = reduce_val[0] / hidden_dim;
s_sum_dxhat_xhat = reduce_val[1] / hidden_dim;
float mean_dim = hidden_dim * 4;
s_sum_dxhat = reduce_val[0] / mean_dim;
s_sum_dxhat_xhat = reduce_val[1] / mean_dim;
}
__syncthreads();

/*
step3. compute input gradient
(dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / hidden_dim) * rsqrt(var)
(dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var)
*/
if (threadIdx.x >= hidden_dim) {
return;
}
dxhat.x = (dxhat.x - s_sum_dxhat - xhat.x * s_sum_dxhat_xhat) * var_rsqrt;
dxhat.y = (dxhat.y - s_sum_dxhat - xhat.y * s_sum_dxhat_xhat) * var_rsqrt;
dxhat.z = (dxhat.z - s_sum_dxhat - xhat.z * s_sum_dxhat_xhat) * var_rsqrt;
Expand All @@ -369,74 +380,80 @@ __global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad,
const __half *residual_grad,
const __half *inp_or_out,
const __half *gamma, const __half *betta,
const __half *vars,
const __half *means) {
float hidden_dim = blockDim.x * 8;
int offset = blockIdx.x * blockDim.x + threadIdx.x;
float var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON);

// step 0. dxhat = dout * gamma
float4 vtmp = ((const float4 *)out_grad)[offset];
__half2 *tmp_h2 = reinterpret_cast<__half2 *>(&vtmp);
float4 vgamma = ((const float4 *)gamma)[threadIdx.x];
__half2 *gamma_h2 = reinterpret_cast<__half2 *>(&vgamma);
float2 dxhat[4];
float sum_dxhat = 0;
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 vdout = __half22float2(tmp_h2[i]);
float2 vgamma = __half22float2(gamma_h2[i]);
dxhat[i].x = vdout.x * vgamma.x;
dxhat[i].y = vdout.y * vgamma.y;
sum_dxhat += dxhat[i].x + dxhat[i].y;
}

/*
step 1. xhat = (output - betta) / gamma or
(input - mean) * rsqrtf(var)
*/
vtmp = ((const float4 *)inp_or_out)[offset];
float2 xhat[4];
float sum_dxhat_xhat = 0;
if (means == nullptr) {
// inp_or_out is output, xhat = (output - betta) / gamma
float4 vbetta = ((const float4 *)betta)[threadIdx.x];
__half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta);
const __half *vars, const __half *means,
int hidden_dim) {
int offset = blockIdx.x * hidden_dim + threadIdx.x;

float2 dxhat[4], xhat[4];
float var_rsqrt;
float4 vtmp;
__half2 *tmp_h2;
float reduce_val[2] = {0.f, 0.f};

if (threadIdx.x < hidden_dim) {
// step 0. dxhat = dout * gamma
vtmp = ((const float4 *)out_grad)[offset];
tmp_h2 = reinterpret_cast<__half2 *>(&vtmp);
float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x];
__half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4);
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 vout = __half22float2(tmp_h2[i]);
float2 vdout = __half22float2(tmp_h2[i]);
float2 vgamma = __half22float2(gamma_h2[i]);
float2 vbetta = __half22float2(betta_h2[i]);
xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x);
xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y);
sum_dxhat_xhat += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
dxhat[i].x = vdout.x * vgamma.x;
dxhat[i].y = vdout.y * vgamma.y;
reduce_val[0] += dxhat[i].x + dxhat[i].y;
}
} else {
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
float fmean = (float)means[blockIdx.x];

/*
step 1. xhat = (output - betta) / gamma or
(input - mean) * rsqrtf(var)
*/
vtmp = ((const float4 *)inp_or_out)[offset];
var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON);
if (means == nullptr) {
// inp_or_out is output, xhat = (output - betta) / gamma
float4 vbetta = ((const float4 *)betta)[threadIdx.x];
__half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta);
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 vinp = __half22float2(tmp_h2[i]);
xhat[i].x = (vinp.x - fmean) * var_rsqrt;
xhat[i].y = (vinp.y - fmean) * var_rsqrt;
sum_dxhat_xhat += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
for (int i = 0; i < 4; i++) {
float2 vout = __half22float2(tmp_h2[i]);
float2 vgamma = __half22float2(gamma_h2[i]);
float2 vbetta = __half22float2(betta_h2[i]);
xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x);
xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y);
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
}
} else {
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
float fmean = (float)means[blockIdx.x];
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 vinp = __half22float2(tmp_h2[i]);
xhat[i].x = (vinp.x - fmean) * var_rsqrt;
xhat[i].y = (vinp.y - fmean) * var_rsqrt;
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
}
}
}

/* step2. block reduce sum for dxhat and dxhat*xhat */
float reduce_val[2] = {sum_dxhat, sum_dxhat_xhat};
blockReduce<ReduceType::kSum, 2>(reduce_val);
__shared__ float s_sum_dxhat, s_sum_dxhat_xhat;
if (threadIdx.x == 0) {
s_sum_dxhat = reduce_val[0] / hidden_dim;
s_sum_dxhat_xhat = reduce_val[1] / hidden_dim;
float mean_dim = hidden_dim * 8;
s_sum_dxhat = reduce_val[0] / mean_dim;
s_sum_dxhat_xhat = reduce_val[1] / mean_dim;
}
__syncthreads();

/*
step3. compute input gradient
(dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / hidden_dim) * rsqrt(var)
(dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var)
*/
if (threadIdx.x >= hidden_dim) {
return;
}
if (residual_grad) {
// Add the residual grad,
// usually in pre-layer-norm for transformer layer
Expand Down Expand Up @@ -492,16 +509,21 @@ void launch_ln_bw<float>(float *gamma_grad, float *betta_grad, float *inp_grad,
const float *means, int batch, int hidden_dim,
cudaStream_t stream[2]) {
// compute grad of gamma and betta
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
ker_ln_bw_dgamma_dbetta<float><<<grid_dim, block_dim, 0, stream[0]>>>(
gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means,
batch, hidden_dim);

// compute grad of input
if (hidden_dim % 4 != 0 || hidden_dim > 4096) {
throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 4096");
}
hidden_dim >>= 2;
ker_ln_bw_dinp<<<batch, hidden_dim, 0, stream[1]>>>(
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means);
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
ker_ln_bw_dinp<<<batch, nthread, 0, stream[1]>>>(
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means,
hidden_dim);
}

template <>
Expand All @@ -512,14 +534,19 @@ void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad,
const __half *vars, const __half *means, int batch,
int hidden_dim, cudaStream_t stream[2]) {
// compute grad of gamma and betta
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
ker_ln_bw_dgamma_dbetta<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means,
batch, hidden_dim);

// compute grad of input
if (hidden_dim % 8 != 0 || hidden_dim > 8192) {
throw std::runtime_error("hidden_dim % 8 != 0 || hidden_dim > 8192");
}
hidden_dim >>= 3;
ker_ln_bw_dinp<<<batch, hidden_dim, 0, stream[1]>>>(
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means);
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
ker_ln_bw_dinp<<<batch, nthread, 0, stream[1]>>>(
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means,
hidden_dim);
}

0 comments on commit 88a4e6d

Please sign in to comment.