diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 14f7dcf4f41ad..e3db937dd2c58 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8135,7 +8135,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( } // V /= S - const float S_inv = 1.0f/S; + const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; ggml_vec_scale_f32(DV, VKQ32, S_inv); // dst indices diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index b88d6b11d30e3..5e5bdf90a557d 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -953,7 +953,53 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21); //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22); //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23); - ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24); + //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24); + ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + int32_t nqptg, + int32_t ncpsg) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + GGML_UNUSED(op); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_%s", + "flash_attn_ext_blk"); + + snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d", + base, + nqptg, + ncpsg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + //ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0); + //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1); + //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2); + //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3); + + //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20); + //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21); + //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22); + //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23); + ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24); + ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index ef049507384d8..1034e4bbf6596 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -141,6 +141,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( bool has_mask, int32_t ncpsg); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + int32_t nqptg, + int32_t ncpsg); + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_library_t lib, const struct ggml_tensor * op, diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 52a6393b250be..529daab763521 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -70,11 +70,19 @@ // function constants offsets #define FC_FLASH_ATTN_EXT_PAD 100 -#define FC_FLASH_ATTN_EXT 200 -#define FC_FLASH_ATTN_EXT_VEC 300 -#define FC_FLASH_ATTN_EXT_VEC_REDUCE 400 -#define FC_MUL_MV 500 -#define FC_MUL_MM 600 +#define FC_FLASH_ATTN_EXT_BLK 200 +#define FC_FLASH_ATTN_EXT 300 +#define FC_FLASH_ATTN_EXT_VEC 400 +#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500 +#define FC_MUL_MV 600 +#define FC_MUL_MM 700 + +// op-specific constants +#define OP_FLASH_ATTN_EXT_NQPTG 8 +#define OP_FLASH_ATTN_EXT_NCPSG 64 + +#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1 +#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32 // kernel argument structs // @@ -262,6 +270,17 @@ typedef struct { uint64_t nb33; } ggml_metal_kargs_flash_attn_ext_pad; +typedef struct { + int32_t ne01; + int32_t ne30; + int32_t ne31; + int32_t ne32; + int32_t ne33; + uint64_t nb31; + uint64_t nb32; + uint64_t nb33; +} ggml_metal_kargs_flash_attn_ext_blk; + typedef struct { int32_t ne01; int32_t ne02; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 9e50fb1940cb9..22803b3512a70 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1904,19 +1904,19 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) { const bool has_mask = op->src[3] != nullptr; if (ggml_metal_op_flash_attn_ext_use_vec(op)) { - const bool has_kvpad = ne11 % 32 != 0; + const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0; if (has_kvpad) { - res += 32*( + res += OP_FLASH_ATTN_EXT_VEC_NCPSG*( nb11*ne12*ne13 + nb21*ne22*ne23 + (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0)); } } else { - const bool has_kvpad = ne11 % 64 != 0; + const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0; if (has_kvpad) { - res += 64*( + res += OP_FLASH_ATTN_EXT_NCPSG*( nb11*ne12*ne13 + nb21*ne22*ne23 + (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0)); @@ -1926,6 +1926,44 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) { return res; } +size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + //GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); + + size_t res = 0; + + const bool has_mask = op->src[3] != nullptr; + + if (!has_mask) { + return res; + } + + const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op); + + // this optimization is not useful for the vector kernels + if (is_vec) { + return res; + } + + const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG; + const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG; + + const int64_t ne1 = (ne01 + nqptg - 1)/nqptg; + const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg; + + res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32); + + return res; +} + size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) { assert(op->op == GGML_OP_FLASH_ATTN_EXT); @@ -2020,18 +2058,23 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_buffer_id bid_pad = bid_dst; bid_pad.offs += ggml_nbytes(op); - ggml_metal_buffer_id bid_tmp = bid_pad; - bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_pad(op); + ggml_metal_buffer_id bid_blk = bid_pad; + bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op); + + ggml_metal_buffer_id bid_tmp = bid_blk; + bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op); if (!ggml_metal_op_flash_attn_ext_use_vec(op)) { // half8x8 kernel - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !! + const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup + const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup GGML_ASSERT(nqptg <= 32); GGML_ASSERT(nqptg % 8 == 0); GGML_ASSERT(ncpsg % 32 == 0); + bool need_sync = false; + const bool has_kvpad = ne11 % ncpsg != 0; if (has_kvpad) { @@ -2069,11 +2112,46 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1); - ggml_metal_op_concurrency_reset(ctx); + need_sync = true; } else { assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0); } + if (has_mask) { + assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0); + + ggml_metal_kargs_flash_attn_ext_blk args0 = { + /*.ne01 =*/ ne01, + /*.ne30 =*/ ne30, + /*.ne31 =*/ ne31, + /*.ne32 =*/ ne32, + /*.ne33 =*/ ne33, + /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, + /*.nb33 =*/ nb33, + }; + + ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_src3, 1); + ggml_metal_encoder_set_buffer (enc, bid_blk, 2); + + const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg); + const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg); + + ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1); + + need_sync = true; + } else { + assert(ggml_metal_op_flash_attn_ext_extra_blk(op) == 0); + } + + if (need_sync) { + ggml_metal_op_concurrency_reset(ctx); + } + const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0; // 2*(2*ncpsg) @@ -2150,7 +2228,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, bid_src3, 4); ggml_metal_encoder_set_buffer (enc, bid_src4, 5); ggml_metal_encoder_set_buffer (enc, bid_pad, 6); - ggml_metal_encoder_set_buffer (enc, bid_dst, 7); + ggml_metal_encoder_set_buffer (enc, bid_blk, 7); + ggml_metal_encoder_set_buffer (enc, bid_dst, 8); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); @@ -2158,14 +2237,16 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { #undef FATTN_SMEM } else { // half4x4 kernel - const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - const int64_t nkpsg = 1*ncpsg; + const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup + const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !! + const int nkpsg = 1*ncpsg; GGML_ASSERT(nqptg <= 32); GGML_ASSERT(nqptg % 1 == 0); GGML_ASSERT(ncpsg % 32 == 0); + bool need_sync = false; + const bool has_kvpad = ne11 % ncpsg != 0; if (has_kvpad) { @@ -2203,11 +2284,15 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1); - ggml_metal_op_concurrency_reset(ctx); + need_sync = true; } else { assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0); } + if (need_sync) { + ggml_metal_op_concurrency_reset(ctx); + } + // ne00 + 2*ncpsg*(nsg) // for each query, we load it as f16 in shared memory (ne00) // and store the soft_max values and the mask diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 6a6d8a7977a7c..d4cb9446212d9 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -40,6 +40,7 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op); bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op); size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op); +size_t ggml_metal_op_flash_attn_ext_extra_blk(const struct ggml_tensor * op); size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op); int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index e53f37b29c1a4..7afc881fa7012 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -194,6 +194,7 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_ case GGML_OP_FLASH_ATTN_EXT: { res += ggml_metal_op_flash_attn_ext_extra_pad(tensor); + res += ggml_metal_op_flash_attn_ext_extra_blk(tensor); res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor); } break; default: diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 7434ee62ef8be..ac2a2f10f9a64 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4451,7 +4451,7 @@ kernel void kernel_leaky_relu_f32_4( constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]]; -constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 24)]]; +constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]]; // pad the last chunk of C elements of k and v into a an extra pad buffer kernel void kernel_flash_attn_ext_pad( @@ -4519,6 +4519,65 @@ kernel void kernel_flash_attn_ext_pad( } } +constant int32_t FC_flash_attn_ext_blk_nqptg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 24)]]; +constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 25)]]; + +// scan the blocks of the mask that are not masked +// 0 - masked (i.e. full of -INF, skip) +// 1 - not masked (i.e. at least one element of the mask is not -INF) +kernel void kernel_flash_attn_ext_blk( + constant ggml_metal_kargs_flash_attn_ext_blk & args, + device const char * mask, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { + // block size C x Q + const int32_t Q = FC_flash_attn_ext_blk_nqptg; + const int32_t C = FC_flash_attn_ext_blk_ncpsg; + + constexpr short NW = N_SIMDWIDTH; + + const int32_t i3 = tgpig[2]/args.ne32; + const int32_t i2 = tgpig[2]%args.ne32; + const int32_t i1 = tgpig[1]; + const int32_t i0 = tgpig[0]; + + char res = i0*C + C > args.ne30 ? 1 : 0; + + device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg; + + // fast route + if (res == 0) { + if (simd_max(*mask_src) > -MAXHALF/2) { + res = 1; + } + } + + // detailed check of the elements of the block + if ((C > NW || Q > 1) && res == 0) { + half m = -MAXHALF; + + FOR_UNROLL (short j = 0; j < Q; ++j) { + FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) { + m = max(m, mask_src[ii*NW]); + } + + mask_src += args.nb31/2; + } + + if (simd_max(m) > -MAXHALF/2) { + res = 1; + } + } + + const int32_t nblk1 = ((args.ne01 + Q - 1)/Q); + const int32_t nblk0 = ((args.ne30 + C - 1)/C); + + if (tiisg == 0) { + dst[((i3*args.ne32 + i2)*nblk1 + i1)*nblk0 + i0] = res; + } +} + constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]]; constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]]; constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]]; @@ -4573,6 +4632,7 @@ void kernel_flash_attn_ext_impl( device const char * mask, device const char * sinks, device const char * pad, + device const char * blk, device char * dst, threadgroup half * shmem_f16, uint3 tgpig, @@ -4638,6 +4698,13 @@ void kernel_flash_attn_ext_impl( pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); } + { + const int32_t nblk1 = ((args.ne01 + Q - 1)/Q); + const int32_t nblk0 = ((args.ne11 + C - 1)/C); + + blk += (((iq3%args.ne33)*args.ne32 + (iq2%args.ne32))*nblk1 + iq1/Q)*nblk0; + } + { q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03; @@ -4697,11 +4764,14 @@ void kernel_flash_attn_ext_impl( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < args.ne11; ic0 += C) { - int ic = ic0; + for (int ic0 = 0; ; ++ic0) { + int ic = ic0*C; + if (ic >= args.ne11) { + break; + } // the last partial chunk uses the pad buffer as source - if (FC_flash_attn_ext_has_kvpad && ic0 + C > args.ne11) { + if (FC_flash_attn_ext_has_kvpad && ic + C > args.ne11) { k = pad; v = k + args.nb11*C*args.ne_12_2*args.ne_12_3; mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3; @@ -4740,6 +4810,14 @@ void kernel_flash_attn_ext_impl( // read the mask into shared mem if (FC_flash_attn_ext_has_mask) { + if (blk[ic0] == 0) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + pm2[jj] += NW; + } + + continue; + } + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { const short j = jj*NSG + sgitg; @@ -4752,6 +4830,9 @@ void kernel_flash_attn_ext_impl( pm2[jj] += NW; } +#if 0 + // note: old -INF block optimization - obsoleted by pre-computing non-masked blocks + threadgroup_barrier(mem_flags::mem_threadgroup); // used to detect blocks full of -INF @@ -4770,6 +4851,7 @@ void kernel_flash_attn_ext_impl( continue; } +#endif } // Q*K^T @@ -4787,26 +4869,24 @@ void kernel_flash_attn_ext_impl( constexpr short NC = (C/8)/NSG; - // TODO: not good to unroll for large contexts - not sure why? + // note: do not unroll for large heads + #pragma unroll (DK <= 64 ? NC : 1) for (short cc = 0; cc < NC; ++cc) { qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); - if (DK8 % 16 != 0) { + if (DK % 16 != 0) { k8x8_t mk; q8x8_t mq; FOR_UNROLL (short i = 0; i < DK8; ++i) { simdgroup_barrier(mem_flags::mem_none); - simdgroup_load(mk, pk, NS10, 0, true); - simdgroup_load(mq, pq, DK); + simdgroup_load(mk, pk + 8*i, NS10, 0, true); + simdgroup_load(mq, pq + 8*i, DK); simdgroup_barrier(mem_flags::mem_none); simdgroup_multiply_accumulate(mqk, mq, mk, mqk); - - pk += 8; - pq += 8; } } else { k8x8_t mk[2]; @@ -4815,26 +4895,22 @@ void kernel_flash_attn_ext_impl( FOR_UNROLL (short i = 0; i < DK8/2; ++i) { simdgroup_barrier(mem_flags::mem_none); - simdgroup_load(mk[0], pk + 0*8, NS10, 0, true); - simdgroup_load(mk[1], pk + 1*8, NS10, 0, true); + simdgroup_load(mq[0], pq + 0*8 + 16*i, DK); + simdgroup_load(mq[1], pq + 1*8 + 16*i, DK); - simdgroup_load(mq[0], pq + 0*8, DK); - simdgroup_load(mq[1], pq + 1*8, DK); + simdgroup_load(mk[0], pk + 0*8 + 16*i, NS10, 0, true); + simdgroup_load(mk[1], pk + 1*8 + 16*i, NS10, 0, true); simdgroup_barrier(mem_flags::mem_none); simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk); simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk); - - pk += 16; - pq += 16; } } simdgroup_store(mqk, ps, SH, 0, false); - pk += 8*(NSG*NS10 - DK8); - pq += 8*(NSG*0 - DK8); + pk += 8*(NSG*NS10); ps += 8*(NSG); } } else { @@ -4968,27 +5044,50 @@ void kernel_flash_attn_ext_impl( } { - auto sst = ss; - device const v_t * pv = (device const v_t *) (v + ic*args.nb21); pv += 8*sgitg; - FOR_UNROLL (short cc = 0; cc < C/8; ++cc) { - s8x8_t vs; - simdgroup_load(vs, sst, SH, 0, false); + if (DV <= 64) { + FOR_UNROLL (short cc = 0; cc < C/8; ++cc) { + s8x8_t vs; + simdgroup_load(vs, ss + 8*cc, SH, 0, false); - FOR_UNROLL (short ii = 0; ii < NO; ++ii) { - v8x8_t mv; + FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { + v8x8_t mv[2]; - simdgroup_load(mv, pv, NS20, 0, false); - simdgroup_multiply_accumulate(lo[ii], vs, mv, lo[ii]); + simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false); + simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false); - pv += 8*NSG; + simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]); + } + + pv += 8*NS20; } + } else { + FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) { + s8x8_t vs[2]; + + simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false); + simdgroup_load(vs[1], ss + 16*cc + 8, SH, 0, false); - pv += 8*(NS20 - NO*NSG); - sst += 8; + FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { + v8x8_t mv[4]; + + simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false); + simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false); + simdgroup_load(mv[2], pv + 0*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false); + simdgroup_load(mv[3], pv + 8*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false); + + simdgroup_multiply_accumulate(lo[2*ii + 0], vs[0], mv[0], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs[0], mv[1], lo[2*ii + 1]); + simdgroup_multiply_accumulate(lo[2*ii + 0], vs[1], mv[2], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs[1], mv[3], lo[2*ii + 1]); + } + + pv += 2*8*NS20; + } } } @@ -5102,7 +5201,7 @@ void kernel_flash_attn_ext_impl( device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4; - const float scale = 1.0f/S[jj]; + const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj]; if (DV4 % NW == 0) { FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) { @@ -5147,8 +5246,8 @@ template< void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), short DK, // K head size short DV, // V head size - short Q = 8, // queries per threadgroup - short C = 64> // cache items per threadgroup + short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup + short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup kernel void kernel_flash_attn_ext( constant ggml_metal_kargs_flash_attn_ext & args, device const char * q, @@ -5157,13 +5256,14 @@ kernel void kernel_flash_attn_ext( device const char * mask, device const char * sinks, device const char * pad, + device const char * blk, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { #define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C -#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg +#define FWD_ARGS args, q, k, v, mask, sinks, pad, blk, dst, shmem_f16, tgpig, tiisg, sgitg switch (FC_flash_attn_ext_nsg) { // note: disabled cases to reduce library load time //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; @@ -5310,9 +5410,9 @@ template< void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), short DK, // K head size short DV, // V head size - short NE = 4, // head elements per thread - short Q = 1, // queries per threadgroup - short C = 32, // cache items per threadgroup + short NE, // head elements per thread + short Q, // queries per threadgroup + short C, // cache items per threadgroup short NSG> // number of simd groups void kernel_flash_attn_ext_vec_impl( constant ggml_metal_kargs_flash_attn_ext_vec & args, @@ -5427,8 +5527,8 @@ void kernel_flash_attn_ext_vec_impl( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) { - int ic = ic0 + C*sgitg; + for (int ic0 = iwg*NSG + sgitg; ; ic0 += NWG*NSG) { + int ic = ic0*C; if (ic >= args.ne11) { break; } @@ -5721,7 +5821,7 @@ void kernel_flash_attn_ext_vec_impl( device float4 * dst4 = (device float4 *) dst; device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results - const float S = NWG == 1 ? 1.0f/ss[0] : 1.0f; + const float S = NWG == 1 ? (ss[0] == 0.0f ? 0.0f : 1.0f/ss[0]) : 1.0f; // interleave the workgroup data for (short i = tiisg; i < DV4; i += NW) { @@ -5759,8 +5859,8 @@ template< short DK, // K head size short DV, // V head size short NE = 4, // head elements per thread - short Q = 1, // queries per threadgroup - short C = 32> // cache items per threadgroup + short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup + short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup kernel void kernel_flash_attn_ext_vec( constant ggml_metal_kargs_flash_attn_ext_vec & args, device const char * q, @@ -5899,7 +5999,8 @@ kernel void kernel_flash_attn_ext_vec_reduce( const float m = simd_max(M); const float ms = exp(M - m); - S = 1.0f/simd_sum(S*ms); + S = simd_sum(S*ms); + S = S == 0.0f ? 0.0f : 1.0f/S; const short DV4 = DV/4;