Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

towards an even better backward attention kernel #179

Merged
merged 4 commits into from Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
184 changes: 176 additions & 8 deletions dev/cuda/attention_backward.cu
Expand Up @@ -28,6 +28,7 @@ OMP_NUM_THREADS=32 ./attention_backward 5
#include <cuda_runtime.h>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cooperative_groups/scan.h>
#include "common.h"

// ----------------------------------------------------------------------------
Expand Down Expand Up @@ -560,6 +561,129 @@ __global__ void softmax_autoregressive_backward_kernel5(float* __restrict__ dpre
}
}


// I want `BlockSize` to be statically known to the compiler, thus we get a template here.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

love this comment block

// This kernel takes a step back, and looks at the original CPU code again. We have some simple outer loops
// That are independent, (b, t, h), and then the inner loops over (t2, t3) where we're combining elements -- this is
// where we can reuse data and be more efficient
// => handle b, t, h through block indices; each block does all the work for the (t2, t3) loop cooperatively.
// Now we have two nested loops, and in the inner instruction, we combine indexing from both => this calls for
// loop tiling, and lifting some of the memory ops out of the loop.
// We're in luck here; if we tile so that t3 is the outer loop, we can get a sinlge write op per result, AND also cache
// the t2-indexed part of the computation, which is the problematic one because it contains a multiplication that now we
// do not have to repeat over and over.
// => do an outer t3 loop where each thread gets one t3 index. Then, do an outer t2 loop in steps of BlockSize, and
// prepare BlockSize many elements for the inner loop. Here, each thread calculates one element and stores it in shmem.
// Then, in the inner t2 loop, each thread reads *all* the elements previously stored and does its computations.
// This way, we do 3*BlockSize loads, but BlockSize^2 computation steps => This kernel is now entirely compute bound.
// To fix up the compute issues, as above, we replace ifs in memory reading with min, and also split the inner loop
// into a large region where we don't have to calculate the indicator, and a small, costly region where we do.
template<int BlockSize>
__global__ void __launch_bounds__(BlockSize) softmax_autoregressive_backward_kernel6(float* dpreatt, const float* datt, const float* att,
int B, int T, int C, int NH) {
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
__shared__ float att_bth_s[BlockSize];

int idx = blockIdx.y;
int t = blockIdx.x;

att += idx * T * T;
datt += idx * T * T;
dpreatt += idx * T * T;

int hs = C / NH; // head size
float scale = 1.0f / sqrtf(hs);
const float* att_bth = att + t * T;
const float* datt_bth = datt + t * T;
float* dpreatt_bth = dpreatt + t * T;

int block_steps = ceil_div(t+1, BlockSize);
// very important: This loop condition needs to be the same for all threads.
// even if a thread later on is not going to do any work, it needs to participate in the
// data loading process!
for (int t3f = 0; t3f < block_steps; ++t3f) {
int t3 = t3f * BlockSize + block.thread_rank();
float acc = 0.f;
float at3 = att_bth[t3];
for (int t2b = 0; t2b <= t; t2b += BlockSize) {
int end = min(t + 1 - t2b, BlockSize);
block.sync();
{
int t2i = block.thread_rank();
int t2 = min(t, t2b + t2i);
att_bth_s[t2i] = att_bth[t2] * datt_bth[t2];
}

block.sync();
if(t3f * BlockSize == t2b) {
for (int t2i = 0; t2i < end; t2i++) {
int t2 = t2b + t2i;
float indicator = t2 == t3 ? 1.0f : 0.0f;
acc += att_bth_s[t2i] * (indicator - at3);
}
} else {
for (int t2i = 0; t2i < end; t2i++) {
acc += att_bth_s[t2i] * (0.f - at3);
}
}
}
dpreatt_bth[t3] = scale * acc;
}
}

template<int BlockSize>
__global__ void __launch_bounds__(BlockSize) softmax_autoregressive_backward_kernel7(float* dpreatt, const float* datt, const float* att,
int B, int T, int C, int NH) {
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
__shared__ float block_acc[32];

int idx = blockIdx.y;
int t = blockIdx.x;

att += idx * T * T;
datt += idx * T * T;
dpreatt += idx * T * T;

int hs = C / NH; // head size
float scale = 1.0f / sqrtf(hs);
const float* att_bth = att + t * T;
const float* datt_bth = datt + t * T;
float* dpreatt_bth = dpreatt + t * T;

if(warp.meta_group_rank() == 0) {
block_acc[warp.thread_rank()] = 0;
}

int block_steps = ceil_div(t+1, BlockSize);
// very important: This loop condition needs to be the same for all threads.
// even if a thread later on is not going to do any work, it needs to participate in the
// data loading process!
for (int t3f = 0; t3f < block_steps; ++t3f) {
int t3 = t3f * BlockSize + block.thread_rank();

float at3 = att_bth[min(t, t3)];
float local_sum = 0;
for(int t2 = block.thread_rank(); t2 <= t; t2 += BlockSize) {
local_sum += att_bth[t2] * datt_bth[t2];
}
block.sync();
block_acc[warp.meta_group_rank()] = cg::reduce(warp, local_sum, cg::plus<float>{});
block.sync();
local_sum = cg::reduce(warp, block_acc[warp.thread_rank()], cg::plus<float>{});

float acc = -local_sum * at3;
float at_t2_eq_t3 = at3 * datt_bth[min(t, t3)];
acc += (at_t2_eq_t3 * (1.f - at3) - at_t2_eq_t3 * (0.f - at3));
if(t3 <= t) {
dpreatt_bth[t3] = scale * acc;
}
}
}


