Skip to content

Commit

Permalink
feat(attention_forward.cu): Use gemm and axpby to do matrix multiplic…
Browse files Browse the repository at this point in the history
…ation
  • Loading branch information
FeSens committed Apr 24, 2024
1 parent a89cce5 commit 8ffa323
Showing 1 changed file with 87 additions and 77 deletions.
164 changes: 87 additions & 77 deletions dev/cuda/attention_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -154,89 +154,99 @@ __global__ void attention_query_key_kernel1(float* preatt, const float* inp,
}

__global__ void attention_query_key_kernel2(float *preatt, const float *inp,
int B, int T, int C, int NH, int HS, int Br, int Bc, int Tr, int Tc, float softmax_scale)
int B, int T, int C, int NH, int HS, int Br, int Bc, int Tr, int Tc, float softmax_scale)
{
using namespace cute;
int block = blockIdx.x;
int b = blockIdx.y; // batch
int h = blockIdx.z; // head

int C3 = C * 3;

extern __shared__ float sram[];
float *smemQ = sram;
float *smemK = &sram[Br * HS]; // K starts after Q

// define the CuTe layout
Layout QK_layout = make_layout(make_shape(B, NH, T, HS), make_stride(T * C3, HS, C3, 1));
Layout Attn_layout = make_layout(make_shape(B, NH, T, T), make_stride(T * T * NH, T * T, T, 1));
Layout sQ_layout = make_layout(make_shape(Br, HS), make_stride(HS, 1));
Layout sK_layout = make_layout(make_shape(Bc, HS), make_stride(HS, 1));

// gmem means global memory
Tensor mQ = make_tensor(make_gmem_ptr(inp), QK_layout); // (B, NH, T, HS)
Tensor mK = make_tensor(make_gmem_ptr(inp + C), QK_layout); // (B, NH, T, HS)
Tensor mPreatt = make_tensor(make_gmem_ptr(preatt), Attn_layout); // (B, NH, T, T)

Tensor gQ_block = mQ(b, h, _, _); // (T,HS)
Tensor gK_block = mK(b, h, _, _); // (T,HS)
Tensor gPreatt_block = mPreatt(b, h, _, _); // (T,T)

Tensor gQ = zipped_divide(gQ_block, make_tile(Br, HS))(make_coord(_, _), block); // (Br,HS) <- This is a little of a hack, but it works
Tensor gK = zipped_divide(gK_block, make_tile(Bc, HS)); // ((Bc,HS), Tile)
Tensor gPreatt = zipped_divide(gPreatt_block, make_tile(Br, T))(make_coord(_, _), block); // (Br,T)

// smem means static memory, it gives a hint to cutlass to select algorithms that efficiently use this memory
Tensor sQ = make_tensor(make_smem_ptr(smemQ), sQ_layout);
Tensor sK = make_tensor(make_smem_ptr(smemK), sK_layout);

int t = threadIdx.x / Bc; // <- This should go away when we make gemm work
int t2 = threadIdx.x % Bc; // <- This should go away when we make gemm work

// Define the thread layout
Layout tQ = make_layout(make_shape(4, 64), make_stride(64, 1));
Layout tK = make_layout(make_shape(4, 64), make_stride(64, 1));
Layout tC = make_layout(make_shape(Int<16>{}, Int<16>{}), make_stride(Int<16>{}, Int<1>{}));

// Lets assing each thread to 'tile' of items, so every thread participates in the copy
Tensor tQgQ = local_partition(gQ, tQ, threadIdx.x);
Tensor tQsQ = local_partition(sQ, tQ, threadIdx.x);

copy(tQgQ, tQsQ);

for (int k_tile = 0; k_tile < Tc; ++k_tile)
{
// My intuition says this should be outside the loop
Tensor tKgK = local_partition(gK(make_coord(_, _), k_tile), tK, threadIdx.x);
Tensor tKsK = local_partition(sK, tK, threadIdx.x);
Tensor ggPreatt = zipped_divide(gPreatt, make_tile(Int<16>{}, Int<16>{}))(make_coord(_, _), k_tile);
using namespace cute;
int block = blockIdx.x;
int b = blockIdx.y; // batch
int h = blockIdx.z; // head

copy(tKgK, tKsK);
int C3 = C * 3;

cp_async_fence(); // Label the end of (potential) cp.async instructions
cp_async_wait<0>(); // Sync on all (potential) cp.async instructions
__syncthreads(); // Wait for all threads to write to smem
extern __shared__ float sram[];
float *smemQ = sram;
float *smemK = &sram[Br * HS]; // K starts after Q

Tensor tCggPreatt = local_partition(ggPreatt, tC, threadIdx.x);
Tensor tCrC = make_tensor_like(tCggPreatt); // (THR_M,THR_N)
// Perform tile multiplication
float val = 0.0f;
if (t + (block * Br) < t2 + (k_tile * Bc))
{
val = -INFINITY;
}
else
// define the CuTe layout
Layout QK_layout = make_layout(make_shape(B, NH, T, HS), make_stride(T * C3, HS, C3, 1));
Layout Attn_layout = make_layout(make_shape(B, NH, T, T), make_stride(T * T * NH, T * T, T, 1));

Layout sQ_layout = make_layout(make_shape(Int<16>{}, Int<64>{}), make_stride(Int<64>{}, Int<1>{}));
Layout sK_layout = make_layout(make_shape(Int<16>{}, Int<64>{}), make_stride(Int<64>{}, Int<1>{}));

// gmem means global memory
Tensor mQ = make_tensor(make_gmem_ptr(inp), QK_layout); // (B, NH, T, HS)
Tensor mK = make_tensor(make_gmem_ptr(inp + C), QK_layout); // (B, NH, T, HS)
Tensor mPreatt = make_tensor(make_gmem_ptr(preatt), Attn_layout); // (B, NH, T, T)

Tensor gQ_block = mQ(b, h, _, _); // (T,HS)
Tensor gK_block = mK(b, h, _, _); // (T,HS)
Tensor gPreatt_block = mPreatt(b, h, _, _); // (T,T)

Tensor gQ = zipped_divide(gQ_block, make_tile(Br, HS))(make_coord(_, _), block); // (Br,HS) <- This is a little of a hack, but it works
Tensor gK = zipped_divide(gK_block, make_tile(Bc, HS)); // ((Bc,HS), Tile)
Tensor gPreatt = zipped_divide(gPreatt_block, make_tile(Br, T))(make_coord(_, _), block); // (Br,T)

// smem means static memory, it gives a hint to cutlass to select algorithms that efficiently use this memory
Tensor sQ = make_tensor(make_smem_ptr(smemQ), sQ_layout);
Tensor sK = make_tensor(make_smem_ptr(smemK), sK_layout);

int t = threadIdx.x / 16; // <- This should go away when we make gemm work
int t2 = threadIdx.x % 16; // <- This should go away when we make gemm work

// Define the thread layout
Layout tQ = make_layout(make_shape(Int<4>{}, Int<64>{}), make_stride(Int<64>{}, Int<1>{}));
Layout tK = make_layout(make_shape(Int<4>{}, Int<64>{}), make_stride(Int<64>{}, Int<1>{}));
Layout tC = make_layout(make_shape(Int<16>{}, Int<16>{}), make_stride(Int<16>{}, Int<1>{}));

// Lets assing each thread to 'tile' of items, so every thread participates in the copy
Tensor tQgQ = local_partition(gQ, tQ, threadIdx.x);
Tensor tQsQ = local_partition(sQ, tQ, threadIdx.x);
Tensor tKsK = local_partition(sK, tK, threadIdx.x);

copy(tQgQ, tQsQ);

for (int k_tile = 0; k_tile < Tc; ++k_tile)
{
// gemm(tQsQ, tKsK, tCrC); //<- This is not working.
for (int hi = 0; hi < HS; ++hi)
{
val += sQ(t, hi) * sK(t2, hi);
}
// My intuition says this should be outside the loop
Tensor tKgK = local_partition(gK(make_coord(_, _), k_tile), tK, threadIdx.x);
Tensor ggPreatt = zipped_divide(gPreatt, make_tile(Int<16>{}, Int<16>{}))(make_coord(_, _), k_tile);

copy(tKgK, tKsK);

cp_async_fence(); // Label the end of (potential) cp.async instructions
cp_async_wait<0>(); // Sync on all (potential) cp.async instructions
__syncthreads(); // Wait for all threads to write to smem

Tensor tCggPreatt = local_partition(ggPreatt, tC, threadIdx.x, Step<_1,_1>{});
Tensor tCsQ = local_partition(sQ, tC, threadIdx.x, Step<_1, X>{}); // (THR_M,BLK_K)
Tensor tCsK = local_partition(sK, tC, threadIdx.x, Step< X,_1>{}); // (THR_M,BLK_K)

Tensor tCrC = make_tensor_like(tCggPreatt); // (THR_M,THR_N)

if (block < k_tile) {
// No need to matrix multiply, this will be all masked
axpby(-INFINITY, tCrC, 0.0, tCggPreatt);
} else if (block <= k_tile)
{
// We know that the diagonal is all -INFINITY
gemm(tCsQ, tCsK, tCrC); //<- This is not working.
__syncthreads();
if(t < t2) {
tCrC(t, t2) = -INFINITY;
}

__syncthreads();
axpby(softmax_scale, tCrC, 0.0, tCggPreatt);

}
else
{
gemm(tCsQ, tCsK, tCrC); //<- This is not working.
axpby(softmax_scale, tCrC, 0.0, tCggPreatt); // multiply tCrC by softmax_scale and add to tCggPreatt
}
__syncthreads();
}
__syncthreads();
val *= softmax_scale;
ggPreatt(t, t2) = val; // <- This should go when gemm works
}
}

__global__ void attention_softmax_kernel1(float* att, const float* preatt,
Expand Down

0 comments on commit 8ffa323

Please sign in to comment.