From 8ffa323fc11c6e69c1dce68e6f3d1a25fff4919d Mon Sep 17 00:00:00 2001 From: FeSens Date: Wed, 24 Apr 2024 00:48:23 -0700 Subject: [PATCH] feat(attention_forward.cu): Use gemm and axpby to do matrix multiplication --- dev/cuda/attention_forward.cu | 164 ++++++++++++++++++---------------- 1 file changed, 87 insertions(+), 77 deletions(-) diff --git a/dev/cuda/attention_forward.cu b/dev/cuda/attention_forward.cu index fa7665093..d18fe7d2e 100644 --- a/dev/cuda/attention_forward.cu +++ b/dev/cuda/attention_forward.cu @@ -154,89 +154,99 @@ __global__ void attention_query_key_kernel1(float* preatt, const float* inp, } __global__ void attention_query_key_kernel2(float *preatt, const float *inp, - int B, int T, int C, int NH, int HS, int Br, int Bc, int Tr, int Tc, float softmax_scale) + int B, int T, int C, int NH, int HS, int Br, int Bc, int Tr, int Tc, float softmax_scale) { - using namespace cute; - int block = blockIdx.x; - int b = blockIdx.y; // batch - int h = blockIdx.z; // head - - int C3 = C * 3; - - extern __shared__ float sram[]; - float *smemQ = sram; - float *smemK = &sram[Br * HS]; // K starts after Q - - // define the CuTe layout - Layout QK_layout = make_layout(make_shape(B, NH, T, HS), make_stride(T * C3, HS, C3, 1)); - Layout Attn_layout = make_layout(make_shape(B, NH, T, T), make_stride(T * T * NH, T * T, T, 1)); - Layout sQ_layout = make_layout(make_shape(Br, HS), make_stride(HS, 1)); - Layout sK_layout = make_layout(make_shape(Bc, HS), make_stride(HS, 1)); - - // gmem means global memory - Tensor mQ = make_tensor(make_gmem_ptr(inp), QK_layout); // (B, NH, T, HS) - Tensor mK = make_tensor(make_gmem_ptr(inp + C), QK_layout); // (B, NH, T, HS) - Tensor mPreatt = make_tensor(make_gmem_ptr(preatt), Attn_layout); // (B, NH, T, T) - - Tensor gQ_block = mQ(b, h, _, _); // (T,HS) - Tensor gK_block = mK(b, h, _, _); // (T,HS) - Tensor gPreatt_block = mPreatt(b, h, _, _); // (T,T) - - Tensor gQ = zipped_divide(gQ_block, make_tile(Br, HS))(make_coord(_, _), block); // (Br,HS) <- This is a little of a hack, but it works - Tensor gK = zipped_divide(gK_block, make_tile(Bc, HS)); // ((Bc,HS), Tile) - Tensor gPreatt = zipped_divide(gPreatt_block, make_tile(Br, T))(make_coord(_, _), block); // (Br,T) - - // smem means static memory, it gives a hint to cutlass to select algorithms that efficiently use this memory - Tensor sQ = make_tensor(make_smem_ptr(smemQ), sQ_layout); - Tensor sK = make_tensor(make_smem_ptr(smemK), sK_layout); - - int t = threadIdx.x / Bc; // <- This should go away when we make gemm work - int t2 = threadIdx.x % Bc; // <- This should go away when we make gemm work - - // Define the thread layout - Layout tQ = make_layout(make_shape(4, 64), make_stride(64, 1)); - Layout tK = make_layout(make_shape(4, 64), make_stride(64, 1)); - Layout tC = make_layout(make_shape(Int<16>{}, Int<16>{}), make_stride(Int<16>{}, Int<1>{})); - - // Lets assing each thread to 'tile' of items, so every thread participates in the copy - Tensor tQgQ = local_partition(gQ, tQ, threadIdx.x); - Tensor tQsQ = local_partition(sQ, tQ, threadIdx.x); - - copy(tQgQ, tQsQ); - - for (int k_tile = 0; k_tile < Tc; ++k_tile) - { - // My intuition says this should be outside the loop - Tensor tKgK = local_partition(gK(make_coord(_, _), k_tile), tK, threadIdx.x); - Tensor tKsK = local_partition(sK, tK, threadIdx.x); - Tensor ggPreatt = zipped_divide(gPreatt, make_tile(Int<16>{}, Int<16>{}))(make_coord(_, _), k_tile); + using namespace cute; + int block = blockIdx.x; + int b = blockIdx.y; // batch + int h = blockIdx.z; // head - copy(tKgK, tKsK); + int C3 = C * 3; - cp_async_fence(); // Label the end of (potential) cp.async instructions - cp_async_wait<0>(); // Sync on all (potential) cp.async instructions - __syncthreads(); // Wait for all threads to write to smem + extern __shared__ float sram[]; + float *smemQ = sram; + float *smemK = &sram[Br * HS]; // K starts after Q - Tensor tCggPreatt = local_partition(ggPreatt, tC, threadIdx.x); - Tensor tCrC = make_tensor_like(tCggPreatt); // (THR_M,THR_N) - // Perform tile multiplication - float val = 0.0f; - if (t + (block * Br) < t2 + (k_tile * Bc)) - { - val = -INFINITY; - } - else + // define the CuTe layout + Layout QK_layout = make_layout(make_shape(B, NH, T, HS), make_stride(T * C3, HS, C3, 1)); + Layout Attn_layout = make_layout(make_shape(B, NH, T, T), make_stride(T * T * NH, T * T, T, 1)); + + Layout sQ_layout = make_layout(make_shape(Int<16>{}, Int<64>{}), make_stride(Int<64>{}, Int<1>{})); + Layout sK_layout = make_layout(make_shape(Int<16>{}, Int<64>{}), make_stride(Int<64>{}, Int<1>{})); + + // gmem means global memory + Tensor mQ = make_tensor(make_gmem_ptr(inp), QK_layout); // (B, NH, T, HS) + Tensor mK = make_tensor(make_gmem_ptr(inp + C), QK_layout); // (B, NH, T, HS) + Tensor mPreatt = make_tensor(make_gmem_ptr(preatt), Attn_layout); // (B, NH, T, T) + + Tensor gQ_block = mQ(b, h, _, _); // (T,HS) + Tensor gK_block = mK(b, h, _, _); // (T,HS) + Tensor gPreatt_block = mPreatt(b, h, _, _); // (T,T) + + Tensor gQ = zipped_divide(gQ_block, make_tile(Br, HS))(make_coord(_, _), block); // (Br,HS) <- This is a little of a hack, but it works + Tensor gK = zipped_divide(gK_block, make_tile(Bc, HS)); // ((Bc,HS), Tile) + Tensor gPreatt = zipped_divide(gPreatt_block, make_tile(Br, T))(make_coord(_, _), block); // (Br,T) + + // smem means static memory, it gives a hint to cutlass to select algorithms that efficiently use this memory + Tensor sQ = make_tensor(make_smem_ptr(smemQ), sQ_layout); + Tensor sK = make_tensor(make_smem_ptr(smemK), sK_layout); + + int t = threadIdx.x / 16; // <- This should go away when we make gemm work + int t2 = threadIdx.x % 16; // <- This should go away when we make gemm work + + // Define the thread layout + Layout tQ = make_layout(make_shape(Int<4>{}, Int<64>{}), make_stride(Int<64>{}, Int<1>{})); + Layout tK = make_layout(make_shape(Int<4>{}, Int<64>{}), make_stride(Int<64>{}, Int<1>{})); + Layout tC = make_layout(make_shape(Int<16>{}, Int<16>{}), make_stride(Int<16>{}, Int<1>{})); + + // Lets assing each thread to 'tile' of items, so every thread participates in the copy + Tensor tQgQ = local_partition(gQ, tQ, threadIdx.x); + Tensor tQsQ = local_partition(sQ, tQ, threadIdx.x); + Tensor tKsK = local_partition(sK, tK, threadIdx.x); + + copy(tQgQ, tQsQ); + + for (int k_tile = 0; k_tile < Tc; ++k_tile) { - // gemm(tQsQ, tKsK, tCrC); //<- This is not working. - for (int hi = 0; hi < HS; ++hi) - { - val += sQ(t, hi) * sK(t2, hi); - } + // My intuition says this should be outside the loop + Tensor tKgK = local_partition(gK(make_coord(_, _), k_tile), tK, threadIdx.x); + Tensor ggPreatt = zipped_divide(gPreatt, make_tile(Int<16>{}, Int<16>{}))(make_coord(_, _), k_tile); + + copy(tKgK, tKsK); + + cp_async_fence(); // Label the end of (potential) cp.async instructions + cp_async_wait<0>(); // Sync on all (potential) cp.async instructions + __syncthreads(); // Wait for all threads to write to smem + + Tensor tCggPreatt = local_partition(ggPreatt, tC, threadIdx.x, Step<_1,_1>{}); + Tensor tCsQ = local_partition(sQ, tC, threadIdx.x, Step<_1, X>{}); // (THR_M,BLK_K) + Tensor tCsK = local_partition(sK, tC, threadIdx.x, Step< X,_1>{}); // (THR_M,BLK_K) + + Tensor tCrC = make_tensor_like(tCggPreatt); // (THR_M,THR_N) + + if (block < k_tile) { + // No need to matrix multiply, this will be all masked + axpby(-INFINITY, tCrC, 0.0, tCggPreatt); + } else if (block <= k_tile) + { + // We know that the diagonal is all -INFINITY + gemm(tCsQ, tCsK, tCrC); //<- This is not working. + __syncthreads(); + if(t < t2) { + tCrC(t, t2) = -INFINITY; + } + + __syncthreads(); + axpby(softmax_scale, tCrC, 0.0, tCggPreatt); + + } + else + { + gemm(tCsQ, tCsK, tCrC); //<- This is not working. + axpby(softmax_scale, tCrC, 0.0, tCggPreatt); // multiply tCrC by softmax_scale and add to tCggPreatt + } + __syncthreads(); } - __syncthreads(); - val *= softmax_scale; - ggPreatt(t, t2) = val; // <- This should go when gemm works - } } __global__ void attention_softmax_kernel1(float* att, const float* preatt,