// ----------------------------------------------------------------------------
// kernel launchers

Expand Down Expand Up @@ -644,6 +768,42 @@ void launch_softmax_5(float* dpreatt, float* datt, const float* att, int B, int
softmax_autoregressive_backward_kernel5<<<dim3(num_blocks, B*NH), block_size>>>(dpreatt, datt, att, B, T, C, NH);
}

template<class Launcher>
void dispatch_launch(Launcher&& launch, int block_size) {
switch(block_size) {
case 32:
return launch(std::integral_constant<int, 32>{});
case 64:
return launch(std::integral_constant<int, 64>{});
case 128:
return launch(std::integral_constant<int, 128>{});
case 256:
return launch(std::integral_constant<int, 256>{});
case 512:
return launch(std::integral_constant<int, 512>{});
case 1024:
return launch(std::integral_constant<int, 1024>{});
default:
assert(false && "Invalid block size");
}
}

void launch_softmax_6(float* dpreatt, float* datt, const float* att, int B, int T, int C, int NH, int block_size) {
auto launch = [&](auto int_const) {
softmax_autoregressive_backward_kernel6<int_const.value><<<dim3(T, B * NH), int_const.value>>>(dpreatt, datt, att, B, T, C, NH);
};
dispatch_launch(launch, block_size);
}

void launch_softmax_7(float* dpreatt, float* datt, const float* att, int B, int T, int C, int NH, int block_size) {
auto launch = [&](auto int_const) {
constexpr int block_size = int_const.value;
softmax_autoregressive_backward_kernel7<block_size><<<dim3(T, B * NH), block_size>>>
(dpreatt, datt, att, B, T, C, NH);
};
dispatch_launch(launch, block_size);
}

// the sequence of transformations in this compound op is:
// inp (B,T,3C) -> qkvr (B,T,3C) -> preatt (B,NH,T,T) -> att (B,NH,T,T) -> vaccum (B,T,C) -> out (B,T,C)
template<class SoftmaxKernel>
Expand Down Expand Up @@ -671,51 +831,51 @@ void attention_backward1(float* dinp, float* dqkvr, float* dpreatt, float* datt,
unpermute_kernel_backward<<<num_blocks, block_size>>>(dvaccum, dout, B, T, NH, HS);

// backward into datt
cublasSgemmStridedBatched(cublas_handle,
cublasCheck(cublasSgemmStridedBatched(cublas_handle,
CUBLAS_OP_T, CUBLAS_OP_N,
T, T, HS,
&alpha,
v, HS, T * HS,
dvaccum, HS, T * HS,
&beta,
datt, T, T * T,
B * NH);
B * NH));

// backward into dv
cublasSgemmStridedBatched(cublas_handle,
cublasCheck(cublasSgemmStridedBatched(cublas_handle,
CUBLAS_OP_N, CUBLAS_OP_T,
HS, T, T,
&alpha,
dvaccum, HS, T * HS,
att, T, T * T,
&beta,
dv, HS, T * HS,
B * NH);
B * NH));

// backward into preatt
softmax_autoregressive_backward(dpreatt, datt, att, B, T, C, NH, block_size);
cudaCheck(cudaGetLastError());

// backward into q
cublasSgemmStridedBatched(cublas_handle,
cublasCheck(cublasSgemmStridedBatched(cublas_handle,
CUBLAS_OP_N, CUBLAS_OP_N,
HS, T, T,
&alpha,
k, HS, T * HS,
dpreatt, T, T * T,
&beta,
dq, HS, T * HS,
B * NH);
B * NH));
// backward into k
cublasSgemmStridedBatched(cublas_handle,
cublasCheck(cublasSgemmStridedBatched(cublas_handle,
CUBLAS_OP_N, CUBLAS_OP_T,
HS, T, T,
&alpha,
q, HS, T * HS,
dpreatt, T, T * T,
&beta,
dk, HS, T * HS,
B * NH);
B * NH));

// backward into inp
num_blocks = ceil_div(B * NH * T * HS, block_size);
Expand Down Expand Up @@ -750,6 +910,14 @@ void attention_backward(int kernel_num,
attention_backward1(dinp, dqkvr, dpreatt, datt, dvaccum, dout, inp, qkvr, preatt, att, vaccum, B, T, C, NH,
launch_softmax_5, block_size);
break;
case 6:
attention_backward1(dinp, dqkvr, dpreatt, datt, dvaccum, dout, inp, qkvr, preatt, att, vaccum, B, T, C, NH,
launch_softmax_6, block_size);
break;
case 7:
attention_backward1(dinp, dqkvr, dpreatt, datt, dvaccum, dout, inp, qkvr, preatt, att, vaccum, B, T, C, NH,
launch_softmax_7, block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
Expand Down
2 changes: 1 addition & 1 deletion dev/cuda/common.h
Expand Up @@ -5,7 +5,7 @@


template<class T>
T ceil_div(T dividend, T divisor) {
__host__ __device__ T ceil_div(T dividend, T divisor) {
return (dividend + divisor-1) / divisor;
}

Expand Down