From 9e62e7e10e493db1d55cc0b531f4f44a44ea4a8c Mon Sep 17 00:00:00 2001 From: Shupei Fan Date: Sun, 22 Sep 2024 22:03:18 +0800 Subject: [PATCH 1/2] [metal-kernel] add flash_attn_ext_scalar_f16 implementation --- ggml/src/ggml-metal.metal | 288 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 288 insertions(+) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 2b200032394b1..122bfb5eb1080 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2799,6 +2799,294 @@ kernel void kernel_flash_attn_ext_vec_f16( template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; //template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; +half dequantize_load_f16(device const half *xb, short il) { + return xb[il]; +} + +half dequantize_load_q8_0(device const block_q8_0 *xb, short il) { + device const block_q8_0 *xb_ = &xb[il / QK8_0]; + return xb_->d * xb_->qs[il % QK8_0]; +} + +template // head size, queries per threadgroup, cache items per threadgroup +kernel void kernel_flash_attn_ext_scalar_f16( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb23, + constant uint64_t & nb31, + constant int64_t & ne1, + constant int64_t & ne2, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + constant float & logit_softcap, + threadgroup half * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]; + + const short NW = N_SIMDWIDTH; + const short SH = (C + Q); // shared memory per simdgroup in (half) + + const short T = D + 2*nsg*SH; // shared memory size per query in (half) + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + const uint32_t h = iq2; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup half * sr = (threadgroup half *) (shared + sgitg*D + 1*T); // scratch buffer for the results + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + half lo[D/NW]; + + // load heads from Q to shared memory + device const float * q_ = (device const float *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + + for (short i = tiisg; i < D; i += NW) { + if (iq1 < ne01) { + sq[i] = (half) q_[i]; + } else { + sq[i] = 0.0h; + } + } + + // zero out lo + for (short i = tiisg; i < D; i += NW) { + lo[i/NW] = 0.0h; + } + + // zero out shared memory SH + for (short i = tiisg; i < SH; i += NW) { + ss[i] = 0.0h; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; + + // assume K and V are same shape + const short ne22 = ne12; + const short ne23 = ne13; + + // broadcast + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; + + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; + + // k indices + const short ik2 = iq2 / rk2; + const short ik3 = iq3 / rk3; + + // v indices + const short iv2 = iq2 / rv2; + const short iv3 = iq3 / rv3; + + // load the queries from shared memory into local memory + half mq[D]; + + for (short ii = 0; ii < D; ii += NW) { + short i = ii + tiisg; + mq[i] = sq[i]; + } + + // pointer to the mask + device const half * mp = (device const half *) (mask + iq1*nb31); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + + // Q*K^T + { +// #pragma unroll + for (short cc = 0; cc < C; ++cc) { + float mqk = 0.0; + + device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + cc)*nb11 + ik2*nb12 + ik3*nb13)); + +#pragma unroll + for (short ii = 0; ii < D; ii += NW) { + const short i = ii + tiisg; + mqk += mq[i] * dequantize_load(pk, i); + } + + // reduce the results from the threads in the simdgroup + mqk += simd_shuffle_down(mqk, 16); + mqk += simd_shuffle_down(mqk, 8); + mqk += simd_shuffle_down(mqk, 4); + mqk += simd_shuffle_down(mqk, 2); + mqk += simd_shuffle_down(mqk, 1); + + // mqk = mqk*scale + mask*slope + if (tiisg == 0) { + mqk *= scale; + + if (logit_softcap != 0.0f) { + mqk = logit_softcap*precise::tanh(mqk); + } + + if (mask != q) { + mqk += (mp[ic + cc])*slope; + } + + ss[cc] = mqk; + } + } + } + + // online softmax + { + const short p = tiisg; + + const float m = M; + const float s = ss[p]; + + M = simd_max(max(M, s)); + + const float ms = exp(m - M); + const float vs = exp(s - M); + + S = S*ms + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[p] = vs; + + // O = diag(ms)*O +#pragma unroll + for (short ii = 0; ii < D; ii += NW) { + const short i = ii + tiisg; + lo[i/NW] *= ms; + } + } + + // O = O + (Q*K^T)*V + { +// #pragma unroll + for (short cc = 0; cc < C; ++cc) { + device const block_q * pv = (device const block_q *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23)); + +#pragma unroll + for (short ii = 0; ii < D; ii += NW) { + const short i = ii + tiisg; + + lo[i/NW] += dequantize_load(pv, i) * ss[cc]; + } + } + } + + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + } + + // store results to shared memory + for (short ii = 0; ii < D; ii += NW) { + short i = ii + tiisg; + sr[i] = lo[ii/NW]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // parallel reduce + for (short r = nsg/2; r > 0; r >>= 1) { + if (sgitg < r) { + const float S0 = ss[ 0]; + const float S1 = ss[r*SH + 0]; + + const float M0 = ss[ 1]; + const float M1 = ss[r*SH + 1]; + + const float M = max(M0, M1); + + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); + + const float S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (short ii = 0; ii < D; ii += NW) { + short i = ii + tiisg; + sr[i] = sr[i]*ms0 + sr[i + r*D]*ms1; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + const float S = ss[0]; + + for (short ii = 0; ii < D; ii += NW) { + short i = ii + tiisg; + dst[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D + i] = sr[i]/S; + } + } +} + +template [[host_name("kernel_flash_attn_ext_scalar_f16_h32")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16; +template [[host_name("kernel_flash_attn_ext_scalar_f16_h64")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16; +template [[host_name("kernel_flash_attn_ext_scalar_f16_h96")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16; +template [[host_name("kernel_flash_attn_ext_scalar_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16; + +template [[host_name("kernel_flash_attn_ext_scalar_q8_0_h32")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16; +template [[host_name("kernel_flash_attn_ext_scalar_q8_0_h64")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16; +template [[host_name("kernel_flash_attn_ext_scalar_q8_0_h96")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16; +template [[host_name("kernel_flash_attn_ext_scalar_q8_0_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16; + template kernel void kernel_cpy( device const void * src0, From d436f5ba2c59bc8033e2eefe6d2a1d862abb2506 Mon Sep 17 00:00:00 2001 From: Shupei Fan Date: Tue, 24 Sep 2024 12:33:32 +0800 Subject: [PATCH 2/2] [metal] (HACK!!!) force use kernel_flash_attn_ext_scalar_f16 in FA --- ggml/src/ggml-metal.m | 91 +++++++++++++++++++++++++++++++------------ 1 file changed, 66 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 8ff16983e0939..1b51a251859fe 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -206,6 +206,14 @@ //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H32, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H32, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H128, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F16_F16, @@ -702,6 +710,14 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction); //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H32, flash_attn_ext_scalar_f16_h32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H64, flash_attn_ext_scalar_f16_h64, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H96, flash_attn_ext_scalar_f16_h96, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H128,flash_attn_ext_scalar_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H32, flash_attn_ext_scalar_q8_0_h32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H64, flash_attn_ext_scalar_q8_0_h64, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H96, flash_attn_ext_scalar_q8_0_h96, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H128,flash_attn_ext_scalar_q8_0_h128, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); @@ -852,15 +868,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx case GGML_OP_LEAKY_RELU: return true; case GGML_OP_FLASH_ATTN_EXT: - if (op->src[1]->type != GGML_TYPE_F16) { - return false; - } - if (op->src[2]->type != GGML_TYPE_F16) { - return false; - } - if (op->src[0]->ne[0] == 256) { - return false; - } + // if (op->src[1]->type != GGML_TYPE_F16) { + // return false; + // } + // if (op->src[2]->type != GGML_TYPE_F16) { + // return false; + // } + // if (op->src[0]->ne[0] == 256) { + // return false; + // } return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: @@ -2765,6 +2781,8 @@ static void ggml_metal_encode_node( GGML_ASSERT(ne11 % 32 == 0); GGML_ASSERT(src0->type == GGML_TYPE_F32); + // K, V shall have the same type + GGML_ASSERT(src1->type == src2->type); GGML_ASSERT(ggml_are_same_shape (src1, src2)); @@ -2811,33 +2829,56 @@ static void ggml_metal_encode_node( bool use_vec_kernel = false; - if (ne01 >= 4 || (ne00%128 != 0)) { + if (false) { switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; default: - { - GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } } } else { use_vec_kernel = true; - switch (ne00) { - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; - //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; - default: - { - GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } + if (src1->type == GGML_TYPE_F16) { + switch (ne00) { + case 32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H32].pipeline; break; + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H64].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H96].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H128].pipeline; break; + //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } else if (src1->type == GGML_TYPE_Q8_0) { + switch (ne00) { + case 32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H32].pipeline; break; + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H64].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H96].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H128].pipeline; break; + //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } else { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); } }