From a1c004ef2e056cdeffcd47aaac196883bb123a3a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Jan 2024 17:42:55 +0200 Subject: [PATCH 001/100] ggml : add ggml_flash_attn_ext API --- ggml-metal.m | 50 +++++++ ggml-metal.metal | 29 ++++ ggml.c | 298 ++++++++++++++++++++++++++++++++++++- ggml.h | 9 ++ llama.cpp | 80 +++++----- tests/test-backend-ops.cpp | 28 ++++ 6 files changed, 456 insertions(+), 38 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 912ddc83f7d9c..6d88d5c36a8ad 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -147,6 +147,7 @@ GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -511,6 +512,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, flash_attn_ext_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); @@ -665,6 +667,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_PAD: case GGML_OP_ARGSORT: case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: return true; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: @@ -2161,6 +2164,53 @@ static bool ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_FLASH_ATTN_EXT: + { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + + struct ggml_tensor * src2 = gf->nodes[i]->src[2]; + struct ggml_tensor * src3 = gf->nodes[i]->src[3]; + + size_t offs_src2 = 0; + size_t offs_src3 = 0; + + id id_src2 = src2 ? ggml_metal_get_buffer(ctx, src2, &offs_src2) : nil; + id id_src3 = src3 ? ggml_metal_get_buffer(ctx, src3, &offs_src3) : nil; + + float scale; + memcpy(&scale, dst->op_params, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16].pipeline; + + // TODO: extend if necessary + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:16]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&scale length:sizeof( float) atIndex:21]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: diff --git a/ggml-metal.metal b/ggml-metal.metal index 029578dc54dbd..b79a1ba5634a7 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1959,6 +1959,35 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; } +kernel void kernel_flash_attn_ext_f16( + device const half * q, + device const half * k, + device const half * v, + device const half * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant float & scale, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + // TODO: implement +} + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, diff --git a/ggml.c b/ggml.c index cbf2d4bddddb8..e01d938ceb681 100644 --- a/ggml.c +++ b/ggml.c @@ -1650,6 +1650,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "LEAKY_RELU", "FLASH_ATTN", + "FLASH_ATTN_EXT", "FLASH_FF", "FLASH_ATTN_BACK", "WIN_PART", @@ -1674,7 +1675,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72"); +static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1736,6 +1737,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "leaky_relu(x)", "flash_attn(x)", + "flash_attn_ext(x)", "flash_ff(x)", "flash_attn_back(x)", "win_part(x)", @@ -1760,7 +1762,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72"); +static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5678,6 +5680,46 @@ struct ggml_tensor * ggml_flash_attn( return result; } +// ggml_flash_attn_ext + +struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale) { + GGML_ASSERT(ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + if (mask) { + GGML_ASSERT(ggml_is_contiguous(mask)); + GGML_ASSERT(mask->ne[2] == 1); + GGML_ASSERT(mask->ne[3] == 1); + //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); + } + + bool is_node = false; + + if (q->grad || k->grad || v->grad) { + is_node = true; + } + + //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne); + + float params[] = { scale }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_FLASH_ATTN_EXT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = mask; + + return result; +} + // ggml_flash_ff struct ggml_tensor * ggml_flash_ff( @@ -13212,6 +13254,251 @@ static void ggml_compute_forward_flash_attn( } } +// ggml_compute_forward_flash_attn_ext + +static void ggml_compute_forward_flash_attn_ext_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + const int64_t P = nek1 - N; + const int64_t M = P + N; + + const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne1 == N); + GGML_ASSERT(P >= 0); + + GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev1 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nek1 == N + P); + GGML_ASSERT(nev1 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float scale = 1.0f; + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); + + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } + + if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { + for (int64_t ic = 0; ic < nek1; ++ic) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2 % nek2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16(neq0, + S + i1, + (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + } else { + for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2 % nek2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16_unroll(neq0, nbk1, + S + i1, + ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + } + + // scale + ggml_vec_scale_f32(nek1, S, scale); + + if (mask) { + const float * mp = (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]); + ggml_vec_acc_f32(M, S, mp); + } + + // softmax + // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero. + // dont forget to set their S values to zero + { + float max = -INFINITY; + ggml_vec_max_f32(M, &max, S); + + ggml_float sum = 0.0; + { +#ifdef GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(S, 1, &max, S, 1, Mup); + vvexpf(S, S, &Mup); + ggml_vec_sum_f32(Mup, &sum, S); +#else + uint16_t scvt[GGML_SOFT_MAX_UNROLL]; + ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; + + for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { + float * SS = S + i; + + for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { + if (SS[j] == -INFINITY) { + SS[j] = 0.0f; + } else { + ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); + memcpy(&scvt[j], &s, sizeof(uint16_t)); + const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]); + sump[j] += (ggml_float)val; + SS[j] = val; + } + } + } + + for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { + sum += sump[i]; + } +#endif + } + + assert(sum > 0.0); + + sum = 1.0/sum; + ggml_vec_scale_f32(M, S, sum); + +#ifndef NDEBUG + for (int i = 0; i < M; ++i) { + assert(!isnan(S[i])); + assert(!isinf(S[i])); + } +#endif + } + + ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup); + + for (int64_t i = 0; i < M; i++) { + S16[i] = GGML_FP32_TO_FP16(S[i]); + } + + // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16). + if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) { + for (int64_t ic = 0; ic < nev1; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // v indices + const int iv2 = iq2 % nev2; + const int iv3 = iq3; + + ggml_vec_dot_f16(nev0, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), + S16); + } + } else { + for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // v indices + const int iv2 = iq2 % nev2; + const int iv3 = iq3; + + ggml_vec_dot_f16_unroll(nev0, nbv1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), + S16); + } + } + } +} + +static void ggml_compute_forward_flash_attn_ext( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + switch (q->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_flash_ff static void ggml_compute_forward_flash_ff_f16( @@ -14717,6 +15004,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm const bool masked = t != 0; ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor); } break; + case GGML_OP_FLASH_ATTN_EXT: + { + ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + } break; case GGML_OP_FLASH_FF: { ggml_compute_forward_flash_ff(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor); @@ -15713,6 +16004,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor GGML_ASSERT(false); // TODO: not implemented } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { struct ggml_tensor * flash_grad = NULL; if (src0->grad || src1->grad || tensor->src[2]->grad) { @@ -16438,6 +16730,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { n_tasks = n_threads; } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { n_tasks = n_threads; } break; @@ -16769,6 +17062,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); diff --git a/ggml.h b/ggml.h index de8162b8135f3..d76fe9d5c48c9 100644 --- a/ggml.h +++ b/ggml.h @@ -452,6 +452,7 @@ extern "C" { GGML_OP_LEAKY_RELU, GGML_OP_FLASH_ATTN, + GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_FF, GGML_OP_FLASH_ATTN_BACK, GGML_OP_WIN_PART, @@ -1619,6 +1620,14 @@ extern "C" { struct ggml_tensor * v, bool masked); + GGML_API struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale); + GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index d28382f7d47b7..cec23c23f1dce 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4205,38 +4205,6 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(k, "k", il); - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); - - if (model.arch == LLM_ARCH_PHI2) { - // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs - // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - } - - if (max_alibi_bias > 0.0f) { - // temporary branch until we figure out how to handle ggml_alibi through ggml_add - kq = ggml_scale(ctx, kq, kq_scale); - cb(kq, "kq_scaled", il); - - if (max_alibi_bias > 0.0f) { - // TODO: n_head or n_head_kv - // TODO: K-shift is likely not working - // TODO: change to ggml_add - kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); - cb(kq, "kq_scaled_alibi", il); - } - - kq = ggml_add(ctx, kq, kq_mask); - cb(kq, "kq_masked", il); - - kq = ggml_soft_max(ctx, kq); - cb(kq, "kq_soft_max", il); - } else { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); - cb(kq, "kq_soft_max_ext", il); - } - // split cached v into n_head heads struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], @@ -4246,8 +4214,49 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); + // TODO: determine if we can use flash attention + const bool supports_flash_attn = true; + + struct ggml_tensor * kqv; + + if (supports_flash_attn) { + kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); + } else { + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); + + if (model.arch == LLM_ARCH_PHI2) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } + + if (max_alibi_bias > 0.0f) { + // temporary branch until we figure out how to handle ggml_alibi through ggml_add + kq = ggml_scale(ctx, kq, kq_scale); + cb(kq, "kq_scaled", il); + + if (max_alibi_bias > 0.0f) { + // TODO: n_head or n_head_kv + // TODO: K-shift is likely not working + // TODO: change to ggml_add + kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); + cb(kq, "kq_scaled_alibi", il); + } + + kq = ggml_add(ctx, kq, kq_mask); + cb(kq, "kq_masked", il); + + kq = ggml_soft_max(ctx, kq); + cb(kq, "kq_soft_max", il); + } else { + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); + cb(kq, "kq_soft_max_ext", il); + } + + kqv = ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv", il); + } struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); cb(kqv_merged, "kqv_merged", il); @@ -9490,8 +9499,7 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, - cparams.n_ctx, cparams.offload_kqv)) { + if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, cparams.n_ctx, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 55ce14e0d902c..5693c2197c7c5 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1384,6 +1384,32 @@ struct test_leaky_relu : public test_case { } }; +// GGML_OP_FLASH_ATTN_EXT +struct test_flash_attn_ext : public test_case { + const ggml_type typeq; + const int64_t hs; // head size + const int64_t nh; // num heads + const int64_t kv; // kv size + const int64_t nt; // tokens + + std::string vars() override { + return VARS_TO_STR5(typeq, hs, nh, kv, nt); + } + + test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, + int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nt = 8) + : typeq(typeq), hs(hs), nh(nh), kv(kv), nt(nt) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nt, nh, 1); + ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nt, 1, 1); + ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); + return out; + } +}; + // Mixtral MOE struct test_moe : public test_case { const int n_experts; @@ -1650,6 +1676,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); + test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 96, 8)); + #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 8*1024)); From fa7ebcca993ec0d47f6ed6a47a8d5ac4f7407262 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Jan 2024 20:06:26 +0200 Subject: [PATCH 002/100] ggml : fix GQA support in ggml_flash_attn_ext --- ggml-metal.metal | 8 ++++---- ggml.c | 23 +++++++++++++++-------- llama.cpp | 4 ++++ 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index b79a1ba5634a7..28847794cb5d8 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1960,10 +1960,10 @@ kernel void kernel_leaky_relu_f32( } kernel void kernel_flash_attn_ext_f16( - device const half * q, - device const half * k, - device const half * v, - device const half * mask, + device const half * q, + device const half * k, + device const half * v, + device const float * mask, device float * dst, constant int64_t & ne00, constant int64_t & ne01, diff --git a/ggml.c b/ggml.c index e01d938ceb681..9cf4784ce4759 100644 --- a/ggml.c +++ b/ggml.c @@ -13307,6 +13307,13 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); + // broadcast factors + const int64_t rk2 = neq2/nek2; + const int64_t rk3 = neq3/nek3; + + const int64_t rv2 = neq2/nev2; + const int64_t rv3 = neq3/nev3; + if (params->type == GGML_TASK_INIT) { return; } @@ -13347,8 +13354,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { for (int64_t ic = 0; ic < nek1; ++ic) { // k indices - const int ik3 = iq3; - const int ik2 = iq2 % nek2; + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; const int ik1 = ic; // S indices @@ -13362,8 +13369,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( } else { for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { // k indices - const int ik3 = iq3; - const int ik2 = iq2 % nek2; + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; const int ik1 = ic; // S indices @@ -13452,8 +13459,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int i3 = iq3; // v indices - const int iv2 = iq2 % nev2; - const int iv3 = iq3; + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; ggml_vec_dot_f16(nev0, (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), @@ -13468,8 +13475,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int i3 = iq3; // v indices - const int iv2 = iq2 % nev2; - const int iv3 = iq3; + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; ggml_vec_dot_f16_unroll(nev0, nbv1, (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), diff --git a/llama.cpp b/llama.cpp index cec23c23f1dce..d4bebe5203e9a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4220,6 +4220,10 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * kqv; if (supports_flash_attn) { + //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); + //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); + //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); + //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); } else { struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); From a9681febd65cbd3f372badc5f4a4d8bc1336d2d9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 20 Jan 2024 12:26:49 +0200 Subject: [PATCH 003/100] ggml : online attention (CPU) --- ggml-metal.m | 8 +- ggml-metal.metal | 3 +- ggml.c | 249 ++++++++++++++++++------------------- ggml.h | 5 + llama.cpp | 124 ++++++++++-------- tests/test-backend-ops.cpp | 14 +-- 6 files changed, 218 insertions(+), 185 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 6d88d5c36a8ad..4d85dd3ddb319 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2207,9 +2207,15 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20]; [encoder setBytes:&scale length:sizeof( float) atIndex:21]; + const int nwarps = 4; + + // each warp needs n_embd_head elements + GGML_ASSERT(nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:nwarps*ne00*sizeof(float) atIndex:0]; + const int nth = MIN(1024, ne0); - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 28847794cb5d8..a1e1755a3a605 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1981,7 +1981,8 @@ kernel void kernel_flash_attn_ext_f16( constant uint64_t & nb1, constant uint64_t & nb2, constant uint64_t & nb3, - constant float & scale, + constant float & scale, + threadgroup float * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { diff --git a/ggml.c b/ggml.c index 9cf4784ce4759..e64a328fadb1f 100644 --- a/ggml.c +++ b/ggml.c @@ -817,7 +817,7 @@ do { \ #if defined(__F16C__) // the _mm256_cvt intrinsics require F16C -#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) +#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) #else static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { @@ -1323,6 +1323,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float #endif } +inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v); + } +#endif +} + // xs and vs are byte strides of x and v inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { @@ -1407,6 +1438,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #endif } +inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_MUL(ay[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#endif +} + inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); } inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } @@ -5704,8 +5764,9 @@ struct ggml_tensor * ggml_flash_attn_ext( is_node = true; } - //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne); + // permute(0, 2, 1, 3) + int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, ne); float params[] = { scale }; ggml_set_op_params(result, params, sizeof(params)); @@ -13281,12 +13342,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t D = neq0; const int64_t N = neq1; const int64_t P = nek1 - N; - const int64_t M = P + N; - - const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); GGML_ASSERT(ne0 == D); - GGML_ASSERT(ne1 == N); + GGML_ASSERT(ne2 == N); GGML_ASSERT(P >= 0); GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); @@ -13295,11 +13353,11 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(neq0 == D); GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev1 == D); + GGML_ASSERT(nev0 == D); GGML_ASSERT(neq1 == N); GGML_ASSERT(nek1 == N + P); - GGML_ASSERT(nev1 == D); + GGML_ASSERT(nev0 == D); // dst cannot be transposed or permuted GGML_ASSERT(nb0 == sizeof(float)); @@ -13339,151 +13397,87 @@ static void ggml_compute_forward_flash_attn_ext_f16( //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices const int iq3 = ir/(neq2*neq1); const int iq2 = (ir - iq3*neq2*neq1)/neq1; const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); - float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); + float S = 0.0f; + float M = -INFINITY; - for (int i = M; i < Mup; ++i) { - S[i] = -INFINITY; - } + float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32); + ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D); - if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { - for (int64_t ic = 0; ic < nek1; ++ic) { - // k indices - const int ik3 = iq3 / rk3; - const int ik2 = iq2 / rk2; - const int ik1 = ic; + memset(V16, 0, D*sizeof(ggml_fp16_t)); - // S indices - const int i1 = ik1; + const float * mp = mask ? (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]) : NULL; - ggml_vec_dot_f16(neq0, - S + i1, - (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } - } else { - for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { - // k indices - const int ik3 = iq3 / rk3; - const int ik2 = iq2 / rk2; - const int ik1 = ic; + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; - // S indices - const int i1 = ik1; + // v indices + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; - ggml_vec_dot_f16_unroll(neq0, nbk1, - S + i1, - ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + // online softmax / attention + // loop over n_kv and n_head_kv + // ref: https://arxiv.org/pdf/2112.05682.pdf + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? mp[ic] : 0.0f; + if (mv == -INFINITY) { + continue; } - } - // scale - ggml_vec_scale_f32(nek1, S, scale); + float s; - if (mask) { - const float * mp = (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]); - ggml_vec_acc_f32(M, S, mp); - } + ggml_vec_dot_f16(D, + &s, + (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - // softmax - // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero. - // dont forget to set their S values to zero - { - float max = -INFINITY; - ggml_vec_max_f32(M, &max, S); + s = s*scale + mv; - ggml_float sum = 0.0; - { -#ifdef GGML_SOFT_MAX_ACCELERATE - max = -max; - vDSP_vsadd(S, 1, &max, S, 1, Mup); - vvexpf(S, S, &Mup); - ggml_vec_sum_f32(Mup, &sum, S); -#else - uint16_t scvt[GGML_SOFT_MAX_UNROLL]; - ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; + const float Mold = M; - for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { - float * SS = S + i; + float ms = 1.0f; + float vs = 1.0f; - for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { - if (SS[j] == -INFINITY) { - SS[j] = 0.0f; - } else { - ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); - memcpy(&scvt[j], &s, sizeof(uint16_t)); - const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]); - sump[j] += (ggml_float)val; - SS[j] = val; - } - } - } + if (s > M) { + M = s; + ms = expf(Mold - M); - for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { - sum += sump[i]; - } -#endif + // V = V*expf(Mold - M) + ggml_vec_scale_f16(D, V16, ms); + } else { + vs = expf(s - M); } - assert(sum > 0.0); + const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); - sum = 1.0/sum; - ggml_vec_scale_f32(M, S, sum); + // V += v*expf(s - M) + ggml_vec_mad_f16(D, V16, v16, vs); -#ifndef NDEBUG - for (int i = 0; i < M; ++i) { - assert(!isnan(S[i])); - assert(!isinf(S[i])); - } -#endif + S = S*ms + vs; } - ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup); - - for (int64_t i = 0; i < M; i++) { - S16[i] = GGML_FP32_TO_FP16(S[i]); + // V /= S + for (int64_t d = 0; d < D; ++d) { + V32[d] = GGML_FP16_TO_FP32(V16[d])/S; } - // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16). - if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) { - for (int64_t ic = 0; ic < nev1; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - // v indices - const int iv2 = iq2 / rv2; - const int iv3 = iq3 / rv3; - - ggml_vec_dot_f16(nev0, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), - S16); - } - } else { - for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; - // v indices - const int iv2 = iq2 / rv2; - const int iv3 = iq3 / rv3; + // original + //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); - ggml_vec_dot_f16_unroll(nev0, nbv1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), - S16); - } - } + // permute(0, 2, 1, 3) + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1); } } @@ -17069,7 +17063,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; } break; case GGML_OP_FLASH_ATTN: - case GGML_OP_FLASH_ATTN_EXT: { const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); @@ -17081,6 +17074,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 } } break; + case GGML_OP_FLASH_ATTN_EXT: + { + const int64_t ne00 = node->src[0]->ne[0]; // D + + cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size + } break; case GGML_OP_FLASH_FF: { if (node->src[1]->type == GGML_TYPE_F32) { diff --git a/ggml.h b/ggml.h index d76fe9d5c48c9..7bca02f2a2c48 100644 --- a/ggml.h +++ b/ggml.h @@ -1620,6 +1620,11 @@ extern "C" { struct ggml_tensor * v, bool masked); + // q: [n_embd, n_batch, n_head, 1] + // k: [n_embd, n_kv, n_head_kv, 1] + // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! + // mask: [n_kv, n_batch, 1, 1] + // res: [n_embd, n_head, n_batch, 1] !! permuted !! GGML_API struct ggml_tensor * ggml_flash_attn_ext( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index f0a63afef0087..4e6c9f9cc75ea 100644 --- a/llama.cpp +++ b/llama.cpp @@ -95,6 +95,8 @@ #define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_EXPERTS 8 +#define LLAMA_FLASH_ATTN + // // logging // @@ -4167,23 +4169,34 @@ static void llm_build_kv_store( const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - // compute the transposed [n_tokens, n_embd] V matrix - struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); - //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed - cb(v_cur_t, "v_cur_t", il); - struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); cb(k_cache_view, "k_cache_view", il); + // important: storing RoPE-ed version of K in the KV cache! + ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); + +#if defined(LLAMA_FLASH_ATTN) + // NOTE: the V cache is not transposed when using FLASH attention !! + struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, + (ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head); + cb(v_cache_view, "v_cache_view", il); + + ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); + + GGML_UNUSED(n_ctx); +#else + // compute the transposed [n_tokens, n_embd] V matrix + //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); + struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed + cb(v_cur_t, "v_cur_t", il); + struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, ( n_ctx)*ggml_element_size(kv.v_l[il]), (kv_head)*ggml_element_size(kv.v_l[il])); - cb(v_cache_view, "v_cache_view", il); - // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); +#endif } static struct ggml_tensor * llm_build_norm( @@ -4343,68 +4356,77 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(k, "k", il); - // split cached v into n_head heads + struct ggml_tensor * cur; + +#if defined(LLAMA_FLASH_ATTN) + // split cached v into n_head heads (not transposed) struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], - n_kv, n_embd_head_v, n_head_kv, - ggml_element_size(kv.v_l[il])*n_ctx, - ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv.v_l[il]->type, n_embd_head_k), 0); cb(v, "v", il); - // TODO: determine if we can use flash attention - const bool supports_flash_attn = true; + cur = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); + //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); + //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); + //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); + //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); + //printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]); - struct ggml_tensor * kqv; + cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens); +#else + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); - if (supports_flash_attn) { - //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); - //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); - //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); - //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); - kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); - } else { - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); + if (model.arch == LLM_ARCH_PHI2) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } - if (model.arch == LLM_ARCH_PHI2) { - // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs - // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - } + if (max_alibi_bias > 0.0f) { + // temporary branch until we figure out how to handle ggml_alibi through ggml_add + kq = ggml_scale(ctx, kq, kq_scale); + cb(kq, "kq_scaled", il); if (max_alibi_bias > 0.0f) { - // temporary branch until we figure out how to handle ggml_alibi through ggml_add - kq = ggml_scale(ctx, kq, kq_scale); - cb(kq, "kq_scaled", il); + // TODO: n_head or n_head_kv + // TODO: K-shift is likely not working + // TODO: change to ggml_add + kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); + cb(kq, "kq_scaled_alibi", il); + } - if (max_alibi_bias > 0.0f) { - // TODO: n_head or n_head_kv - // TODO: K-shift is likely not working - // TODO: change to ggml_add - kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); - cb(kq, "kq_scaled_alibi", il); - } + kq = ggml_add(ctx, kq, kq_mask); + cb(kq, "kq_masked", il); - kq = ggml_add(ctx, kq, kq_mask); - cb(kq, "kq_masked", il); + kq = ggml_soft_max(ctx, kq); + cb(kq, "kq_soft_max", il); + } else { + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); + cb(kq, "kq_soft_max_ext", il); + } - kq = ggml_soft_max(ctx, kq); - cb(kq, "kq_soft_max", il); - } else { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); - cb(kq, "kq_soft_max_ext", il); - } + // split cached v into n_head heads (transposed) + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(kv.v_l[il])*n_ctx, + ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, + 0); + cb(v, "v", il); - kqv = ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); - } + struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv", il); struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); cb(kqv_merged, "kqv_merged", il); - struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); + cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); cb(cur, "kqv_merged_cont", il); +#endif cur = ggml_mul_mat(ctx, wo, cur); if (wo_b) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 5693c2197c7c5..a56c0d6c59a64 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1390,21 +1390,21 @@ struct test_flash_attn_ext : public test_case { const int64_t hs; // head size const int64_t nh; // num heads const int64_t kv; // kv size - const int64_t nt; // tokens + const int64_t nb; // batch size std::string vars() override { - return VARS_TO_STR5(typeq, hs, nh, kv, nt); + return VARS_TO_STR5(typeq, hs, nh, kv, nb); } test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, - int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nt = 8) - : typeq(typeq), hs(hs), nh(nh), kv(kv), nt(nt) {} + int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) + : typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nt, nh, 1); + ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); - ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); - ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nt, 1, 1); + ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); return out; } From 1173f49c3bbe30810af4aeb77219eba7e05f658d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 20 Jan 2024 17:32:28 +0200 Subject: [PATCH 004/100] metal : initial implementation --- ggml-metal.m | 75 +++++++++++++------- ggml-metal.metal | 138 ++++++++++++++++++++++++++++++++++--- ggml.c | 2 +- tests/test-backend-ops.cpp | 4 ++ 4 files changed, 183 insertions(+), 36 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 4d85dd3ddb319..556c53482a75e 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -278,6 +278,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ NSURL * libURL = [NSURL fileURLWithPath:libPath]; GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]); ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; + if (error) { + GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } } else { GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); @@ -316,13 +320,12 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ //[options setFastMathEnabled:false]; ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; + if (error) { + GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } } } - - if (error) { - GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } } // print MTL GPU family: @@ -396,6 +399,9 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \ kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \ kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \ + GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ + (int) kernel->pipeline.threadExecutionWidth); \ if (error) { \ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ return NULL; \ @@ -2171,12 +2177,28 @@ static bool ggml_metal_graph_compute( struct ggml_tensor * src2 = gf->nodes[i]->src[2]; struct ggml_tensor * src3 = gf->nodes[i]->src[3]; + GGML_ASSERT(ggml_are_same_shape(src1, src2)); + size_t offs_src2 = 0; size_t offs_src3 = 0; - id id_src2 = src2 ? ggml_metal_get_buffer(ctx, src2, &offs_src2) : nil; + GGML_ASSERT(src2); + id id_src2 = ggml_metal_get_buffer(ctx, src2, &offs_src2); + id id_src3 = src3 ? ggml_metal_get_buffer(ctx, src3, &offs_src3) : nil; + const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); + const int64_t ne31 = src3 ? src3->ne[1] : 0; + const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); + const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); + + const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); + const uint64_t nb31 = src3 ? src3->nb[1] : 0; + const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); + const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); + + const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); + float scale; memcpy(&scale, dst->op_params, sizeof(float)); @@ -2197,25 +2219,28 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:15]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:16]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&scale length:sizeof( float) atIndex:21]; - - const int nwarps = 4; - - // each warp needs n_embd_head elements - GGML_ASSERT(nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:nwarps*ne00*sizeof(float) atIndex:0]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16]; + [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; + [encoder setBytes:&scale length:sizeof( float) atIndex:27]; + + const int nwarps = 1; + + GGML_ASSERT(2*32*nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 31)/32, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index a1e1755a3a605..5986bcb427f4b 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1960,10 +1960,10 @@ kernel void kernel_leaky_relu_f32( } kernel void kernel_flash_attn_ext_f16( - device const half * q, - device const half * k, - device const half * v, - device const float * mask, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, device float * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -1973,20 +1973,138 @@ kernel void kernel_flash_attn_ext_f16( constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, constant float & scale, threadgroup float * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - // TODO: implement + uint3 ntg[[threads_per_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const int64_t iq3 = tgpig[2]; + const int64_t iq2 = tgpig[1]; + const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; + + if (iq1 >= ne01) { + return; + } + + const int64_t D = ne00; + + // TODO: can we move this to the stack? + threadgroup half * V16 = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); + + // initialize with zeros + for (int64_t d = 0; d < D; ++d) { + V16[d] = 0.0h; + } + + threadgroup half * pq = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); + + half S = 0.0h; + half M = -INFINITY; + + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + + device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; + + // assume K and V are same shape + const int64_t ne22 = ne12; + const int64_t ne23 = ne13; + + const uint64_t nb21 = nb11; + const uint64_t nb22 = nb12; + const uint64_t nb23 = nb13; + + // broadcast + const int64_t rk2 = ne02/ne12; + const int64_t rk3 = ne03/ne13; + + const int64_t rv2 = ne02/ne22; + const int64_t rv3 = ne03/ne23; + + // k indices + const int64_t ik2 = iq2 / rk2; + const int64_t ik3 = iq3 / rk3; + + // v indices + const int64_t iv2 = iq2 / rv2; + const int64_t iv3 = iq3 / rv3; + + // load Q to shared memory + for (int64_t d = 0; d < D; ++d) { + pq[d] = ((device const half *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; + } + + for (int64_t ic = 0; ic < ne11; ++ic) { + const half mv = mp ? mp[ic] : 0.0h; + if (mv == -INFINITY) { + continue; + } + + half s = 0.0f; + + //device const half * pq = (device const half *) ((device char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + device const half * pk = (device const half *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); + + for (int64_t d = 0; d < D; ++d) { + s += pk[d] * pq[d]; + } + + s = s*scale + mv; + + const half Mold = M; + + half ms = 1.0f; + half vs = 1.0f; + + if (s > M) { + M = s; + ms = exp(Mold - M); + + // V = V*exp(Mold - M) + for (int64_t d = 0; d < D; ++d) { + V16[d] *= ms; + } + } else { + vs = exp(s - M); + } + + device const half * pv = (device const half *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + + // V += v*exp(s - M) + for (int64_t d = 0; d < D; ++d) { + V16[d] += pv[d] * vs; + } + + S = S*ms + vs; + } + + for (int64_t d = 0; d < D; ++d) { + V16[d] /= S; + } + + // dst indices + const int64_t i1 = iq1; + const int64_t i2 = iq2; + const int64_t i3 = iq3; + + for (int64_t d = 0; d < D; ++d) { + dst[(i3*ne2*ne1 + i2 + i1*ne1)*D + d] = V16[d]; + } } kernel void kernel_cpy_f16_f16( diff --git a/ggml.c b/ggml.c index e64a328fadb1f..10df03c9c619b 100644 --- a/ggml.c +++ b/ggml.c @@ -13419,8 +13419,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int ik2 = iq2 / rk2; // v indices - const int iv2 = iq2 / rv2; const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; // online softmax / attention // loop over n_kv and n_head_kv diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a56c0d6c59a64..51a33c662da56 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1396,6 +1396,10 @@ struct test_flash_attn_ext : public test_case { return VARS_TO_STR5(typeq, hs, nh, kv, nb); } + double max_nmse_err() override { + return 5e-4; + } + test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) : typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {} From 528da7515ef874ab1188ab8f691c36d3e9e0cb20 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 11:13:24 +0200 Subject: [PATCH 005/100] metal : f16 precision --- ggml-metal.m | 6 ++++-- ggml-metal.metal | 40 ++++++++++++++++++++++------------------ 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 556c53482a75e..e67a7c4ef892b 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2237,8 +2237,10 @@ static bool ggml_metal_graph_compute( const int nwarps = 1; - GGML_ASSERT(2*32*nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*sizeof(float) atIndex:0]; + const size_t shalf = sizeof(float)/2; + + GGML_ASSERT(2*32*nwarps*ne00*shalf <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*shalf atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 31)/32, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 5986bcb427f4b..e4e89b5b3f7bf 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1988,7 +1988,7 @@ kernel void kernel_flash_attn_ext_f16( constant int64_t & ne2, constant int64_t & ne3, constant float & scale, - threadgroup float * shared [[threadgroup(0)]], + threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]], @@ -2003,16 +2003,17 @@ kernel void kernel_flash_attn_ext_f16( } const int64_t D = ne00; + const int64_t D4 = D/4; // TODO: can we move this to the stack? - threadgroup half * V16 = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); + threadgroup half4 * V16 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); // initialize with zeros - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < D4; ++d) { V16[d] = 0.0h; } - threadgroup half * pq = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); half S = 0.0h; half M = -INFINITY; @@ -2045,8 +2046,8 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv3 = iq3 / rv3; // load Q to shared memory - for (int64_t d = 0; d < D; ++d) { - pq[d] = ((device const half *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; + for (int64_t d = 0; d < D4; ++d) { + pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; } for (int64_t ic = 0; ic < ne11; ++ic) { @@ -2055,15 +2056,16 @@ kernel void kernel_flash_attn_ext_f16( continue; } - half s = 0.0f; + half4 s4 = 0.0f; - //device const half * pq = (device const half *) ((device char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); - device const half * pk = (device const half *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); + device const half4 * pk4 = (device const half4 *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); - for (int64_t d = 0; d < D; ++d) { - s += pk[d] * pq[d]; + for (int64_t d = 0; d < D4; ++d) { + s4 += pk4[d] * pq4[d]; } + half s = s4.x + s4.y + s4.z + s4.w; + s = s*scale + mv; const half Mold = M; @@ -2076,24 +2078,24 @@ kernel void kernel_flash_attn_ext_f16( ms = exp(Mold - M); // V = V*exp(Mold - M) - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < D4; ++d) { V16[d] *= ms; } } else { vs = exp(s - M); } - device const half * pv = (device const half *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); // V += v*exp(s - M) - for (int64_t d = 0; d < D; ++d) { - V16[d] += pv[d] * vs; + for (int64_t d = 0; d < D4; ++d) { + V16[d] += pv4[d] * vs; } S = S*ms + vs; } - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < D4; ++d) { V16[d] /= S; } @@ -2102,8 +2104,10 @@ kernel void kernel_flash_attn_ext_f16( const int64_t i2 = iq2; const int64_t i3 = iq3; - for (int64_t d = 0; d < D; ++d) { - dst[(i3*ne2*ne1 + i2 + i1*ne1)*D + d] = V16[d]; + device float4 * dst4 = (device float4 *) dst; + + for (int64_t d = 0; d < D4; ++d) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; } } From 52ae085750afd37affc4ed18fe092d92c9ccdc5f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 11:38:17 +0200 Subject: [PATCH 006/100] metal : reduce branches --- ggml-metal.metal | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index e4e89b5b3f7bf..f3a7efafa6613 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2056,40 +2056,26 @@ kernel void kernel_flash_attn_ext_f16( continue; } - half4 s4 = 0.0f; + device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); + device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); - device const half4 * pk4 = (device const half4 *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); + half4 s4 = 0.0h; for (int64_t d = 0; d < D4; ++d) { s4 += pk4[d] * pq4[d]; } - half s = s4.x + s4.y + s4.z + s4.w; - - s = s*scale + mv; + half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; const half Mold = M; - half ms = 1.0f; - half vs = 1.0f; - - if (s > M) { - M = s; - ms = exp(Mold - M); - - // V = V*exp(Mold - M) - for (int64_t d = 0; d < D4; ++d) { - V16[d] *= ms; - } - } else { - vs = exp(s - M); - } + M = max(M, s); - device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + const half ms = exp(Mold - M); + const half vs = exp(s - M); - // V += v*exp(s - M) for (int64_t d = 0; d < D4; ++d) { - V16[d] += pv4[d] * vs; + V16[d] = V16[d]*ms + pv4[d]*vs; } S = S*ms + vs; From b97325800a7727244e737715fa7b5e2bc41afb21 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 12:01:55 +0200 Subject: [PATCH 007/100] metal : specialize for head size --- ggml-metal.m | 259 +++++++++++++++++++++++++---------------------- ggml-metal.metal | 42 +++++++- 2 files changed, 179 insertions(+), 122 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index e67a7c4ef892b..046643146b3f3 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -147,7 +147,9 @@ GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -412,125 +414,127 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ // simd_sum and simd_max requires MTLGPUFamilyApple7 - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, flash_attn_ext_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); } return ctx; @@ -2172,6 +2176,7 @@ static bool ggml_metal_graph_compute( } break; case GGML_OP_FLASH_ATTN_EXT: { + GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(src0->type == GGML_TYPE_F16); struct ggml_tensor * src2 = gf->nodes[i]->src[2]; @@ -2202,7 +2207,19 @@ static bool ggml_metal_graph_compute( float scale; memcpy(&scale, dst->op_params, sizeof(float)); - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16].pipeline; + id pipeline = nil; + + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ASSERT(false && "add template specialization for this size"); + } + } // TODO: extend if necessary [encoder setComputePipelineState:pipeline]; diff --git a/ggml-metal.metal b/ggml-metal.metal index f3a7efafa6613..d97952f2b0871 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1959,6 +1959,43 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; } +typedef void (flash_attn_ext_f16_t)( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant float & scale, + threadgroup half * shared, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]); + +template // head size kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2002,7 +2039,6 @@ kernel void kernel_flash_attn_ext_f16( return; } - const int64_t D = ne00; const int64_t D4 = D/4; // TODO: can we move this to the stack? @@ -2097,6 +2133,10 @@ kernel void kernel_flash_attn_ext_f16( } } +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, From 8cde449b8be4e481db2a8790d9320c743b3ed65e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 12:23:22 +0200 Subject: [PATCH 008/100] wip : 8 rows per simd group --- ggml-metal.m | 10 +-- ggml-metal.metal | 173 ++++++++++++++++++++++++++++++++++++----------- 2 files changed, 139 insertions(+), 44 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 046643146b3f3..0b1119c4eb467 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,14 +2252,14 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int nwarps = 1; + const int64_t nwarps = 2; - const size_t shalf = sizeof(float)/2; + const size_t smem = nwarps*(2*8*nwarps*ne00 + 128)*(sizeof(float)/2); - GGML_ASSERT(2*32*nwarps*ne00*shalf <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*shalf atIndex:0]; + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 31)/32, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 8*nwarps - 1)/(8*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index d97952f2b0871..789b19bad6b93 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2031,33 +2031,20 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]; - const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; - - if (iq1 >= ne01) { - return; - } + //const int64_t iq3 = tgpig[2]; + //const int64_t iq2 = tgpig[1]; + //const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; - const int64_t D4 = D/4; + const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups - // TODO: can we move this to the stack? - threadgroup half4 * V16 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); + const int64_t iq3 = tgpig[2]; + const int64_t iq2 = tgpig[1]*(8*nsg) + 8*sgitg + tiisg/4; + const int64_t iq1 = tgpig[0]; - // initialize with zeros - for (int64_t d = 0; d < D4; ++d) { - V16[d] = 0.0h; + if (iq2 >= ne02) { + return; } - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); - - half S = 0.0h; - half M = -INFINITY; - - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - - device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; - // assume K and V are same shape const int64_t ne22 = ne12; const int64_t ne23 = ne13; @@ -2081,11 +2068,97 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv2 = iq2 / rv2; const int64_t iv3 = iq3 / rv3; - // load Q to shared memory - for (int64_t d = 0; d < D4; ++d) { - pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + + device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; + +// const int64_t D4 = D/4; +// +// // TODO: can we move this to the stack? +// threadgroup half4x4 * V16 = (threadgroup half4x4 *) (shared); +// +// // initialize with zeros +// for (int64_t d = 0; d < D4; ++d) { +// +// } +// +// threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 4*D); +// +// // load Q to shared memory +// for (int64_t d = 0; d < D4; ++d) { +// pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; +// } +// +// half S = 0.0h; +// half M = -INFINITY; +// +// for (int64_t ic = 0; ic < ne11; ++ic) { +// const half mv = mp ? mp[ic] : 0.0h; +// if (mv == -INFINITY) { +// continue; +// } +// +// device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); +// device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); +// +// half4 s4 = 0.0h; +// +// for (int64_t d = 0; d < D4; ++d) { +// s4 += pk4[d] * pq4[d]; +// } +// +// half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; +// +// const half Mold = M; +// +// M = max(M, s); +// +// const half ms = exp(Mold - M); +// const half vs = exp(s - M); +// +// for (int64_t d = 0; d < D4; ++d) { +// V16[d] = V16[d]*ms + pv4[d]*vs; +// } +// +// S = S*ms + vs; +// } +// +// for (int64_t d = 0; d < D4; ++d) { +// V16[d] /= S; +// } +// +// // dst indices +// const int64_t i1 = iq1; +// const int64_t i2 = iq2; +// const int64_t i3 = iq3; +// +// device float4 * dst4 = (device float4 *) dst; +// +// for (int64_t d = 0; d < D4; ++d) { +// dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; +// } + + const int64_t D4 = D/4; + + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) ); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 8*D); + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 16*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(16*D + 128) + 16*D); + + const uint tiih = tiisg%4; // thread index in head + const uint hiisg = tiisg/4; // head index in simdgroup + + // load 8 heads from Q to shared memory + for (int64_t i = 0; i < D4/4; ++i) { + pq4[hiisg*D4 + 4*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[4*i + tiih]; + ps4[hiisg*D4 + 4*i + tiih] = 0.0h; } + simdgroup_barrier(mem_flags::mem_threadgroup); + + half S = 0.0h; + half M = -INFINITY; + for (int64_t ic = 0; ic < ne11; ++ic) { const half mv = mp ? mp[ic] : 0.0h; if (mv == -INFINITY) { @@ -2097,30 +2170,52 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; - for (int64_t d = 0; d < D4; ++d) { - s4 += pk4[d] * pq4[d]; + for (int64_t i = 0; i < D4/4; ++i) { + s4 += pk4[4*i + tiih] * pq4[hiisg*D4 + 4*i + tiih]; } - half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; + ss4[hiisg*4 + tiih] = s4; + + simdgroup_barrier(mem_flags::mem_threadgroup); + + if (tiih == 0) { + s4 = ss4[4*hiisg + 0] + ss4[4*hiisg + 1] + ss4[4*hiisg + 2] + ss4[4*hiisg + 3]; + + half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; - const half Mold = M; + const half Mold = M; - M = max(M, s); + M = max(M, s); - const half ms = exp(Mold - M); - const half vs = exp(s - M); + const half ms = exp(Mold - M); + const half vs = exp(s - M); - for (int64_t d = 0; d < D4; ++d) { - V16[d] = V16[d]*ms + pv4[d]*vs; + S = S*ms + vs; + + ss[2*hiisg + 0] = ms; + ss[2*hiisg + 1] = vs; } - S = S*ms + vs; + simdgroup_barrier(mem_flags::mem_threadgroup); + + const half ms = ss[2*hiisg + 0]; + const half vs = ss[2*hiisg + 1]; + + for (int64_t i = 0; i < D4/4; ++i) { + ps4[hiisg*D4 + 4*i + tiih] = ps4[hiisg*D4 + 4*i + tiih]*ms + pv4[4*i + tiih]*vs; + } } - for (int64_t d = 0; d < D4; ++d) { - V16[d] /= S; + simdgroup_barrier(mem_flags::mem_threadgroup); + + if (tiih == 0) { + for (int64_t i = 0; i < D4; ++i) { + ps4[hiisg*D4 + i] /= S; + } } + simdgroup_barrier(mem_flags::mem_threadgroup); + // dst indices const int64_t i1 = iq1; const int64_t i2 = iq2; @@ -2128,8 +2223,8 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t d = 0; d < D4; ++d) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; + for (int64_t i = 0; i < D4/4; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 4*i + tiih] = (float4) ps4[hiisg*D4 + 4*i + tiih]; } } From f31955f5d12da67f35aa459996a171975fdf269b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 18:01:28 +0200 Subject: [PATCH 009/100] wip : 4 rows per simd group --- ggml-metal.m | 6 +++--- ggml-metal.metal | 39 +++++++++++++++++++++------------------ 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 0b1119c4eb467..abb96d6ec6e44 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,14 +2252,14 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 2; + const int64_t nwarps = 4; - const size_t smem = nwarps*(2*8*nwarps*ne00 + 128)*(sizeof(float)/2); + const size_t smem = nwarps*(2*4*ne00 + 128)*(sizeof(float)/2); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 8*nwarps - 1)/(8*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 4*nwarps - 1)/(4*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 789b19bad6b93..6fdd7fdad4326 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2038,7 +2038,7 @@ kernel void kernel_flash_attn_ext_f16( const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]*(8*nsg) + 8*sgitg + tiisg/4; + const int64_t iq2 = tgpig[1]*(4*nsg) + 4*sgitg + tiisg/8; const int64_t iq1 = tgpig[0]; if (iq2 >= ne02) { @@ -2140,18 +2140,18 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) ); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 8*D); - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 16*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(16*D + 128) + 16*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) ); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 4*D); + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 2*4*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*4*D + 128) + 2*4*D); - const uint tiih = tiisg%4; // thread index in head - const uint hiisg = tiisg/4; // head index in simdgroup + const uint tiih = tiisg%8; // thread index in head + const uint hiisg = tiisg/8; // head index in simdgroup // load 8 heads from Q to shared memory - for (int64_t i = 0; i < D4/4; ++i) { - pq4[hiisg*D4 + 4*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[4*i + tiih]; - ps4[hiisg*D4 + 4*i + tiih] = 0.0h; + for (int64_t i = 0; i < D4/8; ++i) { + pq4[hiisg*D4 + 8*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[8*i + tiih]; + ps4[hiisg*D4 + 8*i + tiih] = 0.0h; } simdgroup_barrier(mem_flags::mem_threadgroup); @@ -2170,16 +2170,18 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; - for (int64_t i = 0; i < D4/4; ++i) { - s4 += pk4[4*i + tiih] * pq4[hiisg*D4 + 4*i + tiih]; +#pragma unroll(D4/8) + for (int64_t i = 0; i < D4/8; ++i) { + s4 += pk4[8*i + tiih] * pq4[hiisg*D4 + 8*i + tiih]; } - ss4[hiisg*4 + tiih] = s4; + ss4[hiisg*8 + tiih] = s4; simdgroup_barrier(mem_flags::mem_threadgroup); if (tiih == 0) { - s4 = ss4[4*hiisg + 0] + ss4[4*hiisg + 1] + ss4[4*hiisg + 2] + ss4[4*hiisg + 3]; + s4 = ss4[8*hiisg + 0] + ss4[8*hiisg + 1] + ss4[8*hiisg + 2] + ss4[8*hiisg + 3] + + ss4[8*hiisg + 4] + ss4[8*hiisg + 5] + ss4[8*hiisg + 6] + ss4[8*hiisg + 7]; half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; @@ -2201,8 +2203,9 @@ kernel void kernel_flash_attn_ext_f16( const half ms = ss[2*hiisg + 0]; const half vs = ss[2*hiisg + 1]; - for (int64_t i = 0; i < D4/4; ++i) { - ps4[hiisg*D4 + 4*i + tiih] = ps4[hiisg*D4 + 4*i + tiih]*ms + pv4[4*i + tiih]*vs; +#pragma unroll(D4/8) + for (int64_t i = 0; i < D4/8; ++i) { + ps4[hiisg*D4 + 8*i + tiih] = ps4[hiisg*D4 + 8*i + tiih]*ms + pv4[8*i + tiih]*vs; } } @@ -2223,8 +2226,8 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t i = 0; i < D4/4; ++i) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 4*i + tiih] = (float4) ps4[hiisg*D4 + 4*i + tiih]; + for (int64_t i = 0; i < D4/8; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 8*i + tiih] = (float4) ps4[hiisg*D4 + 8*i + tiih]; } } From a4b6341c7b2a1977c29e79b17a0e5de3e31a5420 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 18:24:13 +0200 Subject: [PATCH 010/100] wip : template for rows per warp --- ggml-metal.m | 7 ++++--- ggml-metal.metal | 54 +++++++++++++++++++++++++----------------------- 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index abb96d6ec6e44..d521df43ab302 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,14 +2252,15 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 4; + const int64_t nwarps = 8; + const int64_t nhpw = 4; // heads per warp - const size_t smem = nwarps*(2*4*ne00 + 128)*(sizeof(float)/2); + const size_t smem = nwarps*(2*nhpw*ne00 + 128)*(sizeof(float)/2); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 4*nwarps - 1)/(4*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhpw*nwarps - 1)/(nhpw*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 6fdd7fdad4326..c9876c1033f1f 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size +template // head size, rows per warp kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2036,9 +2036,10 @@ kernel void kernel_flash_attn_ext_f16( //const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups + const uint tph = N_SIMDWIDTH/R; // threads per head const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]*(4*nsg) + 4*sgitg + tiisg/8; + const int64_t iq2 = tgpig[1]*(R*nsg) + R*sgitg + tiisg/tph; const int64_t iq1 = tgpig[0]; if (iq2 >= ne02) { @@ -2140,18 +2141,18 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) ); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 4*D); - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 2*4*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*4*D + 128) + 2*4*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 0*R*D); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 1*R*D); + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 2*R*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*R*D + 128) + 2*R*D); - const uint tiih = tiisg%8; // thread index in head - const uint hiisg = tiisg/8; // head index in simdgroup + const uint tiih = tiisg%tph; // thread index in head + const uint hiisg = tiisg/tph; // head index in simdgroup - // load 8 heads from Q to shared memory - for (int64_t i = 0; i < D4/8; ++i) { - pq4[hiisg*D4 + 8*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[8*i + tiih]; - ps4[hiisg*D4 + 8*i + tiih] = 0.0h; + // load R heads from Q to shared memory + for (int64_t i = 0; i < D4/tph; ++i) { + pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; + ps4[hiisg*D4 + tph*i + tiih] = 0.0h; } simdgroup_barrier(mem_flags::mem_threadgroup); @@ -2170,18 +2171,20 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; -#pragma unroll(D4/8) - for (int64_t i = 0; i < D4/8; ++i) { - s4 += pk4[8*i + tiih] * pq4[hiisg*D4 + 8*i + tiih]; + for (int64_t i = 0; i < D4/tph; ++i) { + s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; } - ss4[hiisg*8 + tiih] = s4; + ss4[hiisg*tph + tiih] = s4; simdgroup_barrier(mem_flags::mem_threadgroup); if (tiih == 0) { - s4 = ss4[8*hiisg + 0] + ss4[8*hiisg + 1] + ss4[8*hiisg + 2] + ss4[8*hiisg + 3] + - ss4[8*hiisg + 4] + ss4[8*hiisg + 5] + ss4[8*hiisg + 6] + ss4[8*hiisg + 7]; + s4 = 0.0h; + + for (int64_t i = 0; i < tph; ++i) { + s4 += ss4[hiisg*tph + i]; + } half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; @@ -2203,9 +2206,8 @@ kernel void kernel_flash_attn_ext_f16( const half ms = ss[2*hiisg + 0]; const half vs = ss[2*hiisg + 1]; -#pragma unroll(D4/8) - for (int64_t i = 0; i < D4/8; ++i) { - ps4[hiisg*D4 + 8*i + tiih] = ps4[hiisg*D4 + 8*i + tiih]*ms + pv4[8*i + tiih]*vs; + for (int64_t i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; } } @@ -2226,14 +2228,14 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t i = 0; i < D4/8; ++i) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 8*i + tiih] = (float4) ps4[hiisg*D4 + 8*i + tiih]; + for (int64_t i = 0; i < D4/tph; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih]; } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 4>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 4>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 4>; kernel void kernel_cpy_f16_f16( device const half * src0, From 77d08f3272c62900b40d110bf0de7f4466675c71 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 21:04:15 +0200 Subject: [PATCH 011/100] metal : parallelize across KV size --- ggml-metal.m | 8 +-- ggml-metal.metal | 137 +++++++++++++++++------------------------------ 2 files changed, 52 insertions(+), 93 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index d521df43ab302..a60dd779a6f09 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,15 +2252,15 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 8; - const int64_t nhpw = 4; // heads per warp + const int64_t nwarps = 16; + const int64_t nhptg = 4; // heads per threadgroup - const size_t smem = nwarps*(2*nhpw*ne00 + 128)*(sizeof(float)/2); + const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhpw*nwarps - 1)/(nhpw*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhptg - 1)/(nhptg), ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index c9876c1033f1f..539e26c91c34a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size, rows per warp +template // head size, rows per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2031,15 +2031,11 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - //const int64_t iq3 = tgpig[2]; - //const int64_t iq2 = tgpig[1]; - //const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; - - const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups - const uint tph = N_SIMDWIDTH/R; // threads per head + const uint nsg = ntg.y; // number of simdgroups + const uint tph = N_SIMDWIDTH/R; // threads per head const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]*(R*nsg) + R*sgitg + tiisg/tph; + const int64_t iq2 = tgpig[1]*R + tiisg/tph; const int64_t iq1 = tgpig[0]; if (iq2 >= ne02) { @@ -2073,94 +2069,30 @@ kernel void kernel_flash_attn_ext_f16( device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; -// const int64_t D4 = D/4; -// -// // TODO: can we move this to the stack? -// threadgroup half4x4 * V16 = (threadgroup half4x4 *) (shared); -// -// // initialize with zeros -// for (int64_t d = 0; d < D4; ++d) { -// -// } -// -// threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 4*D); -// -// // load Q to shared memory -// for (int64_t d = 0; d < D4; ++d) { -// pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; -// } -// -// half S = 0.0h; -// half M = -INFINITY; -// -// for (int64_t ic = 0; ic < ne11; ++ic) { -// const half mv = mp ? mp[ic] : 0.0h; -// if (mv == -INFINITY) { -// continue; -// } -// -// device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); -// device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); -// -// half4 s4 = 0.0h; -// -// for (int64_t d = 0; d < D4; ++d) { -// s4 += pk4[d] * pq4[d]; -// } -// -// half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; -// -// const half Mold = M; -// -// M = max(M, s); -// -// const half ms = exp(Mold - M); -// const half vs = exp(s - M); -// -// for (int64_t d = 0; d < D4; ++d) { -// V16[d] = V16[d]*ms + pv4[d]*vs; -// } -// -// S = S*ms + vs; -// } -// -// for (int64_t d = 0; d < D4; ++d) { -// V16[d] /= S; -// } -// -// // dst indices -// const int64_t i1 = iq1; -// const int64_t i2 = iq2; -// const int64_t i3 = iq3; -// -// device float4 * dst4 = (device float4 *) dst; -// -// for (int64_t d = 0; d < D4; ++d) { -// dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; -// } - const int64_t D4 = D/4; - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 0*R*D); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 1*R*D); - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 2*R*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*R*D + 128) + 2*R*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*R*D); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(R*D + 32) + 1*R*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(R*D + 32) + 2*R*D); const uint tiih = tiisg%tph; // thread index in head const uint hiisg = tiisg/tph; // head index in simdgroup // load R heads from Q to shared memory for (int64_t i = 0; i < D4/tph; ++i) { - pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; + if (sgitg == 0) { + pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; + } + ps4[hiisg*D4 + tph*i + tiih] = 0.0h; } - simdgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); half S = 0.0h; half M = -INFINITY; - for (int64_t ic = 0; ic < ne11; ++ic) { + for (int64_t ic = sgitg; ic < ne11; ic += nsg) { const half mv = mp ? mp[ic] : 0.0h; if (mv == -INFINITY) { continue; @@ -2175,18 +2107,18 @@ kernel void kernel_flash_attn_ext_f16( s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; } - ss4[hiisg*tph + tiih] = s4; + ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w); simdgroup_barrier(mem_flags::mem_threadgroup); if (tiih == 0) { - s4 = 0.0h; + half s = 0.0h; for (int64_t i = 0; i < tph; ++i) { - s4 += ss4[hiisg*tph + i]; + s += ss[hiisg*tph + i]; } - half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; + s = s*scale + mv; const half Mold = M; @@ -2211,9 +2143,34 @@ kernel void kernel_flash_attn_ext_f16( } } - simdgroup_barrier(mem_flags::mem_threadgroup); - if (tiih == 0) { + ss[2*hiisg + 0] = S; + ss[2*hiisg + 1] = M; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // reduce the warps + if (sgitg == 0 && tiih == 0) { + for (int64_t sg = 1; sg < nsg; ++sg) { + const half S0 = S; + const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; + + const half M0 = M; + const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; + + M = max(M0, M1); + + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + for (int64_t i = 0; i < D4; ++i) { + ps4[hiisg*D4 + i] = ps4[hiisg*D4 + i]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + i]*ms1; + } + } + for (int64_t i = 0; i < D4; ++i) { ps4[hiisg*D4 + i] /= S; } @@ -2228,8 +2185,10 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t i = 0; i < D4/tph; ++i) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih]; + if (sgitg == 0) { + for (int64_t i = 0; i < D4/tph; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih]; + } } } From 17720fad669eed6171ddf17184da5bab50adeb72 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 22:44:41 +0200 Subject: [PATCH 012/100] metal : parallel reduce across heads --- ggml-metal.m | 4 ++-- ggml-metal.metal | 32 ++++++++++++++++++++------------ 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a60dd779a6f09..fdfb50d3d03f4 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,8 +2252,8 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 16; - const int64_t nhptg = 4; // heads per threadgroup + const int64_t nwarps = 32; + const int64_t nhptg = 2; // heads per threadgroup const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2); diff --git a/ggml-metal.metal b/ggml-metal.metal index 539e26c91c34a..919119c8d55af 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2103,6 +2103,7 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; +#pragma unroll for (int64_t i = 0; i < D4/tph; ++i) { s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; } @@ -2114,17 +2115,18 @@ kernel void kernel_flash_attn_ext_f16( if (tiih == 0) { half s = 0.0h; +#pragma unroll for (int64_t i = 0; i < tph; ++i) { s += ss[hiisg*tph + i]; } s = s*scale + mv; - const half Mold = M; + const half m = M; M = max(M, s); - const half ms = exp(Mold - M); + const half ms = exp(m - M); const half vs = exp(s - M); S = S*ms + vs; @@ -2138,6 +2140,7 @@ kernel void kernel_flash_attn_ext_f16( const half ms = ss[2*hiisg + 0]; const half vs = ss[2*hiisg + 1]; +#pragma unroll for (int64_t i = 0; i < D4/tph; ++i) { ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; } @@ -2151,12 +2154,12 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); // reduce the warps - if (sgitg == 0 && tiih == 0) { + if (sgitg == 0) { for (int64_t sg = 1; sg < nsg; ++sg) { - const half S0 = S; + const half S0 = ss[ 2*hiisg + 0]; const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; - const half M0 = M; + const half M0 = ss[ 2*hiisg + 1]; const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; M = max(M0, M1); @@ -2166,13 +2169,18 @@ kernel void kernel_flash_attn_ext_f16( S = S0*ms0 + S1*ms1; - for (int64_t i = 0; i < D4; ++i) { - ps4[hiisg*D4 + i] = ps4[hiisg*D4 + i]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + i]*ms1; + if (tiih == 0) { + ss[2*hiisg + 0] = S; + ss[2*hiisg + 1] = M; + } + + for (int64_t i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1; } } - for (int64_t i = 0; i < D4; ++i) { - ps4[hiisg*D4 + i] /= S; + for (int64_t i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S; } } @@ -2192,9 +2200,9 @@ kernel void kernel_flash_attn_ext_f16( } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 4>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 4>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 4>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 2>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 2>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 2>; kernel void kernel_cpy_f16_f16( device const half * src0, From 1446a12b29f422a0c0040e62c16715a3fb7ce1cb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 23 Jan 2024 18:27:54 +0200 Subject: [PATCH 013/100] metal : efficient flash_attn_f16 implementation --- ggml-metal.m | 14 +- ggml-metal.metal | 279 +++++++++++++++++++++++-------------- tests/test-backend-ops.cpp | 6 +- 3 files changed, 188 insertions(+), 111 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index fdfb50d3d03f4..7b161c69d5801 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2183,6 +2183,7 @@ static bool ggml_metal_graph_compute( struct ggml_tensor * src3 = gf->nodes[i]->src[3]; GGML_ASSERT(ggml_are_same_shape(src1, src2)); + GGML_ASSERT(src3); size_t offs_src2 = 0; size_t offs_src3 = 0; @@ -2252,15 +2253,20 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 32; - const int64_t nhptg = 2; // heads per threadgroup + // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) + const int64_t nsg = ne01 < 4 ? 4 : 2; // simdgroups per threadgroup (a.k.a. warps) - const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2); + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values) + //const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2); + const size_t smem = nqptg*(ne00 + nsg*(ne00 + 1*ncpsg))*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhptg - 1)/(nhptg), ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 919119c8d55af..9b6ceec4e1066 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size, rows per threadgroup +template // head size, heads per threadgroup, queries per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2031,178 +2031,247 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const uint nsg = ntg.y; // number of simdgroups - const uint tph = N_SIMDWIDTH/R; // threads per head + const uint nsg = ntg.y; // number of simdgroups const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]*R + tiisg/tph; - const int64_t iq1 = tgpig[0]; + const int64_t iq2 = tgpig[1]; + const int64_t iq1 = tgpig[0]*Q; if (iq2 >= ne02) { return; } - // assume K and V are same shape - const int64_t ne22 = ne12; - const int64_t ne23 = ne13; + const int64_t D4 = D/4; + const int64_t N4 = N_SIMDWIDTH; + const int64_t L4 = (D4 + N4 - 1)/N4; + const int64_t D8 = D/8; + + const int64_t T = D + nsg*(D + 1*C); // shared memory size per query in half + const int64_t T4 = T/4; // shared memory size per query in half4 + + threadgroup half * pq = (threadgroup half *) (shared + 0*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*D); + threadgroup half * ps = (threadgroup half *) (shared + sgitg*(D + 1*C) + 1*D); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(D + 1*C) + 1*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(D + 1*C) + 2*D); + + for (int64_t i = 0; i < L4; ++i) { + // load heads from Q to shared memory + for (int64_t j = sgitg; j < Q; j += nsg) { + if (iq1 + j < ne01) { + pq4[j*T4 + N4*i + tiisg] = ((device const half4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + tiisg]; + } else { + pq4[j*T4 + N4*i + tiisg] = 0.0h; + } + } - const uint64_t nb21 = nb11; - const uint64_t nb22 = nb12; - const uint64_t nb23 = nb13; + // zero out shared memory + for (int64_t j = 0; j < Q; ++j) { + ps4[j*T4 + N4*i + tiisg] = 0.0h; + } + } - // broadcast - const int64_t rk2 = ne02/ne12; - const int64_t rk3 = ne03/ne13; + if (tiisg < C) { + for (int64_t j = 0; j < Q; ++j) { + ss[j*T + 0 + tiisg] = 0.0h; + } + } - const int64_t rv2 = ne02/ne22; - const int64_t rv3 = ne03/ne23; + threadgroup_barrier(mem_flags::mem_threadgroup); - // k indices - const int64_t ik2 = iq2 / rk2; - const int64_t ik3 = iq3 / rk3; + { + half S[Q] = { 0.0h }; + half M[Q] = { -INFINITY }; - // v indices - const int64_t iv2 = iq2 / rv2; - const int64_t iv3 = iq3 / rv3; + // assume K and V are same shape + const int64_t ne22 = ne12; + const int64_t ne23 = ne13; - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + const uint64_t nb21 = nb11; + const uint64_t nb22 = nb12; + const uint64_t nb23 = nb13; - device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; + // broadcast + const int64_t rk2 = ne02/ne12; + const int64_t rk3 = ne03/ne13; - const int64_t D4 = D/4; + const int64_t rv2 = ne02/ne22; + const int64_t rv3 = ne03/ne23; - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*R*D); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(R*D + 32) + 1*R*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(R*D + 32) + 2*R*D); + // k indices + const int64_t ik2 = iq2 / rk2; + const int64_t ik3 = iq3 / rk3; - const uint tiih = tiisg%tph; // thread index in head - const uint hiisg = tiisg/tph; // head index in simdgroup + // v indices + const int64_t iv2 = iq2 / rv2; + const int64_t iv3 = iq3 / rv3; - // load R heads from Q to shared memory - for (int64_t i = 0; i < D4/tph; ++i) { - if (sgitg == 0) { - pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; - } + simdgroup_half8x8 mq[D8]; - ps4[hiisg*D4 + tph*i + tiih] = 0.0h; - } + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load(mq[i], pq + i*8, T); + } - threadgroup_barrier(mem_flags::mem_threadgroup); + // TODO: this can be improved + device const float * mp[Q]; - half S = 0.0h; - half M = -INFINITY; + { + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - for (int64_t ic = sgitg; ic < ne11; ic += nsg) { - const half mv = mp ? mp[ic] : 0.0h; - if (mv == -INFINITY) { - continue; + for (int64_t j = 0; j < Q; ++j) { + if (iq1 + j < ne01) { + mp[j] = (device const float *) (mask + ((ir + j)%ne31)*nb31); + } else { + mp[j] = nullptr; + } + } } - device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); - device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) { + // skip -INF blocks + // TODO: double-check this + { + float smc = -INFINITY; - half4 s4 = 0.0h; + for (int64_t j = 0; j < Q; ++j) { + const float mc = mp[j] ? mp[j][iic + tiisg] : -INFINITY; + smc = simd_max(max(smc, mc)); + } -#pragma unroll - for (int64_t i = 0; i < D4/tph; ++i) { - s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; - } + if (smc == -INFINITY) { + continue; + } + } + + // Q*K^T + { + simdgroup_half8x8 mk; - ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w); + for (int cc = 0; cc < C/8; ++cc) { + simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); - simdgroup_barrier(mem_flags::mem_threadgroup); + device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); - if (tiih == 0) { - half s = 0.0h; + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load(mk, pk + i*8, nb11/2, 0, true); -#pragma unroll - for (int64_t i = 0; i < tph; ++i) { - s += ss[hiisg*tph + i]; + simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + } + + simdgroup_store(mqk, ss + 8*cc, T, 0, false); + } } - s = s*scale + mv; + // online softmax + for (int64_t j = 0; j < Q; ++j) { + const int64_t p = tiisg; - const half m = M; + const half s = ss[j*T + p]*scale + (mp[j][iic + p]); - M = max(M, s); + half m = M[j]; - const half ms = exp(m - M); - const half vs = exp(s - M); + M[j] = simd_max(max(M[j], s)); - S = S*ms + vs; + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - ss[2*hiisg + 0] = ms; - ss[2*hiisg + 1] = vs; - } + S[j] = S[j]*ms + simd_sum(vs); + + for (int64_t i = 0; i < L4; ++i) { + ps4[j*T4 + N4*i + tiisg] *= ms; + } + + ss[j*T + p] = vs; + } + + // (Q*K^T)*V + { + simdgroup_half8x8 mv; + + for (int64_t i = 0; i < D8; ++i) { + simdgroup_half8x8 mp[C/8]; + simdgroup_half8x8 mqkv; - simdgroup_barrier(mem_flags::mem_threadgroup); + simdgroup_load(mqkv, ps + i*8, T, 0, false); - const half ms = ss[2*hiisg + 0]; - const half vs = ss[2*hiisg + 1]; + for (int cc = 0; cc < C/8; ++cc) { + simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); + } + + for (int cc = 0; cc < C/8; ++cc) { + device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); -#pragma unroll - for (int64_t i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; + simdgroup_load(mv, pv + i*8, nb21/2, 0, false); + + simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv); + } + + simdgroup_store(mqkv, ps + i*8, T, 0, false); + } + } } - } - if (tiih == 0) { - ss[2*hiisg + 0] = S; - ss[2*hiisg + 1] = M; + for (int64_t j = 0; j < Q; ++j) { + if (tiisg == 0) { + ss[j*T + 0] = S[j]; + ss[j*T + 1] = M[j]; + } + } } threadgroup_barrier(mem_flags::mem_threadgroup); // reduce the warps + // TODO: try parallel reduce if (sgitg == 0) { + half S = { 0.0h }; + half M = { -INFINITY }; + for (int64_t sg = 1; sg < nsg; ++sg) { - const half S0 = ss[ 2*hiisg + 0]; - const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; + for (int64_t j = 0; j < Q; ++j) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*(D + 1*C) + 0]; - const half M0 = ss[ 2*hiisg + 1]; - const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*(D + 1*C) + 1]; - M = max(M0, M1); + M = max(M0, M1); - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); - S = S0*ms0 + S1*ms1; + S = S0*ms0 + S1*ms1; - if (tiih == 0) { - ss[2*hiisg + 0] = S; - ss[2*hiisg + 1] = M; - } + if (tiisg == 0) { + ss[j*T + 0] = S; + ss[j*T + 1] = M; + } - for (int64_t i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1; + for (int64_t i = 0; i < L4; ++i) { + ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1; + } } } - - for (int64_t i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S; - } } simdgroup_barrier(mem_flags::mem_threadgroup); - // dst indices - const int64_t i1 = iq1; - const int64_t i2 = iq2; - const int64_t i3 = iq3; - device float4 * dst4 = (device float4 *) dst; if (sgitg == 0) { - for (int64_t i = 0; i < D4/tph; ++i) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih]; + for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { + const half S = ss[j*T + 0]; + + for (int64_t i = 0; i < L4; ++i) { + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + N4*i + tiisg] = (float4) ps4[j*T4 + N4*i + tiisg]/S; + } } } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 2>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 2>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 2>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; kernel void kernel_cpy_f16_f16( device const half * src0, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 51a33c662da56..41ddfcca5b687 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1397,7 +1397,7 @@ struct test_flash_attn_ext : public test_case { } double max_nmse_err() override { - return 5e-4; + return 5e-5; } test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, @@ -1680,7 +1680,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 96, 8)); + test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 8)); + test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 7)); + test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 1)); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From d917746ddb053b73e868fd6e1854ac17b62bd863 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 15:00:49 +0200 Subject: [PATCH 014/100] metal : avoid redundant loads of the attention --- ggml-metal.metal | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 9b6ceec4e1066..785a60e50eba8 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2184,20 +2184,22 @@ kernel void kernel_flash_attn_ext_f16( ss[j*T + p] = vs; } + simdgroup_barrier(mem_flags::mem_none); + // (Q*K^T)*V { simdgroup_half8x8 mv; + simdgroup_half8x8 mp[C/8]; + for (int cc = 0; cc < C/8; ++cc) { + simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); + } + for (int64_t i = 0; i < D8; ++i) { - simdgroup_half8x8 mp[C/8]; simdgroup_half8x8 mqkv; simdgroup_load(mqkv, ps + i*8, T, 0, false); - for (int cc = 0; cc < C/8; ++cc) { - simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); - } - for (int cc = 0; cc < C/8; ++cc) { device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); From 432ad04ffaa445a3837b92dce1c03513009ab4ac Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 15:47:52 +0200 Subject: [PATCH 015/100] metal : scale and mask in matrix form --- ggml-metal.metal | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 785a60e50eba8..ae8f5caeaa75f 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2127,6 +2127,9 @@ kernel void kernel_flash_attn_ext_f16( } } + // prepare diagonal scale matrix + simdgroup_half8x8 mscale(scale); + for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) { // skip -INF blocks // TODO: double-check this @@ -2153,11 +2156,16 @@ kernel void kernel_flash_attn_ext_f16( device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); for (int64_t i = 0; i < D8; ++i) { - simdgroup_load(mk, pk + i*8, nb11/2, 0, true); + simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } + // mqk = mqk*scale + mask + simdgroup_float8x8 mm; + simdgroup_load(mm, mp[0] + iic + 8*cc, nb31/sizeof(float), 0, false); + simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); + simdgroup_store(mqk, ss + 8*cc, T, 0, false); } } @@ -2166,7 +2174,8 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t j = 0; j < Q; ++j) { const int64_t p = tiisg; - const half s = ss[j*T + p]*scale + (mp[j][iic + p]); + //const half s = ss[j*T + p]*scale + (mp[j][iic + p]); + const half s = ss[j*T + p]; half m = M[j]; @@ -2203,7 +2212,7 @@ kernel void kernel_flash_attn_ext_f16( for (int cc = 0; cc < C/8; ++cc) { device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); - simdgroup_load(mv, pv + i*8, nb21/2, 0, false); + simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false); simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv); } From 40ea8cd1aca61294e1987bcb1051317827f1b145 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 16:31:39 +0200 Subject: [PATCH 016/100] metal : fix comment --- ggml-metal.metal | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index ae8f5caeaa75f..9ab9e16c3915a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size, heads per threadgroup, queries per threadgroup +template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2031,16 +2031,12 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const uint nsg = ntg.y; // number of simdgroups + const uint nsg = ntg.y; // number of simdgroups const int64_t iq3 = tgpig[2]; const int64_t iq2 = tgpig[1]; const int64_t iq1 = tgpig[0]*Q; - if (iq2 >= ne02) { - return; - } - const int64_t D4 = D/4; const int64_t N4 = N_SIMDWIDTH; const int64_t L4 = (D4 + N4 - 1)/N4; From f9ca5dcbe86a10cfa873814d5f754b7c9108f339 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 17:46:07 +0200 Subject: [PATCH 017/100] llama : avoid ggml_cast, use F32 query --- ggml-metal.m | 4 ++-- ggml-metal.metal | 3 ++- ggml.c | 31 +++++++++++++++++++++++++++---- ggml.h | 4 ++++ llama.cpp | 3 ++- tests/test-backend-ops.cpp | 16 +++++++--------- 6 files changed, 44 insertions(+), 17 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 7b161c69d5801..7b6762e6d9158 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2177,7 +2177,7 @@ static bool ggml_metal_graph_compute( case GGML_OP_FLASH_ATTN_EXT: { GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == GGML_TYPE_F32); struct ggml_tensor * src2 = gf->nodes[i]->src[2]; struct ggml_tensor * src3 = gf->nodes[i]->src[3]; @@ -2254,7 +2254,7 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&scale length:sizeof( float) atIndex:27]; // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 < 4 ? 4 : 2; // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = ne01 < 4 ? 12 : 2; // simdgroups per threadgroup (a.k.a. warps) const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values) diff --git a/ggml-metal.metal b/ggml-metal.metal index 9ab9e16c3915a..c9e4dcfe99cd4 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2054,8 +2054,9 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t i = 0; i < L4; ++i) { // load heads from Q to shared memory for (int64_t j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); if (iq1 + j < ne01) { - pq4[j*T4 + N4*i + tiisg] = ((device const half4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + tiisg]; + pq4[j*T4 + N4*i + tiisg] = (half4) q4[N4*i + tiisg]; } else { pq4[j*T4 + N4*i + tiisg] = 0.0h; } diff --git a/ggml.c b/ggml.c index 10df03c9c619b..5e515c03fdb9d 100644 --- a/ggml.c +++ b/ggml.c @@ -4178,6 +4178,8 @@ struct ggml_tensor * ggml_mul_mat( void ggml_mul_mat_set_prec( struct ggml_tensor * a, enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_MUL_MAT); + const int32_t prec_i32 = (int32_t) prec; ggml_set_op_params_i32(a, 0, prec_i32); @@ -5781,6 +5783,16 @@ struct ggml_tensor * ggml_flash_attn_ext( return result; } +void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT); + + const int32_t prec_i32 = (int32_t) prec; + + ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos +} + // ggml_flash_ff struct ggml_tensor * ggml_flash_ff( @@ -13347,7 +13359,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(ne2 == N); GGML_ASSERT(P >= 0); - GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbq0 == sizeof(float)); GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); @@ -13408,6 +13420,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( float M = -INFINITY; float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32); + ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D); memset(V16, 0, D*sizeof(ggml_fp16_t)); @@ -13433,10 +13446,19 @@ static void ggml_compute_forward_flash_attn_ext_f16( float s; + // convert Q to F16 in V32 + { + const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); + + for (int64_t d = 0; d < D; ++d) { + Q16[d] = GGML_FP32_TO_FP16(pq[d]); + } + } + ggml_vec_dot_f16(D, &s, (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + Q16); s = s*scale + mv; @@ -13488,13 +13510,14 @@ static void ggml_compute_forward_flash_attn_ext( const struct ggml_tensor * v, const struct ggml_tensor * mask, struct ggml_tensor * dst) { - switch (q->type) { - case GGML_TYPE_F16: + switch (dst->op_params[1]) { + case GGML_PREC_DEFAULT: { ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); } break; default: { + // TODO: implement F32 precision GGML_ASSERT(false); } break; } diff --git a/ggml.h b/ggml.h index 7bca02f2a2c48..e2f74412fde1e 100644 --- a/ggml.h +++ b/ggml.h @@ -1633,6 +1633,10 @@ extern "C" { struct ggml_tensor * mask, float scale); + GGML_API void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec); + GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index 4e6c9f9cc75ea..550caced4ae57 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4368,7 +4368,8 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - cur = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale); + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_DEFAULT); //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 41ddfcca5b687..db1244876ce06 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1386,26 +1386,24 @@ struct test_leaky_relu : public test_case { // GGML_OP_FLASH_ATTN_EXT struct test_flash_attn_ext : public test_case { - const ggml_type typeq; const int64_t hs; // head size const int64_t nh; // num heads const int64_t kv; // kv size const int64_t nb; // batch size std::string vars() override { - return VARS_TO_STR5(typeq, hs, nh, kv, nb); + return VARS_TO_STR4(hs, nh, kv, nb); } double max_nmse_err() override { return 5e-5; } - test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, - int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) - : typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {} + test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) + : hs(hs), nh(nh), kv(kv), nb(nb) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nb, nh, 1); + ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); @@ -1680,9 +1678,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 8)); - test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 7)); - test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 1)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 8)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 7)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 1)); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From 6fea843b246409a3c4b26156745a89e4ba01029b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 17:59:41 +0200 Subject: [PATCH 018/100] metal : add parallel reduce version (disabled) --- ggml-metal.m | 2 +- ggml-metal.metal | 42 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 7b6762e6d9158..cf7880c822db5 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2254,7 +2254,7 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&scale length:sizeof( float) atIndex:27]; // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 < 4 ? 12 : 2; // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = ne01 < 4 ? 12 : 4; // simdgroups per threadgroup (a.k.a. warps) const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values) diff --git a/ggml-metal.metal b/ggml-metal.metal index c9e4dcfe99cd4..6eb2825df558b 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2230,7 +2230,7 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); // reduce the warps - // TODO: try parallel reduce +#if 1 if (sgitg == 0) { half S = { 0.0h }; half M = { -INFINITY }; @@ -2261,6 +2261,46 @@ kernel void kernel_flash_attn_ext_f16( } } } +#else + // parallel reduce + // NOTE: this is significantly slower than the serial version above, likely due to the small number of warps + { + half S = { 0.0h }; + half M = { -INFINITY }; + + for (int64_t sg = nsg/2; sg > 0; sg /= 2) { + if (sgitg >= sg) { + continue; + } + + for (int64_t j = 0; j < Q; ++j) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*(D + 1*C) + 0]; + + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*(D + 1*C) + 1]; + + M = max(M0, M1); + + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[j*T + 0] = S; + ss[j*T + 1] = M; + } + + for (int64_t i = 0; i < L4; ++i) { + ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } +#endif simdgroup_barrier(mem_flags::mem_threadgroup); From 77f6976a87f6d034cf0f7a77e14a011da7901911 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 13:15:00 +0200 Subject: [PATCH 019/100] metal : move output into local memory + optimize - the result from each simdgroup now stays in the registers - significantly reduced SRAM usage - more efficient skipping of -INF blocks - avoid simdgroup barrier in hot loop - add comments --- ggml-metal.m | 12 +-- ggml-metal.metal | 220 ++++++++++++++++++++++------------------------- 2 files changed, 110 insertions(+), 122 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index eabc16f416645..a7e126bff5318 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2213,14 +2213,14 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 < 4 ? 12 : 4; // simdgroups per threadgroup (a.k.a. warps) - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values) + const int64_t ncpsg = 32; // cache values per simdgroup + + // simdgroups per threadgroup (a.k.a. warps) + // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/32, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; - //const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2); - const size_t smem = nqptg*(ne00 + nsg*(ne00 + 1*ncpsg))*(sizeof(float)/2); + const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); diff --git a/ggml-metal.metal b/ggml-metal.metal index 6eb2825df558b..b564f014de2b6 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,6 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); +// ref: https://arxiv.org/pdf/2307.08691.pdf template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, @@ -2038,39 +2039,45 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iq1 = tgpig[0]*Q; const int64_t D4 = D/4; - const int64_t N4 = N_SIMDWIDTH; - const int64_t L4 = (D4 + N4 - 1)/N4; const int64_t D8 = D/8; + const int64_t NW = N_SIMDWIDTH; + const int64_t L4 = (D4 + NW - 1)/NW; + const int64_t SH = (C + Q); // shared memory per simdgroup in (half) - const int64_t T = D + nsg*(D + 1*C); // shared memory size per query in half - const int64_t T4 = T/4; // shared memory size per query in half4 + const int64_t T = D + nsg*SH; // shared memory size per query in (half) + const int64_t T4 = T/4; // shared memory size per query in (half4) - threadgroup half * pq = (threadgroup half *) (shared + 0*D); - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*D); - threadgroup half * ps = (threadgroup half *) (shared + sgitg*(D + 1*C) + 1*D); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(D + 1*C) + 1*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(D + 1*C) + 2*D); + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // scratch buffer for attention + threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for diagonal matrix + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + simdgroup_half8x8 lo[D8]; for (int64_t i = 0; i < L4; ++i) { // load heads from Q to shared memory for (int64_t j = sgitg; j < Q; j += nsg) { device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); if (iq1 + j < ne01) { - pq4[j*T4 + N4*i + tiisg] = (half4) q4[N4*i + tiisg]; + sq4[j*T4 + NW*i + tiisg] = (half4) q4[NW*i + tiisg]; } else { - pq4[j*T4 + N4*i + tiisg] = 0.0h; + sq4[j*T4 + NW*i + tiisg] = 0.0h; } } + } - // zero out shared memory - for (int64_t j = 0; j < Q; ++j) { - ps4[j*T4 + N4*i + tiisg] = 0.0h; - } + // zero out lo + for (int64_t i = 0; i < D8; ++i) { + lo[i] = make_filled_simdgroup_matrix(0.0h); } + // zero out shared memory SH if (tiisg < C) { for (int64_t j = 0; j < Q; ++j) { - ss[j*T + 0 + tiisg] = 0.0h; + ss[j*T + tiisg] = 0.0h; + if (tiisg < Q) { + ss[j*T + C + tiisg] = 0.0h; + } } } @@ -2103,46 +2110,24 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv2 = iq2 / rv2; const int64_t iv3 = iq3 / rv3; + // load the queries from shared memory into local memory simdgroup_half8x8 mq[D8]; for (int64_t i = 0; i < D8; ++i) { - simdgroup_load(mq[i], pq + i*8, T); + simdgroup_load(mq[i], sq + i*8, T); } - // TODO: this can be improved - device const float * mp[Q]; + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - { - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - - for (int64_t j = 0; j < Q; ++j) { - if (iq1 + j < ne01) { - mp[j] = (device const float *) (mask + ((ir + j)%ne31)*nb31); - } else { - mp[j] = nullptr; - } - } - } + // pointer to the mask + device const float * mp = (device const float *) (mask + (ir%ne31)*nb31); // prepare diagonal scale matrix simdgroup_half8x8 mscale(scale); - for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) { - // skip -INF blocks - // TODO: double-check this - { - float smc = -INFINITY; - - for (int64_t j = 0; j < Q; ++j) { - const float mc = mp[j] ? mp[j][iic + tiisg] : -INFINITY; - smc = simd_max(max(smc, mc)); - } - - if (smc == -INFINITY) { - continue; - } - } - + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int64_t ic = C*sgitg; ic < ne11; ic += C*nsg) { // Q*K^T { simdgroup_half8x8 mk; @@ -2150,7 +2135,7 @@ kernel void kernel_flash_attn_ext_f16( for (int cc = 0; cc < C/8; ++cc) { simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); - device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); + device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); for (int64_t i = 0; i < D8; ++i) { simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); @@ -2160,65 +2145,77 @@ kernel void kernel_flash_attn_ext_f16( // mqk = mqk*scale + mask simdgroup_float8x8 mm; - simdgroup_load(mm, mp[0] + iic + 8*cc, nb31/sizeof(float), 0, false); + simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(float), 0, false); simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); simdgroup_store(mqk, ss + 8*cc, T, 0, false); } } + // used to detect blocks full of -INF + half smax = -INFINITY; + // online softmax for (int64_t j = 0; j < Q; ++j) { const int64_t p = tiisg; - //const half s = ss[j*T + p]*scale + (mp[j][iic + p]); const half s = ss[j*T + p]; - half m = M[j]; - + smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); + const half m = M[j]; + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); S[j] = S[j]*ms + simd_sum(vs); - for (int64_t i = 0; i < L4; ++i) { - ps4[j*T4 + N4*i + tiisg] *= ms; + // create an 8x8 diagonal matrix for rescaling the output + if (p == j) { + ss[j*T + C + j] = ms; } + // the P matrix from the paper (Q rows, C columns) ss[j*T + p] = vs; } - simdgroup_barrier(mem_flags::mem_none); + // skip -INF blocks + if (smax == -INFINITY) { + continue; + } - // (Q*K^T)*V + // O = diag(ms)*O { - simdgroup_half8x8 mv; + simdgroup_half8x8 mm; - simdgroup_half8x8 mp[C/8]; - for (int cc = 0; cc < C/8; ++cc) { - simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); - } + simdgroup_load(mm, ss + C, T, 0, false); for (int64_t i = 0; i < D8; ++i) { - simdgroup_half8x8 mqkv; + simdgroup_multiply(lo[i], mm, lo[i]); + } + } - simdgroup_load(mqkv, ps + i*8, T, 0, false); + // O = O + (Q*K^T)*V + { + simdgroup_half8x8 mv; + + for (int cc = 0; cc < C/8; ++cc) { + simdgroup_half8x8 mp; + simdgroup_load(mp, ss + 8*cc, T, 0, false); - for (int cc = 0; cc < C/8; ++cc) { - device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + for (int64_t i = 0; i < D8; ++i) { + device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false); - simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv); + simdgroup_multiply_accumulate(lo[i], mp, mv, lo[i]); } - - simdgroup_store(mqkv, ps + i*8, T, 0, false); } } } + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (int64_t j = 0; j < Q; ++j) { if (tiisg == 0) { ss[j*T + 0] = S[j]; @@ -2227,58 +2224,30 @@ kernel void kernel_flash_attn_ext_f16( } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // reduce the warps -#if 1 - if (sgitg == 0) { + // reduce the warps sequentially + for (int64_t sg = 1; sg < nsg; ++sg) { half S = { 0.0h }; half M = { -INFINITY }; - for (int64_t sg = 1; sg < nsg; ++sg) { - for (int64_t j = 0; j < Q; ++j) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + sg*(D + 1*C) + 0]; - - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + sg*(D + 1*C) + 1]; - - M = max(M0, M1); - - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); - - S = S0*ms0 + S1*ms1; - - if (tiisg == 0) { - ss[j*T + 0] = S; - ss[j*T + 1] = M; - } + threadgroup_barrier(mem_flags::mem_threadgroup); - for (int64_t i = 0; i < L4; ++i) { - ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1; - } + // each simdgroup stores its output to shared memory, reusing sq4 + if (sgitg == sg) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); } } - } -#else - // parallel reduce - // NOTE: this is significantly slower than the serial version above, likely due to the small number of warps - { - half S = { 0.0h }; - half M = { -INFINITY }; - for (int64_t sg = nsg/2; sg > 0; sg /= 2) { - if (sgitg >= sg) { - continue; - } + threadgroup_barrier(mem_flags::mem_threadgroup); + // the first simdgroup accumulates the results from the other simdgroups + if (sgitg == 0) { for (int64_t j = 0; j < Q; ++j) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + sg*(D + 1*C) + 0]; + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*SH + 0]; - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + sg*(D + 1*C) + 1]; + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*SH + 1]; M = max(M0, M1); @@ -2290,28 +2259,47 @@ kernel void kernel_flash_attn_ext_f16( if (tiisg == 0) { ss[j*T + 0] = S; ss[j*T + 1] = M; - } - for (int64_t i = 0; i < L4; ++i) { - ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1; + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + sg*SH] = ms1; } } - threadgroup_barrier(mem_flags::mem_threadgroup); + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + { + simdgroup_half8x8 t; + simdgroup_half8x8 ms0; + simdgroup_half8x8 ms1; + + simdgroup_load(ms0, ss + C, T, 0, false); + simdgroup_load(ms1, ss + C + sg*SH, T, 0, false); + + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load (t, sq + i*8, T, 0, false); + simdgroup_multiply(t, ms1, t); + + simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); + } + } } } -#endif - simdgroup_barrier(mem_flags::mem_threadgroup); + // store result to shared memory (reuse sq4) + if (sgitg == 0) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); + } + } device float4 * dst4 = (device float4 *) dst; + // final rescale with 1/S and store to global memory if (sgitg == 0) { for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; for (int64_t i = 0; i < L4; ++i) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + N4*i + tiisg] = (float4) ps4[j*T4 + N4*i + tiisg]/S; + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + NW*i + tiisg] = (float4) sq4[j*T4 + NW*i + tiisg]/S; } } } From ecc466a460abc7ad73df3b22a3e0957170bcf7b9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 15:42:57 +0200 Subject: [PATCH 020/100] metal : add tests, fix scaling, support C > 32 --- ggml-metal.m | 6 ++-- ggml-metal.metal | 62 ++++++++++++++++++++------------------ tests/test-backend-ops.cpp | 14 ++++++--- 3 files changed, 46 insertions(+), 36 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a7e126bff5318..484ef89398e7a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2213,12 +2213,12 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! (multiple of 8) + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! (multiple of 32) // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/32, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); diff --git a/ggml-metal.metal b/ggml-metal.metal index b564f014de2b6..7b604eb61a177 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2041,7 +2041,6 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; const int64_t D8 = D/8; const int64_t NW = N_SIMDWIDTH; - const int64_t L4 = (D4 + NW - 1)/NW; const int64_t SH = (C + Q); // shared memory per simdgroup in (half) const int64_t T = D + nsg*SH; // shared memory size per query in (half) @@ -2054,14 +2053,15 @@ kernel void kernel_flash_attn_ext_f16( // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) simdgroup_half8x8 lo[D8]; - for (int64_t i = 0; i < L4; ++i) { - // load heads from Q to shared memory - for (int64_t j = sgitg; j < Q; j += nsg) { - device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + // load heads from Q to shared memory + for (int64_t j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + + for (int64_t i = tiisg; i < D4; i += NW) { if (iq1 + j < ne01) { - sq4[j*T4 + NW*i + tiisg] = (half4) q4[NW*i + tiisg]; + sq4[j*T4 + i] = (half4) q4[i]; } else { - sq4[j*T4 + NW*i + tiisg] = 0.0h; + sq4[j*T4 + i] = 0.0h; } } } @@ -2072,12 +2072,9 @@ kernel void kernel_flash_attn_ext_f16( } // zero out shared memory SH - if (tiisg < C) { - for (int64_t j = 0; j < Q; ++j) { - ss[j*T + tiisg] = 0.0h; - if (tiisg < Q) { - ss[j*T + C + tiisg] = 0.0h; - } + for (int64_t j = 0; j < Q; ++j) { + for (int64_t i = tiisg; i < SH; i += NW) { + ss[j*T + i] = 0.0h; } } @@ -2157,27 +2154,34 @@ kernel void kernel_flash_attn_ext_f16( // online softmax for (int64_t j = 0; j < Q; ++j) { - const int64_t p = tiisg; - - const half s = ss[j*T + p]; + const half m = M[j]; - smax = simd_max(max(smax, s)); - M[j] = simd_max(max(M[j], s)); + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; - const half m = M[j]; + smax = simd_max(max(smax, s)); + M[j] = simd_max(max(M[j], s)); + } - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + const half ms = exp(m - M[j]); - S[j] = S[j]*ms + simd_sum(vs); + S[j] = S[j]*ms; // create an 8x8 diagonal matrix for rescaling the output - if (p == j) { + if (tiisg == j) { ss[j*T + C + j] = ms; } - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; + + const half vs = exp(s - M[j]); + + S[j] = S[j] + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } } // skip -INF blocks @@ -2231,7 +2235,7 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); - // each simdgroup stores its output to shared memory, reusing sq4 + // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { for (int64_t i = 0; i < D8; ++i) { simdgroup_store(lo[i], sq + i*8, T, 0, false); @@ -2284,7 +2288,7 @@ kernel void kernel_flash_attn_ext_f16( } } - // store result to shared memory (reuse sq4) + // store result to shared memory (reuse sq) if (sgitg == 0) { for (int64_t i = 0; i < D8; ++i) { simdgroup_store(lo[i], sq + i*8, T, 0, false); @@ -2298,8 +2302,8 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; - for (int64_t i = 0; i < L4; ++i) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + NW*i + tiisg] = (float4) sq4[j*T4 + NW*i + tiisg]/S; + for (int64_t i = tiisg; i < D4; i += NW) { + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; } } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 4c98bef7cf3a6..4093a52f2eef1 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1395,7 +1395,7 @@ struct test_flash_attn_ext : public test_case { } double max_nmse_err() override { - return 5e-5; + return 5e-4; } test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) @@ -1677,9 +1677,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 8)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 7)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 1)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 8)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 7)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 1)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 8)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 7)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 1)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 8)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 7)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 1)); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From 3a428a10973a751af72b55b9ef396de9c305c6ac Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 17:47:22 +0200 Subject: [PATCH 021/100] metal : improve precision --- ggml-metal.metal | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 7b604eb61a177..b6b5fd997b93a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2120,7 +2120,7 @@ kernel void kernel_flash_attn_ext_f16( device const float * mp = (device const float *) (mask + (ir%ne31)*nb31); // prepare diagonal scale matrix - simdgroup_half8x8 mscale(scale); + simdgroup_float8x8 mscale(scale); // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns @@ -2163,7 +2163,7 @@ kernel void kernel_flash_attn_ext_f16( M[j] = simd_max(max(M[j], s)); } - const half ms = exp(m - M[j]); + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); S[j] = S[j]*ms; @@ -2175,7 +2175,7 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t p = tiisg; p < C; p += NW) { const half s = ss[j*T + p]; - const half vs = exp(s - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); S[j] = S[j] + simd_sum(vs); @@ -2255,8 +2255,8 @@ kernel void kernel_flash_attn_ext_f16( M = max(M0, M1); - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); + const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M); + const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M); S = S0*ms0 + S1*ms1; From 8612864108760897261d0d10101f68355899b03f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 18:10:16 +0200 Subject: [PATCH 022/100] ggml : fix f16 mad --- ggml.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index 6bba840d93d0c..fc0886aecf5a1 100644 --- a/ggml.c +++ b/ggml.c @@ -1344,12 +1344,12 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const // leftovers for (int i = np; i < n; ++i) { - y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v); + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); } #else // scalar for (int i = 0; i < n; ++i) { - y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v); + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); } #endif } From 134c81c78dfdeaca988ea2505cc6f0c0aec2d243 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 22:23:40 +0200 Subject: [PATCH 023/100] metal : minor --- ggml-metal.metal | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index b6b5fd997b93a..ad6a4a318f4c3 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2127,15 +2127,14 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t ic = C*sgitg; ic < ne11; ic += C*nsg) { // Q*K^T { - simdgroup_half8x8 mk; - for (int cc = 0; cc < C/8; ++cc) { simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); for (int64_t i = 0; i < D8; ++i) { - simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); + simdgroup_half8x8 mk; + simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } @@ -2192,7 +2191,6 @@ kernel void kernel_flash_attn_ext_f16( // O = diag(ms)*O { simdgroup_half8x8 mm; - simdgroup_load(mm, ss + C, T, 0, false); for (int64_t i = 0; i < D8; ++i) { @@ -2202,8 +2200,6 @@ kernel void kernel_flash_attn_ext_f16( // O = O + (Q*K^T)*V { - simdgroup_half8x8 mv; - for (int cc = 0; cc < C/8; ++cc) { simdgroup_half8x8 mp; simdgroup_load(mp, ss + 8*cc, T, 0, false); @@ -2211,6 +2207,7 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t i = 0; i < D8; ++i) { device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + simdgroup_half8x8 mv; simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false); simdgroup_multiply_accumulate(lo[i], mp, mv, lo[i]); From 1db22d7032fd55a612e400164cb70ad238bbc055 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 23:08:31 +0200 Subject: [PATCH 024/100] metal : support Q > 8 --- examples/batched-bench/batched-bench.cpp | 2 +- ggml-metal.m | 7 ++- ggml-metal.metal | 80 +++++++++++++++--------- 3 files changed, 55 insertions(+), 34 deletions(-) diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 7924db267401c..4992b57f6f9db 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -104,7 +104,7 @@ int main(int argc, char ** argv) { ctx_params.seed = 1234; ctx_params.n_ctx = n_kv_max; - ctx_params.n_batch = 512; + ctx_params.n_batch = 2048; ctx_params.mul_mat_q = mmq; ctx_params.n_threads = params.n_threads; diff --git a/ggml-metal.m b/ggml-metal.m index ef799ef57b643..a0dd1d0df5bcb 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2206,8 +2206,11 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! (multiple of 8) - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! (multiple of 32) + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) diff --git a/ggml-metal.metal b/ggml-metal.metal index ad6a4a318f4c3..08c000cc4c027 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2040,6 +2040,7 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; const int64_t D8 = D/8; + const int64_t Q8 = Q/8; const int64_t NW = N_SIMDWIDTH; const int64_t SH = (C + Q); // shared memory per simdgroup in (half) @@ -2051,7 +2052,7 @@ kernel void kernel_flash_attn_ext_f16( threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for diagonal matrix // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - simdgroup_half8x8 lo[D8]; + simdgroup_half8x8 lo[Q8][D8]; // load heads from Q to shared memory for (int64_t j = sgitg; j < Q; j += nsg) { @@ -2067,8 +2068,10 @@ kernel void kernel_flash_attn_ext_f16( } // zero out lo - for (int64_t i = 0; i < D8; ++i) { - lo[i] = make_filled_simdgroup_matrix(0.0h); + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + lo[j][i] = make_filled_simdgroup_matrix(0.0h); + } } // zero out shared memory SH @@ -2108,10 +2111,12 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv3 = iq3 / rv3; // load the queries from shared memory into local memory - simdgroup_half8x8 mq[D8]; + simdgroup_half8x8 mq[Q8][D8]; - for (int64_t i = 0; i < D8; ++i) { - simdgroup_load(mq[i], sq + i*8, T); + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T); + } } const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; @@ -2128,7 +2133,10 @@ kernel void kernel_flash_attn_ext_f16( // Q*K^T { for (int cc = 0; cc < C/8; ++cc) { - simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); + simdgroup_half8x8 mqk[Q8]; + for (int64_t j = 0; j < Q8; ++j) { + mqk[j] = make_filled_simdgroup_matrix(0.h); + } device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); @@ -2136,15 +2144,19 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_half8x8 mk; simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose - simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + for (int64_t j = 0; j < Q8; ++j) { + simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]); + } } // mqk = mqk*scale + mask - simdgroup_float8x8 mm; - simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(float), 0, false); - simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); + for (int64_t j = 0; j < Q8; ++j) { + simdgroup_float8x8 mm; + simdgroup_load(mm, mp + 8*j*(nb31/sizeof(float)) + ic + 8*cc, nb31/sizeof(float), 0, false); + simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); - simdgroup_store(mqk, ss + 8*cc, T, 0, false); + simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false); + } } } @@ -2166,7 +2178,7 @@ kernel void kernel_flash_attn_ext_f16( S[j] = S[j]*ms; - // create an 8x8 diagonal matrix for rescaling the output + // create a QxQ diagonal matrix for rescaling the output if (tiisg == j) { ss[j*T + C + j] = ms; } @@ -2189,28 +2201,30 @@ kernel void kernel_flash_attn_ext_f16( } // O = diag(ms)*O - { + for (int64_t j = 0; j < Q8; ++j) { simdgroup_half8x8 mm; - simdgroup_load(mm, ss + C, T, 0, false); + simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false); for (int64_t i = 0; i < D8; ++i) { - simdgroup_multiply(lo[i], mm, lo[i]); + simdgroup_multiply(lo[j][i], mm, lo[j][i]); } } // O = O + (Q*K^T)*V { for (int cc = 0; cc < C/8; ++cc) { - simdgroup_half8x8 mp; - simdgroup_load(mp, ss + 8*cc, T, 0, false); + device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); for (int64_t i = 0; i < D8; ++i) { - device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + simdgroup_half8x8 mk; + simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); - simdgroup_half8x8 mv; - simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false); + for (int64_t j = 0; j < Q8; ++j) { + simdgroup_half8x8 mv; + simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false); - simdgroup_multiply_accumulate(lo[i], mp, mv, lo[i]); + simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]); + } } } } @@ -2234,8 +2248,10 @@ kernel void kernel_flash_attn_ext_f16( // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { - for (int64_t i = 0; i < D8; ++i) { - simdgroup_store(lo[i], sq + i*8, T, 0, false); + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); + } } } @@ -2267,19 +2283,19 @@ kernel void kernel_flash_attn_ext_f16( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - { + for (int64_t j = 0; j < Q8; ++j) { simdgroup_half8x8 t; simdgroup_half8x8 ms0; simdgroup_half8x8 ms1; - simdgroup_load(ms0, ss + C, T, 0, false); - simdgroup_load(ms1, ss + C + sg*SH, T, 0, false); + simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false); + simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false); for (int64_t i = 0; i < D8; ++i) { - simdgroup_load (t, sq + i*8, T, 0, false); + simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false); simdgroup_multiply(t, ms1, t); - simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); + simdgroup_multiply_accumulate(lo[j][i], ms0, lo[j][i], t); } } } @@ -2287,8 +2303,10 @@ kernel void kernel_flash_attn_ext_f16( // store result to shared memory (reuse sq) if (sgitg == 0) { - for (int64_t i = 0; i < D8; ++i) { - simdgroup_store(lo[i], sq + i*8, T, 0, false); + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); + } } } From 4794821a31d5778b3398b8375d29fa63a539c8c4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 16:44:55 +0200 Subject: [PATCH 025/100] tests : add ATTN tests --- tests/test-backend-ops.cpp | 70 +++++++++++++++++++++++++++++++++----- 1 file changed, 61 insertions(+), 9 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c18ff07ea4d21..0ce498e9e7dd4 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1418,6 +1418,48 @@ struct test_flash_attn_ext : public test_case { } }; +// Attention +struct test_attn : public test_case { + const int64_t hs; // head size + const int64_t nh; // num heads + const int64_t kv; // kv size + const int64_t nb; // batch size + + std::string op_desc(ggml_tensor * t) override { + return "ATTN"; + + GGML_UNUSED(t); + } + + std::string vars() override { + return VARS_TO_STR4(hs, nh, kv, nb); + } + + double max_nmse_err() override { + return 5e-4; + } + + test_attn(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) + : hs(hs), nh(nh), kv(kv), nb(nb) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); + ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); // transposed + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); + + struct ggml_tensor * cur; + + cur = ggml_mul_mat (ctx, k, q); + cur = ggml_soft_max_ext(ctx, cur, mask, 1.0f/sqrtf(hs)); + cur = ggml_mul_mat (ctx, v, cur); + cur = ggml_permute (ctx, cur, 0, 2, 1, 3); + cur = ggml_cont_2d (ctx, cur, hs*nh, nb); + + return cur; + } +}; + // Mixtral MOE struct test_moe : public test_case { const int n_experts; @@ -1684,15 +1726,25 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 8)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 7)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 1)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 8)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 7)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 1)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 8)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 7)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 1)); + test_cases.emplace_back(new test_attn(64, 32, 512, 8)); + test_cases.emplace_back(new test_attn(64, 32, 512, 7)); + test_cases.emplace_back(new test_attn(64, 32, 512, 1)); + test_cases.emplace_back(new test_attn(80, 32, 512, 8)); + test_cases.emplace_back(new test_attn(80, 32, 512, 7)); + test_cases.emplace_back(new test_attn(80, 32, 512, 1)); + test_cases.emplace_back(new test_attn(128, 32, 512, 8)); + test_cases.emplace_back(new test_attn(128, 32, 512, 7)); + test_cases.emplace_back(new test_attn(128, 32, 512, 1)); + + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 8)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 7)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 1)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 8)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 7)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 1)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 8)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 7)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 1)); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From abeaf0d90ee82096a0aba20785f1e37bd1f3aa41 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 18:12:24 +0200 Subject: [PATCH 026/100] metal : disable buffer allocation logs --- ggml-metal.m | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a0dd1d0df5bcb..a637f04875dbe 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2421,10 +2421,13 @@ GGML_CALL static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buff UNUSED(buft); } -static void ggml_backend_metal_log_allocated_size(id device) { +static void ggml_backend_metal_log_allocated_size(id device, size_t size_aligned) { +#ifndef GGML_METAL_NDEBUG #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) if (@available(macOS 10.12, iOS 16.0, *)) { - GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)", + GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)", + __func__, + size_aligned / 1024.0 / 1024.0, device.currentAllocatedSize / 1024.0 / 1024.0, device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); @@ -2434,10 +2437,15 @@ static void ggml_backend_metal_log_allocated_size(id device) { GGML_METAL_LOG_INFO("\n"); } } else { - GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0); + GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n", + __func__, + size_aligned / 1024.0 / 1024.0, + device.currentAllocatedSize / 1024.0 / 1024.0); } +#endif #endif UNUSED(device); + UNUSED(size_aligned); } GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -2471,8 +2479,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff return NULL; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); - ggml_backend_metal_log_allocated_size(device); + ggml_backend_metal_log_allocated_size(device, size_aligned); return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size); } @@ -2549,7 +2556,7 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, return false; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); + ggml_backend_metal_log_allocated_size(device, size_aligned); ++ctx->n_buffers; } else { @@ -2572,7 +2579,8 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, return false; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, offs = %12ld", __func__, size_step_aligned / 1024.0 / 1024.0, i); + ggml_backend_metal_log_allocated_size(device, size_step_aligned); + if (i + size_step < size) { GGML_METAL_LOG_INFO("\n"); } @@ -2581,8 +2589,6 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, } } - ggml_backend_metal_log_allocated_size(device); - return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size); } From c6c1132e5e6658b3c209433ed5ef75067ef31a2f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 18:22:28 +0200 Subject: [PATCH 027/100] tests : more --- ggml-metal.m | 9 +++++++++ ggml-metal.metal | 3 +++ ggml.c | 5 ----- tests/test-backend-ops.cpp | 29 ++++++++++------------------- 4 files changed, 22 insertions(+), 24 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a637f04875dbe..4b5fd0bb8fc58 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -137,7 +137,10 @@ GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -505,7 +508,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); @@ -2166,7 +2172,10 @@ static bool ggml_metal_graph_compute( switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; default: { GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); diff --git a/ggml-metal.metal b/ggml-metal.metal index 08c000cc4c027..be059d78f505a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2326,7 +2326,10 @@ kernel void kernel_flash_attn_ext_f16( template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>; kernel void kernel_cpy_f16_f16( device const half * src0, diff --git a/ggml.c b/ggml.c index e8a5fcfa485c1..57271a1ad43e3 100644 --- a/ggml.c +++ b/ggml.c @@ -13554,11 +13554,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t D = neq0; const int64_t N = neq1; - const int64_t P = nek1 - N; GGML_ASSERT(ne0 == D); GGML_ASSERT(ne2 == N); - GGML_ASSERT(P >= 0); GGML_ASSERT(nbq0 == sizeof(float)); GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); @@ -13569,7 +13567,6 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(nev0 == D); GGML_ASSERT(neq1 == N); - GGML_ASSERT(nek1 == N + P); GGML_ASSERT(nev0 == D); // dst cannot be transposed or permuted @@ -13608,8 +13605,6 @@ static void ggml_compute_forward_flash_attn_ext_f16( float scale = 1.0f; memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); - //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); - // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0ce498e9e7dd4..f57e8ab1a853e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1726,25 +1726,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_attn(64, 32, 512, 8)); - test_cases.emplace_back(new test_attn(64, 32, 512, 7)); - test_cases.emplace_back(new test_attn(64, 32, 512, 1)); - test_cases.emplace_back(new test_attn(80, 32, 512, 8)); - test_cases.emplace_back(new test_attn(80, 32, 512, 7)); - test_cases.emplace_back(new test_attn(80, 32, 512, 1)); - test_cases.emplace_back(new test_attn(128, 32, 512, 8)); - test_cases.emplace_back(new test_attn(128, 32, 512, 7)); - test_cases.emplace_back(new test_attn(128, 32, 512, 1)); - - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 8)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 7)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 1)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 8)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 7)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 1)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 8)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 7)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 1)); + for (int hs : { 64, 80, 96, 112, 128, 256, }) { + for (int nh : { 32, }) { + for (int kv : { 512, 1024, 2048, 4096, }) { + for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { + test_cases.emplace_back(new test_attn (hs, nh, kv, nb)); + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); + } + } + } + } #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From 5fcb9c1c5af108056c8ad51fc1995de9d7707d2f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 19:46:22 +0200 Subject: [PATCH 028/100] metal : faster inner loop for C == 32 --- ggml-metal.metal | 59 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 17 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index be059d78f505a..db4c7cfde0037 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2048,8 +2048,8 @@ kernel void kernel_flash_attn_ext_f16( const int64_t T4 = T/4; // shared memory size per query in (half4) threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // scratch buffer for attention - threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for diagonal matrix + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) simdgroup_half8x8 lo[Q8][D8]; @@ -2164,34 +2164,59 @@ kernel void kernel_flash_attn_ext_f16( half smax = -INFINITY; // online softmax - for (int64_t j = 0; j < Q; ++j) { - const half m = M[j]; + if (C == 32) { + for (int64_t j = 0; j < Q; ++j) { + const int64_t p = tiisg; - for (int64_t p = tiisg; p < C; p += NW) { + const half m = M[j]; const half s = ss[j*T + p]; smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); - } - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - S[j] = S[j]*ms; + S[j] = S[j]*ms + simd_sum(vs); + + // create a QxQ diagonal matrix for rescaling the output + if (p == j) { + ss[j*T + C + j] = ms; + } - // create a QxQ diagonal matrix for rescaling the output - if (tiisg == j) { - ss[j*T + C + j] = ms; + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; } + } else { + for (int64_t j = 0; j < Q; ++j) { + const half m = M[j]; - for (int64_t p = tiisg; p < C; p += NW) { - const half s = ss[j*T + p]; + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + smax = simd_max(max(smax, s)); + M[j] = simd_max(max(M[j], s)); + } - S[j] = S[j] + simd_sum(vs); + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + S[j] = S[j]*ms; + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg == j) { + ss[j*T + C + j] = ms; + } + + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; + + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + + S[j] = S[j] + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } } } From d073e4f93337560e552f0d3de4b6b07bf13ef3f5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 Jan 2024 21:45:32 +0200 Subject: [PATCH 029/100] metal : fix array initialization --- ggml-metal.metal | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index db4c7cfde0037..41f6169de8abd 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2084,8 +2084,8 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); { - half S[Q] = { 0.0h }; - half M[Q] = { -INFINITY }; + half S[Q] = { [0 ... Q-1] = 0.0h }; + half M[Q] = { [0 ... Q-1] = -INFINITY }; // assume K and V are same shape const int64_t ne22 = ne12; From 78df5527e4e9eafb181200384fbed80c8116042e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 Jan 2024 21:46:49 +0200 Subject: [PATCH 030/100] tests : ifdef --- tests/test-backend-ops.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index f57e8ab1a853e..07182c6d8aa63 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1726,6 +1726,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); +#if 0 for (int hs : { 64, 80, 96, 112, 128, 256, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { @@ -1736,6 +1737,18 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } } +#else + for (int hs : { 128, }) { + for (int nh : { 32, }) { + for (int kv : { 512, 1024, }) { + for (int nb : { 1, 2, 4, 8, 512 }) { + test_cases.emplace_back(new test_attn (hs, nh, kv, nb)); + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); + } + } + } + } +#endif #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From 8ad92dc1ec9aa6549c68900daa7ab93b57fa3ae5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 31 Jan 2024 19:17:16 +0200 Subject: [PATCH 031/100] ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext --- ggml-cuda.cu | 20 +++++++++---------- ggml-metal.m | 6 ++++++ ggml-metal.metal | 40 ++++++++++++++++++-------------------- ggml.c | 13 +++++++++---- ggml.h | 12 +++++++----- llama.cpp | 40 ++++++++++++++++++++++---------------- tests/test-backend-ops.cpp | 10 +++++----- 7 files changed, 79 insertions(+), 62 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e565957421795..c57a031e4060c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5917,7 +5917,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int } template -static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { +static __global__ void soft_max_f16(const float * x, const half * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template; const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2; @@ -5952,12 +5952,12 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds if (need_check && col_data + 0 >= ncols_data) { val.x = -INFINITY; } else { - val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f); + val.x = x[ix + 0]*scale + (y ? __half2float(y[iy + 0]) : 0.0f); } if (need_check && col_data + WARP_SIZE >= ncols_data) { val.y = -INFINITY; } else { - val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f); + val.y = x[ix + WARP_SIZE]*scale + (y ? __half2float(y[iy + WARP_SIZE]) : 0.0f); } if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) { vals[col_smem] = val; @@ -6047,7 +6047,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds } template -static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { +static __global__ void soft_max_f32(const float * x, const half * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; @@ -6077,7 +6077,7 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (y ? y[iy] : 0.0f); + const float val = x[ix]*scale + (y ? __half2float(y[iy]) : 0.0f); vals[col] = val; max_val = max(max_val, val); } @@ -7585,7 +7585,7 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols diag_mask_inf_f32<<>>(x, dst, ncols_x, rows_per_channel, n_past); } -static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { +static void soft_max_f16_cuda(const float * x, const half * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -7628,7 +7628,7 @@ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, con } } -static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { +static void soft_max_f32_cuda(const float * x, const half * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -9060,7 +9060,7 @@ static void ggml_cuda_op_soft_max( GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); // src1 contains mask and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -9080,9 +9080,9 @@ static void ggml_cuda_op_soft_max( #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX if (use_f16_soft_max) { - soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); + soft_max_f16_cuda(src0_dd, src1 ? (const half *) src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); } else { - soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); + soft_max_f32_cuda(src0_dd, src1 ? (const half *) src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); } (void) dst; diff --git a/ggml-metal.m b/ggml-metal.m index 15e5568f960f1..e00069624551f 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1187,6 +1187,8 @@ static bool ggml_metal_graph_compute( } break; case GGML_OP_SOFT_MAX: { + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); + int nth = 32; // SIMD width id pipeline = nil; @@ -2213,6 +2215,10 @@ static bool ggml_metal_graph_compute( id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; + GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); + GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && + "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); const int64_t ne31 = src3 ? src3->ne[1] : 0; const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); diff --git a/ggml-metal.metal b/ggml-metal.metal index b2e40715d4f2d..04c1aaf9cdfb9 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -349,9 +349,9 @@ kernel void kernel_sum_rows( } kernel void kernel_soft_max( - device const float * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -366,9 +366,9 @@ kernel void kernel_soft_max( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr; - device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const half * pmask = src1 != src0 ? (device const half *) src1 + i01*ne00 : nullptr; + device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); // parallel max float lmax = -INFINITY; @@ -435,14 +435,14 @@ kernel void kernel_soft_max( } kernel void kernel_soft_max_4( - device const float * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, constant float & scale, - threadgroup float * buf [[threadgroup(0)]], + threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], @@ -452,15 +452,15 @@ kernel void kernel_soft_max_4( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr; - device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + device const half4 * pmask = src1 != src0 ? (device const half4 *) src1 + i01*ne00/4 : nullptr; + device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; // parallel max float4 lmax4 = -INFINITY; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)); + lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4) (pmask ? pmask[i00] : 0.0f)); } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); @@ -486,7 +486,7 @@ kernel void kernel_soft_max_4( // parallel sum float4 lsum4 = 0.0f; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val); + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4) (pmask ? pmask[i00] : 0.0f)) - max_val); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; } @@ -2144,13 +2144,11 @@ kernel void kernel_flash_attn_ext_f16( } } - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - // pointer to the mask - device const float * mp = (device const float *) (mask + (ir%ne31)*nb31); + device const half * mp = (device const half *) (mask + iq1*nb31); // prepare diagonal scale matrix - simdgroup_float8x8 mscale(scale); + simdgroup_half8x8 mscale(scale); // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns @@ -2176,8 +2174,8 @@ kernel void kernel_flash_attn_ext_f16( // mqk = mqk*scale + mask for (int64_t j = 0; j < Q8; ++j) { - simdgroup_float8x8 mm; - simdgroup_load(mm, mp + 8*j*(nb31/sizeof(float)) + ic + 8*cc, nb31/sizeof(float), 0, false); + simdgroup_half8x8 mm; + simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false); diff --git a/ggml.c b/ggml.c index 466a8cdec3c9d..59a4c05a12ffe 100644 --- a/ggml.c +++ b/ggml.c @@ -5085,6 +5085,7 @@ static struct ggml_tensor * ggml_soft_max_impl( bool inplace) { GGML_ASSERT(ggml_is_contiguous(a)); if (mask) { + GGML_ASSERT(mask->type == GGML_TYPE_F16); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[2] == 1); GGML_ASSERT(mask->ne[3] == 1); @@ -5854,6 +5855,8 @@ struct ggml_tensor * ggml_flash_attn_ext( GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[2] == 1); GGML_ASSERT(mask->ne[3] == 1); + GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && + "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); } @@ -11552,12 +11555,14 @@ static void ggml_compute_forward_soft_max_f32( float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); // broadcast the mask across rows - float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; + ggml_fp16_t * mp = src1 ? (ggml_fp16_t *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; ggml_vec_cpy_f32 (nc, wp, sp); ggml_vec_scale_f32(nc, wp, scale); if (mp) { - ggml_vec_acc_f32(nc, wp, mp); + for (int i = 0; i < nc; ++i) { + wp[i] += GGML_FP16_TO_FP32(mp[i]); + } } #ifndef NDEBUG @@ -13760,7 +13765,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( memset(V16, 0, D*sizeof(ggml_fp16_t)); - const float * mp = mask ? (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]) : NULL; + const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; // k indices const int ik3 = iq3 / rk3; @@ -13774,7 +13779,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( // loop over n_kv and n_head_kv // ref: https://arxiv.org/pdf/2112.05682.pdf for (int64_t ic = 0; ic < nek1; ++ic) { - const float mv = mp ? mp[ic] : 0.0f; + const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f; if (mv == -INFINITY) { continue; } diff --git a/ggml.h b/ggml.h index a83ff8035f9ea..74ce1abd4d500 100644 --- a/ggml.h +++ b/ggml.h @@ -1646,11 +1646,13 @@ extern "C" { struct ggml_tensor * v, bool masked); - // q: [n_embd, n_batch, n_head, 1] - // k: [n_embd, n_kv, n_head_kv, 1] - // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! - // mask: [n_kv, n_batch, 1, 1] - // res: [n_embd, n_head, n_batch, 1] !! permuted !! +#define GGML_KQ_MASK_PAD 32 + + // q: [n_embd, n_batch, n_head, 1] + // k: [n_embd, n_kv, n_head_kv, 1] + // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! + // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! + // res: [n_embd, n_head, n_batch, 1] !! permuted !! GGML_API struct ggml_tensor * ggml_flash_attn_ext( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index 1f8ecc19b4e0c..fe25839669efc 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4721,7 +4721,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -4905,7 +4905,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -5026,7 +5026,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -5148,7 +5148,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); @@ -5245,7 +5245,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); if (do_rope_shift) { @@ -5448,7 +5448,7 @@ struct llm_build_context { cb(inpL, "inp_embd", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { @@ -5538,7 +5538,7 @@ struct llm_build_context { cb(inpL, "inp_embd", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); inpL = llm_build_norm(ctx0, inpL, hparams, @@ -5631,7 +5631,7 @@ struct llm_build_context { cb(inpL, "inp_embd", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { @@ -5731,7 +5731,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -5854,7 +5854,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -5968,7 +5968,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -6089,7 +6089,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -6211,7 +6211,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -6318,7 +6318,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); @@ -6416,7 +6416,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -6524,7 +6524,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -10250,7 +10250,10 @@ struct llama_context * llama_new_context_with_model( const auto & hparams = model->hparams; auto & cparams = ctx->cparams; - cparams.n_batch = params.n_batch; + // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask + // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) + cparams.n_batch = std::max((uint32_t) GGML_KQ_MASK_PAD, params.n_batch); + cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor; @@ -10430,6 +10433,9 @@ struct llama_context * llama_new_context_with_model( ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true)); + // zero-out the input buffer to prevent NaNs in padded tensors + ggml_backend_buffer_clear(ctx->buf_input, 0); + LLAMA_LOG_INFO("%s: %10s input buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(ctx->buf_input), ggml_backend_buffer_get_size(ctx->buf_input) / 1024.0 / 1024.0); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0f31c00f9672c..b1b30b91c9c6b 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1101,7 +1101,7 @@ struct test_soft_max : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_tensor * b = nullptr; - if (mask) { b = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); } + if (mask) { b = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, ne[0], ne[1]); } ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, scale); return out; } @@ -1472,7 +1472,7 @@ struct test_flash_attn_ext : public test_case { ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); - ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1); ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); return out; } @@ -1506,7 +1506,7 @@ struct test_attn : public test_case { ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); // transposed - ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, nb, 1, 1); struct ggml_tensor * cur; @@ -1793,7 +1793,7 @@ struct test_llama : public test_llm { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1); ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); @@ -1915,7 +1915,7 @@ struct test_falcon : public test_llm { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1); ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); From 910b15bb4006409fe24b41da171cc562cdb1f3a4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 16:41:02 +0200 Subject: [PATCH 032/100] ggml : fix ggml_soft_max mask requirement --- ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 59a4c05a12ffe..ebd9c6b341080 100644 --- a/ggml.c +++ b/ggml.c @@ -5089,7 +5089,7 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[2] == 1); GGML_ASSERT(mask->ne[3] == 1); - GGML_ASSERT(ggml_can_repeat_rows(mask, a)); + GGML_ASSERT(mask->ne[1] >= a->ne[1]); } bool is_node = false; From 2e460137490a4e002a60a60aed052e90179bb65b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 16:47:20 +0200 Subject: [PATCH 033/100] cuda : fix soft_max to use correct mask size --- ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c57a031e4060c..15fc6154f7508 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -9064,7 +9064,7 @@ static void ggml_cuda_op_soft_max( const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1; + const int64_t nrows_y = src1 ? src0->ne[1] : 1; // note: using number of queries since mask can be padded! float scale = 1.0f; memcpy(&scale, dst->op_params, sizeof(float)); From 5a19a9f6d0899becbc71a19454a27c0225edddf7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 19:47:11 +0200 Subject: [PATCH 034/100] cuda : add flash_attn kernel (wip) --- ggml-cuda.cu | 735 ++++++++++++++++++++++++++++++++++++++++++++++++++- llama.cpp | 3 +- 2 files changed, 735 insertions(+), 3 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 15fc6154f7508..60d228a61660f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -108,6 +108,7 @@ #include #include #include +#include #if CUDART_VERSION < 11020 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED @@ -655,6 +656,19 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } +static __device__ __forceinline__ half warp_reduce_sum(half x) { +#if __CUDA_ARCH__ >= CC_VOLTA +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hadd(__shfl_xor_sync(0xffffffff, x, mask, 32), x); + } + return x; +#else + (void) x; + NO_DEVICE_CODE; +#endif +} + static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { @@ -676,6 +690,18 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX } +static __device__ __forceinline__ half warp_reduce_max(half x) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#else + (void) x; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +} + static __device__ __forceinline__ float op_repeat(const float a, const float b) { return b; GGML_UNUSED(a); @@ -989,6 +1015,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr if (lane_id == 0) { s_sum[warp_id] = tmp; } + __syncthreads(); tmp = s_sum[lane_id]; tmp = warp_reduce_sum(tmp); @@ -6249,6 +6276,528 @@ static __global__ void pool2d_nchw_kernel( o_ptr[cur_oh * ow + cur_ow] = res; } +#define CUDA_FLASH_ATTENTION_BLOCK_SIZE 256 + +template +static __global__ void flash_attn_f32( + const float* __restrict__ q, + const float* __restrict__ k, + const float* __restrict__ v, + float* __restrict__ kqv, + float kq_scale, + int head_dim, int seq_len, int num_heads) { + const int head = blockIdx.x / seq_len; + const int head_size = head_dim * seq_len; + const int s = blockIdx.x % seq_len; + + extern __shared__ char flash_attn_shmem_f32[]; + float* S = (float*)flash_attn_shmem_f32; + float* warp_data = (float*)(flash_attn_shmem_f32 + seq_len * sizeof(float)); + + // QK^T + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + + const int key_offset = is * head_dim + head * head_size; + const int query_offset = s * head_dim + head * head_size; + + float tmp = 0.0f; + for(int d = 0; d < head_dim; d++) { + tmp += k[key_offset + d] * q[query_offset + d]; + } + S[is] = tmp * kq_scale; + } + __syncthreads(); + + float max_val = -INFINITY; + // get the max + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + + max_val = fmaxf(max_val , S[is]); + } + + max_val = warp_reduce_max(max_val); + + { // get max from all threads + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + warp_data[warp_id] = max_val; + } + __syncthreads(); + max_val = warp_data[lane_id]; + max_val = warp_reduce_max(max_val); + } + + // softmax(QK^T) + float sum = 0.0f; + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + float tmp = expf(S[is] - max_val); + sum += tmp; + S[is] = tmp; + } + __syncthreads(); + + sum = warp_reduce_sum(sum); + { // softmax sum partials + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + warp_data[warp_id] = sum; + } + __syncthreads(); + sum = warp_data[lane_id]; + sum = warp_reduce_sum(sum); + } + + float inv_sum = 1.0f / sum; + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + + S[is] *= inv_sum; + } + __syncthreads(); + + // softmax(QK^T)V + #pragma unroll + for (int d0 = threadIdx.x; d0 < k_head_dim; d0 += block_size) { + const int d = threadIdx.x + d0; + if(d >= head_dim) { + break; + } + const int dst_index = d + s * head_dim + head * head_size; + const int value_offset = d * seq_len + head * head_size; + + float temp = 0.0f; + #pragma unroll + for(int ic = 0; ic < k_seq_len;ic++) { + if(ic >= seq_len) { + break; + } + temp += v[value_offset + ic] * S[ic]; + } + kqv[dst_index] = temp; + } +} + +#if __CUDA_ARCH__ >= CC_VOLTA +typedef nvcuda::wmma::fragment half16x16_a; +typedef nvcuda::wmma::fragment half16x16_b; +typedef nvcuda::wmma::fragment half16x16_bT; +typedef nvcuda::wmma::fragment half16x16_acc; +#endif + +// based on metal version +template // D head size, Q queries per block, C cache items per block +static __global__ void flash_attn_ext_f16( + const char* __restrict__ q, + const char* __restrict__ k, + const char* __restrict__ v, + const char* __restrict__ mask, + float* __restrict__ dst, + float scale, + int ne00, + int ne01, + int ne02, + int ne03, + int ne10, + int ne11, + int ne12, + int ne13, + int ne31, + int nb31, + int nb01, + int nb02, + int nb03, + int nb11, + int nb12, + int nb13, + int ne0, + int ne1, + int ne2, + int ne3) { +#if __CUDA_ARCH__ >= CC_VOLTA + const int warp_id = threadIdx.y; + const int lane_id = threadIdx.x; + + const int num_warps = blockDim.y; // number of warps + const int iq3 = blockIdx.z; + const int iq2 = blockIdx.y; + const int iq1 = blockIdx.x * Q; + + const int D2 = D/2; + const int D16 = D/16; + const int Q16 = Q/16; + const int NW = WARP_SIZE; + const int SH = (C + Q); // shared memory per simdgroup in (half) + + const int T = D + num_warps*SH; // shared memory size per query in (half) + const int T2 = T/2; // shared memory size per query in (half2) + + extern __shared__ half __flash_attn_f16_shmem[]; + // pq + half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data + half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 + half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix + + half16x16_acc zr; + half16x16_acc lo[Q16][D16]; + + // load heads from Q to shared memory + for (int64_t j = warp_id; j < Q; j += num_warps) { + const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + + for (int64_t i = lane_id; i < D2; i += NW) { + if (iq1 + j < ne01) { + sq2[j*T2 + i] = __float22half2_rn(q2[i]); + } else { + sq2[j*T2 + i] = make_half2(0.0, 0.0); + } + } + } + + nvcuda::wmma::fill_fragment(zr, 0.0); + + // zero out lo + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::fill_fragment(lo[j][i], 0.0); + } + } + + // zero out shared memory SH + for (int64_t j = 0; j < Q; ++j) { + for (int64_t i = lane_id; i < SH; i += NW) { + ss[j*T + i] = 0.0; + } + } + + __syncthreads(); + + { + half S[Q]; + half M[Q]; + + for(int i = 0; i < Q; i++) { + S[i] = __float2half(0.0f); + M[i] = __float2half(-INFINITY); + } + + // assume K and V are same shape + const int ne22 = ne12; + const int ne23 = ne13; + + const int nb21 = nb11; + const int nb22 = nb12; + const int nb23 = nb13; + + // broadcast + const int rk2 = ne02/ne12; + const int rk3 = ne03/ne13; + + const int rv2 = ne02/ne22; + const int rv3 = ne03/ne23; + + // k indices + const int ik2 = iq2 / rk2; + const int ik3 = iq3 / rk3; + + // v indices + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; + + // load the queries from shared memory into local memory + half16x16_a mq[Q16][D16]; + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); + } + } + + // pointer to the mask + const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr; + + // prepare diagonal scale matrix + half16x16_b mscale; + for (int i = 0; i < 16; ++i) { + ss[i*T + i] = __float2half(scale); + } + nvcuda::wmma::load_matrix_sync(mscale, ss, T); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) { + // Q*K^T + { + for (int cc = 0; cc < C/16; ++cc) { + half16x16_acc mqk[Q16]; + for (int64_t j = 0; j < Q16; ++j) { + nvcuda::wmma::fill_fragment(mqk[j], 0); + } + + const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); + + for (int64_t i = 0; i < D16; ++i) { + half16x16_bT mk; // transposed key + nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); + + for (int64_t j = 0; j < Q16; ++j) { + nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); + } + } + + // mqk = mqk*scale + mask + for (int64_t j = 0; j < Q16; ++j) { + half16x16_a mqka; + half16x16_acc mm; + if(mp) { + nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); + } + + // convert accumulator to matrix_a + nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T); + + nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr); + nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); + } + } + } + + // used to detect blocks full of -INF + half smax = __float2half(-INFINITY); + + // online softmax + if (C == 32) { + for (int64_t j = 0; j < Q; ++j) { + const int64_t p = lane_id; + + const half m = M[j]; + const half s = ss[j*T + p]; + + smax = warp_reduce_max(__hmax(smax, s)); + M[j] = warp_reduce_max(__hmax(M[j], s)); + + const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); + const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); + + S[j] = S[j]*ms + warp_reduce_sum(vs); + + // create a QxQ diagonal matrix for rescaling the output + if (p == j) { + ss[j*T + C + j] = ms; + } + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } + } else { + for (int64_t j = 0; j < Q; ++j) { + const half m = M[j]; + + for (int64_t p = lane_id; p < C; p += NW) { + const half s = ss[j*T + p]; + + smax = __hmax(smax, s); + M[j] = __hmax(M[j], s); + } + + smax = warp_reduce_max(smax); + M[j] = warp_reduce_max(M[j]); + + const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); + + // create a QxQ diagonal matrix for rescaling the output + if (lane_id == j) { + ss[j*T + C + j] = ms; + } + + // local sum + half ls = 0.0f; + + for (int64_t p = lane_id; p < C; p += NW) { + const half s = ss[j*T + p]; + + const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); + + ls += vs; + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } + + S[j] = S[j]*ms + warp_reduce_sum(ls); + } + } + + // skip -INF blocks + if (__hisinf(smax)) { + continue; + } + + // O = diag(ms)*O + for (int64_t j = 0; j < Q16; ++j) { + half16x16_a mm; + half16x16_b lob; + + nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); + + for (int64_t i = 0; i < D16; ++i) { + // convert accumulator to matrix_b + nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); + + nvcuda::wmma::fill_fragment(lo[j][i], 0.0); + nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]); + } + + // restore zeros + nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major); + } + + // O = O + (Q*K^T)*V + { + for (int cc = 0; cc < C/16; ++cc) { + const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); + + half16x16_b mk[D16]; + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half)); + } + + half16x16_a mv[Q16]; + for (int64_t j = 0; j < Q16; ++j) { + nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T); + } + + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]); + } + } + } + } + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + for (int64_t j = 0; j < Q; ++j) { + if (lane_id == 0) { + ss[j*T + 0] = S[j]; + ss[j*T + 1] = M[j]; + } + } + } + + // reduce the warps sequentially + for (int64_t sg = 1; sg < num_warps; ++sg) { + half S = __float2half(0.0f); + half M = __float2half(-INFINITY); + + __syncthreads(); + + // each simdgroup stores its output to shared memory, reusing sq + if (warp_id == sg) { + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + } + } + } + + __syncthreads(); + + // the first simdgroup accumulates the results from the other simdgroups + if (warp_id == 0) { + for (int64_t j = 0; j < Q; ++j) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*SH + 0]; + + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*SH + 1]; + + M = __hmax(M0, M1); + + const half ms0 = __hisinf(M0) ? __float2half(0.0f) : hexp(M0 - M); + const half ms1 = __hisinf(M1) ? __float2half(0.0f) : hexp(M1 - M); + + S = S0*ms0 + S1*ms1; + + if (lane_id == 0) { + ss[j*T + 0] = S; + ss[j*T + 1] = M; + + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + sg*SH] = ms1; + } + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (int64_t j = 0; j < Q16; ++j) { + half16x16_a ms0; + half16x16_a ms1; + half16x16_b t; + half16x16_acc t2; + + nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T); + nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); + + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::fill_fragment(t2, 0.0); + nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); + nvcuda::wmma::mma_sync(t2, ms1, t, t2); + + // convert accumulator to matrix_b + nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T); + + nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2); + } + } + } + } + + // store result to shared memory (reuse sq) + if (warp_id == 0) { + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + } + } + } + + // final rescale with 1/S and store to global memory + if (warp_id == 0) { + for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { + const half S = ss[j*T + 0]; + + for (int64_t i = lane_id; i < D; i += NW) { + dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); + } + } + } +#else + NO_DEVICE_CODE; +#endif +} + template static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { @@ -7682,6 +8231,13 @@ static void im2col_cuda(const float* x, T* dst, im2col_kernel<<>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1); } +static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) { + int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float); + int num_blocks = num_heads * seq_len; + flash_attn_f32<<>>( + q, k, v, dst, kq_scale, d_head, seq_len, num_heads); +} + // buffer pool for cuda #define MAX_CUDA_BUFFERS 256 @@ -8659,7 +9215,7 @@ static void ggml_cuda_op_dequantize_mul_mat_vec( src1_dfloat = src1_dfloat_a.alloc(ne00); ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00, ne00, 1, sizeof(float), 0, 0, - ne00, 1, sizeof(half), 0, 0, stream); + ne00, 1, sizeof(half), 0, 0, 0, 0, 0, 0, stream); } #else const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion @@ -10284,6 +10840,170 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s } } +inline void ggml_cuda_flash_attn(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV) { + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(K->type == GGML_TYPE_F32); + GGML_ASSERT(V->type == GGML_TYPE_F32); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(Q->backend == GGML_BACKEND_GPU); + GGML_ASSERT(K->backend == GGML_BACKEND_GPU); + GGML_ASSERT(V->backend == GGML_BACKEND_GPU); + GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU); + + ggml_cuda_set_device(g_main_device); + const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra; + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra; + ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra; + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra; + + const int64_t d_head = Q->ne[0]; + const int64_t sequence_length = Q->ne[1]; + const int64_t num_heads = Q->ne[2]; + + GGML_ASSERT(Q->ne[0] == d_head); + GGML_ASSERT(K->ne[0] == d_head); + GGML_ASSERT(V->ne[1] == d_head); + + GGML_ASSERT(Q->ne[1] == sequence_length); + GGML_ASSERT(K->ne[1] == sequence_length); + GGML_ASSERT(V->ne[0] == sequence_length); + + GGML_ASSERT(Q->ne[2] == num_heads); + GGML_ASSERT(K->ne[2] == num_heads); + GGML_ASSERT(V->ne[2] == num_heads); + + float KQ_scale = 1.0f / sqrtf((float)d_head); + + flash_attn_f32_cuda( + (float *) src0_extra->data_device[g_main_device], // Query + (float *) src1_extra->data_device[g_main_device], // Key + (float *) src2_extra->data_device[g_main_device], // Value + (float *) dst_extra->data_device[g_main_device], // dst + KQ_scale, d_head, sequence_length, num_heads, main_stream); +} + + +inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) { + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(K->type == GGML_TYPE_F16); + GGML_ASSERT(V->type == GGML_TYPE_F16); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(Q->backend == GGML_BACKEND_GPU); + GGML_ASSERT(K->backend == GGML_BACKEND_GPU); + GGML_ASSERT(V->backend == GGML_BACKEND_GPU); + GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU); + + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); + GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU); + GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + + ggml_cuda_set_device(g_main_device); + const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra; + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra; + ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra; + ggml_tensor_extra_gpu * src3_extra = mask ? (ggml_tensor_extra_gpu *) mask->extra : nullptr; + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra; + + float scale; + memcpy(&scale, KQV->op_params, sizeof(float)); + +#define NQPB 16 +#define NCPW 128 + + const int nqpb = NQPB; // queries per block + const int ncpw = NCPW; // cache values per warp (does not work for other values) + + const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? + // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why + const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 2; + + dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); + dim3 block_dim(32, nwarps, 1); + + const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); + + switch (Q->ne[0]) + { + case 16: + flash_attn_ext_f16<16, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 64: + flash_attn_ext_f16<64, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 80: + flash_attn_ext_f16<80, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 128: + flash_attn_ext_f16<128, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + break; + } +} + static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); } @@ -10573,6 +11293,10 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st case GGML_OP_ARGSORT: func = ggml_cuda_argsort; break; + case GGML_OP_FLASH_ATTN: + break; + case GGML_OP_FLASH_ATTN_EXT: + break; default: return false; } @@ -10587,7 +11311,13 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return true; } - func(tensor->src[0], tensor->src[1], tensor); + if(tensor->op == GGML_OP_FLASH_ATTN) { + ggml_cuda_flash_attn(tensor->src[0], tensor->src[1], tensor->src[2], tensor); + } else if(tensor->op == GGML_OP_FLASH_ATTN_EXT) { + ggml_cuda_flash_attn_ext(tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + } else { + func(tensor->src[0], tensor->src[1], tensor); + } return true; } @@ -11403,6 +12133,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_UPSCALE: case GGML_OP_PAD: case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: return true; default: return false; diff --git a/llama.cpp b/llama.cpp index fe25839669efc..2330efff57bd3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6881,7 +6881,8 @@ static int llama_decode_internal( // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); + // note: we pad the n_kv because certain GPU kernels require it (e.g. ggml_flash_attn_ext) + kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(128, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128))); //kv_self.n = llama_kv_cache_cell_max(kv_self); //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); From 56e45a239e1d5a871009aa162b7ba99c93c40b62 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 20:16:32 +0200 Subject: [PATCH 035/100] metal : optimize softmax for C > 32 --- ggml-metal.metal | 16 +++++++++++----- tests/test-backend-ops.cpp | 9 +++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 04c1aaf9cdfb9..3d5d762d16d99 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2217,29 +2217,35 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t p = tiisg; p < C; p += NW) { const half s = ss[j*T + p]; - smax = simd_max(max(smax, s)); - M[j] = simd_max(max(M[j], s)); + smax = max(smax, s); + M[j] = max(M[j], s); } - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); + smax = simd_max(smax); + M[j] = simd_max(M[j]); - S[j] = S[j]*ms; + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); // create a QxQ diagonal matrix for rescaling the output if (tiisg == j) { ss[j*T + C + j] = ms; } + // local sum + half ls = 0.0h; + for (int64_t p = tiisg; p < C; p += NW) { const half s = ss[j*T + p]; const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - S[j] = S[j] + simd_sum(vs); + ls += vs; // the P matrix from the paper (Q rows, C columns) ss[j*T + p] = vs; } + + S[j] = S[j]*ms + simd_sum(ls); } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b1b30b91c9c6b..2ab5354069853 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -572,9 +572,18 @@ struct test_case { // duplicate the op size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1; +#if 1 for (int i = 1; i < n_runs; i++) { gf->nodes[gf->n_nodes++] = out; } +#else + int n_nodes = gf->n_nodes; + for (int i = 1; i < n_runs; i++) { + for (int j = 0; j < n_nodes; j++) { + gf->nodes[gf->n_nodes++] = gf->nodes[j]; + } + } +#endif // calculate memory size_t mem = n_runs * op_size(out); From cda5a60a41c669f233a943e1182cd6625f61a924 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 20:53:29 +0200 Subject: [PATCH 036/100] metal : optimize softmax --- ggml-metal.m | 5 +++-- ggml-metal.metal | 34 +++++++++++++++++++--------------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index e00069624551f..2bbb6d17a36ad 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2285,8 +2285,9 @@ static bool ggml_metal_graph_compute( const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - GGML_ASSERT(nqptg % 8 == 0); - GGML_ASSERT(ncpsg % 32 == 0); + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) diff --git a/ggml-metal.metal b/ggml-metal.metal index 3d5d762d16d99..d9a536ae87109 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2188,6 +2188,8 @@ kernel void kernel_flash_attn_ext_f16( // online softmax if (C == 32) { + half ms[Q]; + for (int64_t j = 0; j < Q; ++j) { const int64_t p = tiisg; @@ -2197,20 +2199,22 @@ kernel void kernel_flash_attn_ext_f16( smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - - S[j] = S[j]*ms + simd_sum(vs); + ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - // create a QxQ diagonal matrix for rescaling the output - if (p == j) { - ss[j*T + C + j] = ms; - } + S[j] = S[j]*ms[j] + simd_sum(vs); // the P matrix from the paper (Q rows, C columns) ss[j*T + p] = vs; } + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*T + C + tiisg] = ms[tiisg]; + } } else { + half ms[Q]; + for (int64_t j = 0; j < Q; ++j) { const half m = M[j]; @@ -2224,12 +2228,7 @@ kernel void kernel_flash_attn_ext_f16( smax = simd_max(smax); M[j] = simd_max(M[j]); - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); - - // create a QxQ diagonal matrix for rescaling the output - if (tiisg == j) { - ss[j*T + C + j] = ms; - } + ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); // local sum half ls = 0.0h; @@ -2245,7 +2244,12 @@ kernel void kernel_flash_attn_ext_f16( ss[j*T + p] = vs; } - S[j] = S[j]*ms + simd_sum(ls); + S[j] = S[j]*ms[j] + simd_sum(ls); + } + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*T + C + tiisg] = ms[tiisg]; } } From c6769b942229a9e634965e6215651b8d4cf02403 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 21:24:26 +0200 Subject: [PATCH 037/100] tests : minor fix --- tests/test-backend-ops.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2ab5354069853..727f2ea06a5d7 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -578,6 +578,7 @@ struct test_case { } #else int n_nodes = gf->n_nodes; + n_runs = 1000; for (int i = 1; i < n_runs; i++) { for (int j = 0; j < n_nodes; j++) { gf->nodes[gf->n_nodes++] = gf->nodes[j]; From db1f3c482e256398330d44ad22b498ca2cd03161 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 22:08:37 +0200 Subject: [PATCH 038/100] cuda : avoid zeroing fragments --- ggml-cuda.cu | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 098b55e073c12..7130209e702d2 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6443,11 +6443,11 @@ static __global__ void flash_attn_ext_f16( const int iq2 = blockIdx.y; const int iq1 = blockIdx.x * Q; - const int D2 = D/2; + const int D2 = D/2; const int D16 = D/16; const int Q16 = Q/16; - const int NW = WARP_SIZE; - const int SH = (C + Q); // shared memory per simdgroup in (half) + const int NW = WARP_SIZE; + const int SH = (C + Q); // shared memory per simdgroup in (half) const int T = D + num_warps*SH; // shared memory size per query in (half) const int T2 = T/2; // shared memory size per query in (half2) @@ -6665,8 +6665,7 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); - nvcuda::wmma::fill_fragment(lo[j][i], 0.0); - nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]); + nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr); } // restore zeros @@ -6760,9 +6759,8 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); for (int64_t i = 0; i < D16; ++i) { - nvcuda::wmma::fill_fragment(t2, 0.0); nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); - nvcuda::wmma::mma_sync(t2, ms1, t, t2); + nvcuda::wmma::mma_sync(t2, ms1, t, zr); // convert accumulator to matrix_b nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); From 12eaa22628740e388789081ccba93159c1b0b412 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 2 Feb 2024 11:55:38 +0200 Subject: [PATCH 039/100] tests : update dims --- ggml-cuda.cu | 180 ++++++++++++++++++++++--------------- tests/test-backend-ops.cpp | 6 +- 2 files changed, 110 insertions(+), 76 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 7130209e702d2..2c050c0c44edc 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6568,7 +6568,8 @@ static __global__ void flash_attn_ext_f16( for (int64_t j = 0; j < Q16; ++j) { half16x16_a mqka; half16x16_acc mm; - if(mp) { + + if (mp) { nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); } @@ -10927,78 +10928,111 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); - switch (Q->ne[0]) - { - case 16: - flash_attn_ext_f16<16, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 64: - flash_attn_ext_f16<64, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 80: - flash_attn_ext_f16<80, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 128: - flash_attn_ext_f16<128, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - default: - break; + switch (Q->ne[0]) { + case 64: + flash_attn_ext_f16<64, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 80: + flash_attn_ext_f16<80, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 96: + flash_attn_ext_f16<96, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 112: + flash_attn_ext_f16<112, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 128: + flash_attn_ext_f16<128, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 256: + flash_attn_ext_f16<256, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + break; } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 727f2ea06a5d7..9feb5e1fe550e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -572,7 +572,7 @@ struct test_case { // duplicate the op size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1; -#if 1 +#if 0 for (int i = 1; i < n_runs; i++) { gf->nodes[gf->n_nodes++] = out; } @@ -2209,8 +2209,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); -#if 0 - for (int hs : { 64, 80, 96, 112, 128, 256, }) { +#if 1 + for (int hs : { 64, 80, 128, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { From b68a112204c58e2bed334273753211c15acc48e2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 2 Feb 2024 15:12:28 +0200 Subject: [PATCH 040/100] cuda : fix __hisinf() result check --- ggml-cuda.cu | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2c050c0c44edc..0136fbf28f2a5 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6597,8 +6597,8 @@ static __global__ void flash_attn_ext_f16( smax = warp_reduce_max(__hmax(smax, s)); M[j] = warp_reduce_max(__hmax(M[j], s)); - const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); - const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); + const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); + const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); S[j] = S[j]*ms + warp_reduce_sum(vs); @@ -6624,7 +6624,7 @@ static __global__ void flash_attn_ext_f16( smax = warp_reduce_max(smax); M[j] = warp_reduce_max(M[j]); - const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); + const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); // create a QxQ diagonal matrix for rescaling the output if (lane_id == j) { @@ -6637,7 +6637,7 @@ static __global__ void flash_attn_ext_f16( for (int64_t p = lane_id; p < C; p += NW) { const half s = ss[j*T + p]; - const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); + const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); ls += vs; @@ -6650,7 +6650,7 @@ static __global__ void flash_attn_ext_f16( } // skip -INF blocks - if (__hisinf(smax)) { + if (__hisinf(smax) == -1) { continue; } @@ -6735,8 +6735,8 @@ static __global__ void flash_attn_ext_f16( M = __hmax(M0, M1); - const half ms0 = __hisinf(M0) ? __float2half(0.0f) : hexp(M0 - M); - const half ms1 = __hisinf(M1) ? __float2half(0.0f) : hexp(M1 - M); + const half ms0 = __hisinf(M0) == -1 ? __float2half(0.0f) : hexp(M0 - M); + const half ms1 = __hisinf(M1) == -1 ? __float2half(0.0f) : hexp(M1 - M); S = S0*ms0 + S1*ms1; From b150abe83e6f0f8a0cf552d7fc0d8fe9f0f78569 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 13:17:47 +0200 Subject: [PATCH 041/100] cuda : avoid warp_reduce for smax --- ggml-cuda.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0136fbf28f2a5..c3f24242b350a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6621,7 +6621,6 @@ static __global__ void flash_attn_ext_f16( M[j] = __hmax(M[j], s); } - smax = warp_reduce_max(smax); M[j] = warp_reduce_max(M[j]); const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); @@ -6649,6 +6648,8 @@ static __global__ void flash_attn_ext_f16( } } + smax = warp_reduce_max(smax); + // skip -INF blocks if (__hisinf(smax) == -1) { continue; From 7c34655b366e14d43f7fc9fa104a9ca7b8f60580 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 13:39:46 +0200 Subject: [PATCH 042/100] cuda : use int instead of int64_t Noticeably improves performance (thanks to Johannes) --- ggml-cuda.cu | 70 ++++++++++++++++++++++++++-------------------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c3f24242b350a..558ffb8ac7b56 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6462,10 +6462,10 @@ static __global__ void flash_attn_ext_f16( half16x16_acc lo[Q16][D16]; // load heads from Q to shared memory - for (int64_t j = warp_id; j < Q; j += num_warps) { + for (int j = warp_id; j < Q; j += num_warps) { const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); - for (int64_t i = lane_id; i < D2; i += NW) { + for (int i = lane_id; i < D2; i += NW) { if (iq1 + j < ne01) { sq2[j*T2 + i] = __float22half2_rn(q2[i]); } else { @@ -6477,15 +6477,15 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::fill_fragment(zr, 0.0); // zero out lo - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::fill_fragment(lo[j][i], 0.0); } } // zero out shared memory SH - for (int64_t j = 0; j < Q; ++j) { - for (int64_t i = lane_id; i < SH; i += NW) { + for (int j = 0; j < Q; ++j) { + for (int i = lane_id; i < SH; i += NW) { ss[j*T + i] = 0.0; } } @@ -6526,8 +6526,8 @@ static __global__ void flash_attn_ext_f16( // load the queries from shared memory into local memory half16x16_a mq[Q16][D16]; - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); } } @@ -6544,28 +6544,28 @@ static __global__ void flash_attn_ext_f16( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) { + for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) { // Q*K^T { for (int cc = 0; cc < C/16; ++cc) { half16x16_acc mqk[Q16]; - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::fill_fragment(mqk[j], 0); } const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { half16x16_bT mk; // transposed key nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); } } // mqk = mqk*scale + mask - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { half16x16_a mqka; half16x16_acc mm; @@ -6588,8 +6588,8 @@ static __global__ void flash_attn_ext_f16( // online softmax if (C == 32) { - for (int64_t j = 0; j < Q; ++j) { - const int64_t p = lane_id; + for (int j = 0; j < Q; ++j) { + const int p = lane_id; const half m = M[j]; const half s = ss[j*T + p]; @@ -6611,10 +6611,10 @@ static __global__ void flash_attn_ext_f16( ss[j*T + p] = vs; } } else { - for (int64_t j = 0; j < Q; ++j) { + for (int j = 0; j < Q; ++j) { const half m = M[j]; - for (int64_t p = lane_id; p < C; p += NW) { + for (int p = lane_id; p < C; p += NW) { const half s = ss[j*T + p]; smax = __hmax(smax, s); @@ -6633,7 +6633,7 @@ static __global__ void flash_attn_ext_f16( // local sum half ls = 0.0f; - for (int64_t p = lane_id; p < C; p += NW) { + for (int p = lane_id; p < C; p += NW) { const half s = ss[j*T + p]; const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); @@ -6656,13 +6656,13 @@ static __global__ void flash_attn_ext_f16( } // O = diag(ms)*O - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { half16x16_a mm; half16x16_b lob; nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { // convert accumulator to matrix_b nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); @@ -6680,17 +6680,17 @@ static __global__ void flash_attn_ext_f16( const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); half16x16_b mk[D16]; - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half)); } half16x16_a mv[Q16]; - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T); } - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]); } } @@ -6699,7 +6699,7 @@ static __global__ void flash_attn_ext_f16( } // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (int64_t j = 0; j < Q; ++j) { + for (int j = 0; j < Q; ++j) { if (lane_id == 0) { ss[j*T + 0] = S[j]; ss[j*T + 1] = M[j]; @@ -6708,7 +6708,7 @@ static __global__ void flash_attn_ext_f16( } // reduce the warps sequentially - for (int64_t sg = 1; sg < num_warps; ++sg) { + for (int sg = 1; sg < num_warps; ++sg) { half S = __float2half(0.0f); half M = __float2half(-INFINITY); @@ -6716,8 +6716,8 @@ static __global__ void flash_attn_ext_f16( // each simdgroup stores its output to shared memory, reusing sq if (warp_id == sg) { - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); } } @@ -6727,7 +6727,7 @@ static __global__ void flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (warp_id == 0) { - for (int64_t j = 0; j < Q; ++j) { + for (int j = 0; j < Q; ++j) { const half S0 = ss[j*T + 0]; const half S1 = ss[j*T + sg*SH + 0]; @@ -6751,7 +6751,7 @@ static __global__ void flash_attn_ext_f16( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { half16x16_a ms0; half16x16_a ms1; half16x16_b t; @@ -6760,7 +6760,7 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T); nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); nvcuda::wmma::mma_sync(t2, ms1, t, zr); @@ -6776,8 +6776,8 @@ static __global__ void flash_attn_ext_f16( // store result to shared memory (reuse sq) if (warp_id == 0) { - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); } } @@ -6785,10 +6785,10 @@ static __global__ void flash_attn_ext_f16( // final rescale with 1/S and store to global memory if (warp_id == 0) { - for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { + for (int j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; - for (int64_t i = lane_id; i < D; i += NW) { + for (int i = lane_id; i < D; i += NW) { dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); } } From 1f8a5924823aecaa6ab1d5c2ac70ddde1d6c27d0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 14:01:32 +0200 Subject: [PATCH 043/100] cuda : make loops use the same loop values Thanks Johannes again for the tip --- ggml-cuda.cu | 43 +++++++++++++++++++++++++++++++------- tests/test-backend-ops.cpp | 2 +- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 558ffb8ac7b56..a3a6c6455017b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6462,10 +6462,20 @@ static __global__ void flash_attn_ext_f16( half16x16_acc lo[Q16][D16]; // load heads from Q to shared memory - for (int j = warp_id; j < Q; j += num_warps) { + for (int j0 = 0; j0 < Q; j0 += num_warps) { + const int j = j0 + warp_id; + if (j >= Q) { + break; + } + const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); - for (int i = lane_id; i < D2; i += NW) { + for (int i0 = 0; i0 < D2; i0 += NW) { + const int i = i0 + lane_id; + if (i >= D2) { + break; + } + if (iq1 + j < ne01) { sq2[j*T2 + i] = __float22half2_rn(q2[i]); } else { @@ -6485,7 +6495,12 @@ static __global__ void flash_attn_ext_f16( // zero out shared memory SH for (int j = 0; j < Q; ++j) { - for (int i = lane_id; i < SH; i += NW) { + for (int i0 = 0; i0 < SH; i0 += NW) { + const int i = i0 + lane_id; + if (i >= SH) { + break; + } + ss[j*T + i] = 0.0; } } @@ -6544,7 +6559,12 @@ static __global__ void flash_attn_ext_f16( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) { + for (int ic0 = 0; ic0 < ne11; ic0 += C*num_warps) { + const int ic = ic0 + warp_id*C; + if (ic >= ne11) { + break; + } + // Q*K^T { for (int cc = 0; cc < C/16; ++cc) { @@ -6614,7 +6634,9 @@ static __global__ void flash_attn_ext_f16( for (int j = 0; j < Q; ++j) { const half m = M[j]; - for (int p = lane_id; p < C; p += NW) { + for (int p0 = 0; p0 < C; p0 += NW) { + const int p = p0 + lane_id; + const half s = ss[j*T + p]; smax = __hmax(smax, s); @@ -6633,7 +6655,9 @@ static __global__ void flash_attn_ext_f16( // local sum half ls = 0.0f; - for (int p = lane_id; p < C; p += NW) { + for (int p0 = 0; p0 < C; p0 += NW) { + const int p = p0 + lane_id; + const half s = ss[j*T + p]; const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); @@ -6788,7 +6812,12 @@ static __global__ void flash_attn_ext_f16( for (int j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; - for (int i = lane_id; i < D; i += NW) { + for (int i0 = 0; i0 < D; i0 += NW) { + const int i = i0 + lane_id; + if (i >= D) { + break; + } + dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 9feb5e1fe550e..e4076b49c180d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2210,7 +2210,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_leaky_relu()); #if 1 - for (int hs : { 64, 80, 128, }) { + for (int hs : { 128, 64, 80, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { From 92472ea22ca3eed7f65114b0e6b7de1585930759 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 14:10:01 +0200 Subject: [PATCH 044/100] cuda : unroll some of the loops --- ggml-cuda.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index a3a6c6455017b..deda4cc706fdc 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6462,6 +6462,7 @@ static __global__ void flash_attn_ext_f16( half16x16_acc lo[Q16][D16]; // load heads from Q to shared memory +#pragma unroll for (int j0 = 0; j0 < Q; j0 += num_warps) { const int j = j0 + warp_id; if (j >= Q) { @@ -6470,6 +6471,7 @@ static __global__ void flash_attn_ext_f16( const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); +#pragma unroll for (int i0 = 0; i0 < D2; i0 += NW) { const int i = i0 + lane_id; if (i >= D2) { From c51f27c0dbd70fe8eda6182d61371d6a2dea6fb9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 14:27:36 +0200 Subject: [PATCH 045/100] cuda : avoid __hisinf branches --- ggml-cuda.cu | 83 ++++++++++++++++++---------------------------------- 1 file changed, 29 insertions(+), 54 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index deda4cc706fdc..4d1fb008c3994 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6513,9 +6513,9 @@ static __global__ void flash_attn_ext_f16( half S[Q]; half M[Q]; - for(int i = 0; i < Q; i++) { + for (int i = 0; i < Q; ++i) { S[i] = __float2half(0.0f); - M[i] = __float2half(-INFINITY); + M[i] = CUDART_MIN_DENORM_FP16; } // assume K and V are same shape @@ -6609,69 +6609,44 @@ static __global__ void flash_attn_ext_f16( half smax = __float2half(-INFINITY); // online softmax - if (C == 32) { - for (int j = 0; j < Q; ++j) { - const int p = lane_id; - - const half m = M[j]; - const half s = ss[j*T + p]; - - smax = warp_reduce_max(__hmax(smax, s)); - M[j] = warp_reduce_max(__hmax(M[j], s)); - - const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); - const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); + for (int j = 0; j < Q; ++j) { + const half m = M[j]; - S[j] = S[j]*ms + warp_reduce_sum(vs); + for (int p0 = 0; p0 < C; p0 += NW) { + const int p = p0 + lane_id; - // create a QxQ diagonal matrix for rescaling the output - if (p == j) { - ss[j*T + C + j] = ms; - } + const half s = ss[j*T + p]; - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + smax = __hmax(smax, s); + M[j] = __hmax(M[j], s); } - } else { - for (int j = 0; j < Q; ++j) { - const half m = M[j]; - - for (int p0 = 0; p0 < C; p0 += NW) { - const int p = p0 + lane_id; - const half s = ss[j*T + p]; + M[j] = warp_reduce_max(M[j]); - smax = __hmax(smax, s); - M[j] = __hmax(M[j], s); - } - - M[j] = warp_reduce_max(M[j]); - - const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); + const half ms = hexp(m - M[j]); - // create a QxQ diagonal matrix for rescaling the output - if (lane_id == j) { - ss[j*T + C + j] = ms; - } - - // local sum - half ls = 0.0f; + // create a QxQ diagonal matrix for rescaling the output + if (lane_id == j) { + ss[j*T + C + j] = ms; + } - for (int p0 = 0; p0 < C; p0 += NW) { - const int p = p0 + lane_id; + // local sum + half ls = 0.0f; - const half s = ss[j*T + p]; + for (int p0 = 0; p0 < C; p0 += NW) { + const int p = p0 + lane_id; - const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); + const half s = ss[j*T + p]; - ls += vs; + const half vs = hexp(s - M[j]); - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; - } + ls += vs; - S[j] = S[j]*ms + warp_reduce_sum(ls); + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; } + + S[j] = S[j]*ms + warp_reduce_sum(ls); } smax = warp_reduce_max(smax); @@ -6736,7 +6711,7 @@ static __global__ void flash_attn_ext_f16( // reduce the warps sequentially for (int sg = 1; sg < num_warps; ++sg) { half S = __float2half(0.0f); - half M = __float2half(-INFINITY); + half M = CUDART_MIN_DENORM_FP16; __syncthreads(); @@ -6762,8 +6737,8 @@ static __global__ void flash_attn_ext_f16( M = __hmax(M0, M1); - const half ms0 = __hisinf(M0) == -1 ? __float2half(0.0f) : hexp(M0 - M); - const half ms1 = __hisinf(M1) == -1 ? __float2half(0.0f) : hexp(M1 - M); + const half ms0 = hexp(M0 - M); + const half ms1 = hexp(M1 - M); S = S0*ms0 + S1*ms1; From b958151e3f66e17a9bc5131e446a50c5529b4b81 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 15:00:25 +0200 Subject: [PATCH 046/100] cuda : use half2 in softmax --- ggml-cuda.cu | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 4d1fb008c3994..1fed9d23e2f47 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6451,12 +6451,14 @@ static __global__ void flash_attn_ext_f16( const int T = D + num_warps*SH; // shared memory size per query in (half) const int T2 = T/2; // shared memory size per query in (half2) + const int C2 = C/2; extern __shared__ half __flash_attn_f16_shmem[]; // pq half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix + half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2 half16x16_acc zr; half16x16_acc lo[Q16][D16]; @@ -6606,19 +6608,19 @@ static __global__ void flash_attn_ext_f16( } // used to detect blocks full of -INF - half smax = __float2half(-INFINITY); + half2 smax = make_half2(-INFINITY, -INFINITY); // online softmax for (int j = 0; j < Q; ++j) { const half m = M[j]; - for (int p0 = 0; p0 < C; p0 += NW) { + for (int p0 = 0; p0 < C2; p0 += NW) { const int p = p0 + lane_id; - const half s = ss[j*T + p]; + const half2 s = ss2[j*T2 + p]; - smax = __hmax(smax, s); - M[j] = __hmax(M[j], s); + smax = __hmax2(smax, s); + M[j] = __hmax(M[j], __hmax(s.x, s.y)); } M[j] = warp_reduce_max(M[j]); @@ -6631,28 +6633,31 @@ static __global__ void flash_attn_ext_f16( } // local sum - half ls = 0.0f; + half2 ls = make_half2(0.0f, 0.0f); + half2 M2 = make_half2(M[j], M[j]); - for (int p0 = 0; p0 < C; p0 += NW) { + for (int p0 = 0; p0 < C2; p0 += NW) { const int p = p0 + lane_id; - const half s = ss[j*T + p]; + const half2 s = ss2[j*T2 + p]; - const half vs = hexp(s - M[j]); + const half2 vs = h2exp(s - M2); ls += vs; // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + ss2[j*T2 + p] = vs; } - S[j] = S[j]*ms + warp_reduce_sum(ls); + ls = warp_reduce_sum(ls); + + S[j] = S[j]*ms + ls.x + ls.y; } smax = warp_reduce_max(smax); // skip -INF blocks - if (__hisinf(smax) == -1) { + if (__hisinf(smax.x) == -1 || __hisinf(smax.y) == -1) { continue; } From a7b471569bdf4e09e97b2d02c27989b8cb801861 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 15:17:49 +0200 Subject: [PATCH 047/100] cuda : switch to 1 warp for bs > 16 --- ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 1fed9d23e2f47..c98b551b31e8e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -10933,7 +10933,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why - const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 2; + const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1; dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); dim3 block_dim(32, nwarps, 1); From 3b1c4e76739031bee3028748e0cd288c148f77b4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 15:36:05 +0200 Subject: [PATCH 048/100] cuda : speed-up reduce part of the kernel --- ggml-cuda.cu | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c98b551b31e8e..67541a61ef716 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6715,9 +6715,6 @@ static __global__ void flash_attn_ext_f16( // reduce the warps sequentially for (int sg = 1; sg < num_warps; ++sg) { - half S = __float2half(0.0f); - half M = CUDART_MIN_DENORM_FP16; - __syncthreads(); // each simdgroup stores its output to shared memory, reusing sq @@ -6733,27 +6730,25 @@ static __global__ void flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (warp_id == 0) { - for (int j = 0; j < Q; ++j) { + for (int j = lane_id; j < Q; j += NW) { const half S0 = ss[j*T + 0]; const half S1 = ss[j*T + sg*SH + 0]; const half M0 = ss[j*T + 1]; const half M1 = ss[j*T + sg*SH + 1]; - M = __hmax(M0, M1); + const half M = __hmax(M0, M1); const half ms0 = hexp(M0 - M); const half ms1 = hexp(M1 - M); - S = S0*ms0 + S1*ms1; + const half S = S0*ms0 + S1*ms1; - if (lane_id == 0) { - ss[j*T + 0] = S; - ss[j*T + 1] = M; + ss[j*T + 0] = S; + ss[j*T + 1] = M; - ss[j*T + C + j ] = ms0; - ss[j*T + C + j + sg*SH] = ms1; - } + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + sg*SH] = ms1; } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 @@ -10931,6 +10926,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const int nqpb = NQPB; // queries per block const int ncpw = NCPW; // cache values per warp (does not work for other values) + GGML_ASSERT(NQPB <= 32); + const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1; From 5b263dd83a5f906eddd10bc044051d7571097043 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 16:12:20 +0200 Subject: [PATCH 049/100] cuda : unroll Q*K^T loop --- ggml-cuda.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 67541a61ef716..dbd4822396f4e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6571,6 +6571,7 @@ static __global__ void flash_attn_ext_f16( // Q*K^T { +#pragma unroll for (int cc = 0; cc < C/16; ++cc) { half16x16_acc mqk[Q16]; for (int j = 0; j < Q16; ++j) { From e04ff391819e1875beed3e30d9e7592db45e0e62 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 16:57:46 +0200 Subject: [PATCH 050/100] cuda : fix -INF block check --- ggml-cuda.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index dbd4822396f4e..e51ddc08f764f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6658,7 +6658,7 @@ static __global__ void flash_attn_ext_f16( smax = warp_reduce_max(smax); // skip -INF blocks - if (__hisinf(smax.x) == -1 || __hisinf(smax.y) == -1) { + if (__hisinf(smax.x) == -1 && __hisinf(smax.y) == -1) { continue; } @@ -6676,8 +6676,10 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr); } + } - // restore zeros + // restore zeros + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major); } From cfd9732b2e45a442f4f7261ac0b50ec6e0862ab2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 18:31:55 +0200 Subject: [PATCH 051/100] cuda : simplify softmax --- ggml-cuda.cu | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e51ddc08f764f..25f810cbea86f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6512,11 +6512,10 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); { - half S[Q]; + half S = __float2half(0.0f); half M[Q]; for (int i = 0; i < Q; ++i) { - S[i] = __float2half(0.0f); M[i] = CUDART_MIN_DENORM_FP16; } @@ -6626,13 +6625,6 @@ static __global__ void flash_attn_ext_f16( M[j] = warp_reduce_max(M[j]); - const half ms = hexp(m - M[j]); - - // create a QxQ diagonal matrix for rescaling the output - if (lane_id == j) { - ss[j*T + C + j] = ms; - } - // local sum half2 ls = make_half2(0.0f, 0.0f); half2 M2 = make_half2(M[j], M[j]); @@ -6652,7 +6644,14 @@ static __global__ void flash_attn_ext_f16( ls = warp_reduce_sum(ls); - S[j] = S[j]*ms + ls.x + ls.y; + const half ms = hexp(m - M[j]); + + // create a QxQ diagonal matrix for rescaling the output + if (lane_id == j) { + ss[j*T + C + j] = ms; + + S = S*ms + ls.x + ls.y; + } } smax = warp_reduce_max(smax); @@ -6709,8 +6708,8 @@ static __global__ void flash_attn_ext_f16( // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (int j = 0; j < Q; ++j) { - if (lane_id == 0) { - ss[j*T + 0] = S[j]; + if (lane_id == j) { + ss[j*T + 0] = S; ss[j*T + 1] = M[j]; } } From ef68fac2a8b51e2237234e3d7c6120cade457ce8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 18:36:58 +0200 Subject: [PATCH 052/100] cuda : fix matrix names --- ggml-cuda.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 25f810cbea86f..d9ab2bd093feb 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6687,19 +6687,19 @@ static __global__ void flash_attn_ext_f16( for (int cc = 0; cc < C/16; ++cc) { const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); - half16x16_b mk[D16]; + half16x16_b mv[D16]; for (int i = 0; i < D16; ++i) { - nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half)); + nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half)); } - half16x16_a mv[Q16]; + half16x16_a ms[Q16]; for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T); + nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T); } for (int j = 0; j < Q16; ++j) { for (int i = 0; i < D16; ++i) { - nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]); + nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]); } } } From 1846e92a904ef17a55a3f7e7c2e837f35db2ce7d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 4 Feb 2024 09:57:58 +0200 Subject: [PATCH 053/100] cuda : minor --- ggml-cuda.cu | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index d9ab2bd093feb..713a6a89acfdc 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6399,10 +6399,10 @@ static __global__ void flash_attn_f32( } #if __CUDA_ARCH__ >= CC_VOLTA -typedef nvcuda::wmma::fragment half16x16_a; -typedef nvcuda::wmma::fragment half16x16_b; -typedef nvcuda::wmma::fragment half16x16_bT; -typedef nvcuda::wmma::fragment half16x16_acc; +typedef nvcuda::wmma::fragment half16x16_a; +typedef nvcuda::wmma::fragment half16x16_b; +typedef nvcuda::wmma::fragment half16x16_bT; +typedef nvcuda::wmma::fragment half16x16_acc; #endif // based on metal version @@ -6443,15 +6443,17 @@ static __global__ void flash_attn_ext_f16( const int iq2 = blockIdx.y; const int iq1 = blockIdx.x * Q; - const int D2 = D/2; const int D16 = D/16; const int Q16 = Q/16; + const int C16 = C/16; + const int NW = WARP_SIZE; const int SH = (C + Q); // shared memory per simdgroup in (half) const int T = D + num_warps*SH; // shared memory size per query in (half) const int T2 = T/2; // shared memory size per query in (half2) const int C2 = C/2; + const int D2 = D/2; extern __shared__ half __flash_attn_f16_shmem[]; // pq @@ -6571,7 +6573,7 @@ static __global__ void flash_attn_ext_f16( // Q*K^T { #pragma unroll - for (int cc = 0; cc < C/16; ++cc) { + for (int cc = 0; cc < C16; ++cc) { half16x16_acc mqk[Q16]; for (int j = 0; j < Q16; ++j) { nvcuda::wmma::fill_fragment(mqk[j], 0); @@ -6684,7 +6686,7 @@ static __global__ void flash_attn_ext_f16( // O = O + (Q*K^T)*V { - for (int cc = 0; cc < C/16; ++cc) { + for (int cc = 0; cc < C16; ++cc) { const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); half16x16_b mv[D16]; @@ -6707,11 +6709,9 @@ static __global__ void flash_attn_ext_f16( } // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (int j = 0; j < Q; ++j) { - if (lane_id == j) { - ss[j*T + 0] = S; - ss[j*T + 1] = M[j]; - } + if (lane_id < Q) { + ss[lane_id*T + 0] = S; + ss[lane_id*T + 1] = M[lane_id]; } } @@ -10939,6 +10939,10 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); + // increase shared memory limit to 96KB + //const size_t shmem_max = 96*1024; + //cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max); + switch (Q->ne[0]) { case 64: flash_attn_ext_f16<64, NQPB, NCPW> @@ -11045,6 +11049,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * default: break; } + + CUDA_CHECK(cudaGetLastError()); } static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { From f249c997a8b3f9b129fe825bebd609a362e9ab9c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 19 Feb 2024 13:10:24 +0200 Subject: [PATCH 054/100] llama : adapt to F16 KQ_pos --- ggml-cuda.cu | 2 +- ggml.c | 2 +- llama.cpp | 15 ++++++++++----- tests/test-backend-ops.cpp | 2 +- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2c8af51a66d59..5c6159a83f3cd 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6232,7 +6232,7 @@ static __global__ void soft_max_f32(const float * x, const half * mask, const ha const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? __half2float(slope*pos[col]) : 0.0f); + const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? slope*__half2float(pos[col]) : 0.0f); vals[col] = val; max_val = max(max_val, val); diff --git a/ggml.c b/ggml.c index efc570db698f2..9a2ae62647364 100644 --- a/ggml.c +++ b/ggml.c @@ -5192,7 +5192,7 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(mask->type == GGML_TYPE_F16); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_matrix(mask)); - GGML_ASSERT(ggml_can_repeat_rows(mask, a)); + GGML_ASSERT(mask->ne[1] >= a->ne[1]); } if (pos) { diff --git a/llama.cpp b/llama.cpp index 2359ed10aebf0..5aa3a508d7080 100644 --- a/llama.cpp +++ b/llama.cpp @@ -102,7 +102,7 @@ #define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_EXPERTS 8 -#define LLAMA_FLASH_ATTN +//#define LLAMA_FLASH_ATTN // // logging @@ -4831,6 +4831,11 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * cur; #if defined(LLAMA_FLASH_ATTN) + GGML_UNUSED(model); + GGML_UNUSED(n_ctx); + + GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention"); + // split cached v into n_head heads (not transposed) struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], @@ -5260,7 +5265,7 @@ struct llm_build_context { cb(KQ_mask, "KQ_mask", -1); // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0); + struct ggml_tensor * KQ_pos = ggml_cast(ctx0, ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0), GGML_TYPE_F16); cb(KQ_pos, "KQ_pos", -1); // shift the entire K-cache if needed @@ -5804,7 +5809,7 @@ struct llm_build_context { cb(KQ_mask, "KQ_mask", -1); // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0); + struct ggml_tensor * KQ_pos = ggml_cast(ctx0, ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0), GGML_TYPE_F16); cb(KQ_pos, "KQ_pos", -1); for (int il = 0; il < n_layer; ++il) { @@ -6043,7 +6048,7 @@ struct llm_build_context { cb(KQ_mask, "KQ_mask", -1); // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0); + struct ggml_tensor * KQ_pos = ggml_cast(ctx0, ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0), GGML_TYPE_F16); cb(KQ_pos, "KQ_pos", -1); inpL = llm_build_norm(ctx0, inpL, hparams, @@ -6140,7 +6145,7 @@ struct llm_build_context { cb(KQ_mask, "KQ_mask", -1); // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0); + struct ggml_tensor * KQ_pos = ggml_cast(ctx0, ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0), GGML_TYPE_F16); cb(KQ_pos, "KQ_pos", -1); for (int il = 0; il < n_layer; ++il) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 912223def6e06..278c57299ce88 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1505,7 +1505,7 @@ struct test_attn : public test_case { struct ggml_tensor * cur; cur = ggml_mul_mat (ctx, k, q); - cur = ggml_soft_max_ext(ctx, cur, mask, 1.0f/sqrtf(hs)); + cur = ggml_soft_max_ext(ctx, cur, mask, nullptr, 1.0f/sqrtf(hs), 0.0f); cur = ggml_mul_mat (ctx, v, cur); cur = ggml_permute (ctx, cur, 0, 2, 1, 3); cur = ggml_cont_2d (ctx, cur, hs*nh, nb); From 6aefd11204199c9bd520b8991bab4085cb6fc977 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 3 Mar 2024 13:50:54 +0200 Subject: [PATCH 055/100] llama : adapt new models to F16 KQ_mask --- llama.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index 1a099adcba5dc..f2b224cafb2e3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7362,7 +7362,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { @@ -7489,7 +7489,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { @@ -7724,7 +7724,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { From 58c7f6167c0f6540f0da0386fc65d940e1a16ea5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 4 Mar 2024 20:44:57 +0200 Subject: [PATCH 056/100] ggml : fix F16 store (ARM NEON) --- ggml.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index 09ed9da342885..5715b78ec7bbf 100644 --- a/ggml.c +++ b/ggml.c @@ -874,7 +874,7 @@ inline static float vaddvq_f32(float32x4_t v) { #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((__fp16 *)(p), r[i]) #define GGML_F16_VEC_FMA GGML_F16x8_FMA #define GGML_F16_VEC_ADD GGML_F16x8_ADD #define GGML_F16_VEC_MUL GGML_F16x8_MUL @@ -900,7 +900,7 @@ inline static float vaddvq_f32(float32x4_t v) { #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((__fp16 *)(p), r[i]) #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL From 3a468e6f9f0c7dff9ed78b0f7a5af069da420606 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 22 Mar 2024 17:12:17 +0200 Subject: [PATCH 057/100] llama : fix type of KQ_mask and KQ_pos --- llama.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/llama.cpp b/llama.cpp index 77da94960e65c..b80080daf7506 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5810,20 +5810,20 @@ struct llm_build_context { struct ggml_tensor * build_inp_KQ_mask(bool causal = true) { if (causal) { - lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); } else { - lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); } cb(lctx.inp_KQ_mask, "KQ_mask", -1); ggml_set_input(lctx.inp_KQ_mask); - return lctx.inp_KQ_mask; + return ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16); } struct ggml_tensor * build_inp_KQ_pos() { - lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F16, n_kv); + lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv); cb(lctx.inp_KQ_pos, "KQ_pos", -1); ggml_set_input(lctx.inp_KQ_pos); - return lctx.inp_KQ_pos; + return ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16); } struct ggml_tensor * build_inp_mean() { From 09532120e0afc5b0c451fbb79eac558e4561660b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 22 Mar 2024 17:49:42 +0200 Subject: [PATCH 058/100] ggml : fix CPU soft_max --- ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 434fb76c3b1b6..7ea1abfcd67c6 100644 --- a/ggml.c +++ b/ggml.c @@ -12302,7 +12302,7 @@ static void ggml_compute_forward_soft_max_f32( float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); // broadcast the mask across rows - ggml_fp16_t * mp = src1 ? (ggml_fp16_t *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; + ggml_fp16_t * mp = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; ggml_vec_cpy_f32 (nc, wp, sp); ggml_vec_scale_f32(nc, wp, scale); From e425810bb6c8e9dd8ef9ff6f606313b6cfa1b607 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 24 Mar 2024 12:21:41 +0200 Subject: [PATCH 059/100] tests : add hs=256 --- ggml.c | 2 +- tests/test-backend-ops.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index 7ea1abfcd67c6..336c3c0ffaade 100644 --- a/ggml.c +++ b/ggml.c @@ -12272,7 +12272,7 @@ static void ggml_compute_forward_soft_max_f32( GGML_TENSOR_UNARY_OP_LOCALS - const int64_t ne11 = src1 ? src1->ne[1] : 1; + //const int64_t ne11 = src1 ? src1->ne[1] : 1; // TODO: is this supposed to be ceil instead of floor? // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b8994cecff4b2..50c4b27294fea 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2261,7 +2261,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_leaky_relu()); #if 1 - for (int hs : { 128, 64, 80, }) { + for (int hs : { 128, 256, 64, 80, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { From 6be02b5969ed04fb3ce336702ae3949076bb2bf4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 27 Mar 2024 10:31:52 +0200 Subject: [PATCH 060/100] cuda : fix build --- ggml-cuda.cu | 7 +------ ggml-cuda/fattn.cu | 48 +++++++++++++++++++++++++++++++++++++++------ ggml-cuda/fattn.cuh | 5 +---- 3 files changed, 44 insertions(+), 16 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 963588c4f2fad..31bfc43c1febe 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2384,17 +2384,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_argsort(ctx, dst); break; case GGML_OP_FLASH_ATTN_EXT: + ggml_cuda_flash_attn_ext(ctx, dst); break; default: return false; } - if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { - ggml_cuda_flash_attn_ext(ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); - } else { - func(ctx, tensor->src[0], tensor->src[1], tensor); - } - cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst)); diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 0f135a184b572..bcf27fd794aaf 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -1,5 +1,33 @@ #include "fattn.cuh" +#include + +static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); + } + return a; +#else + GGML_UNUSED(a); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +} + +static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#else + GGML_UNUSED(x); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +} + #if __CUDA_ARCH__ >= CC_VOLTA typedef nvcuda::wmma::fragment half16x16_a; typedef nvcuda::wmma::fragment half16x16_b; @@ -10,11 +38,11 @@ typedef nvcuda::wmma::fragment // based on metal version template // D head size, Q queries per block, C cache items per block static __global__ void flash_attn_ext_f16( - const char* __restrict__ q, - const char* __restrict__ k, - const char* __restrict__ v, - const char* __restrict__ mask, - float* __restrict__ dst, + const char * __restrict__ q, + const char * __restrict__ k, + const char * __restrict__ v, + const char * __restrict__ mask, + float * __restrict__ dst, float scale, int ne00, int ne01, @@ -408,7 +436,15 @@ static __global__ void flash_attn_ext_f16( #endif } -void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) { +void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + const ggml_tensor * mask = dst->src[3]; + + ggml_tensor * KQV = dst; + GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(K->type == GGML_TYPE_F16); GGML_ASSERT(V->type == GGML_TYPE_F16); diff --git a/ggml-cuda/fattn.cuh b/ggml-cuda/fattn.cuh index 1b764bc946f8f..ad3ca7a8d8e4d 100644 --- a/ggml-cuda/fattn.cuh +++ b/ggml-cuda/fattn.cuh @@ -1,6 +1,3 @@ #include "common.cuh" -void ggml_cuda_flash_attn_ext( - ggml_backend_cuda_context & ctx, - const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, - const ggml_tensor * mask, ggml_tensor * KQV); +void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From 57c03b78b6bca2049e43905e808666b04304d0fd Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 28 Mar 2024 19:29:06 +0200 Subject: [PATCH 061/100] metal : improve perf via smaller int registers --- ggml-metal.metal | 149 ++++++++++++++++++++++++----------------------- 1 file changed, 77 insertions(+), 72 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index dc6ad31417c51..27eeb3932ff1a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2061,11 +2061,11 @@ typedef void (flash_attn_ext_f16_t)( constant int64_t & ne3, constant float & scale, threadgroup half * shared, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]); + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]); // ref: https://arxiv.org/pdf/2307.08691.pdf template // head size, queries per threadgroup, cache items per threadgroup @@ -2099,25 +2099,25 @@ kernel void kernel_flash_attn_ext_f16( constant int64_t & ne3, constant float & scale, threadgroup half * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const uint nsg = ntg.y; // number of simdgroups - - const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]; - const int64_t iq1 = tgpig[0]*Q; - - const int64_t D4 = D/4; - const int64_t D8 = D/8; - const int64_t Q8 = Q/8; - const int64_t NW = N_SIMDWIDTH; - const int64_t SH = (C + Q); // shared memory per simdgroup in (half) - - const int64_t T = D + nsg*SH; // shared memory size per query in (half) - const int64_t T4 = T/4; // shared memory size per query in (half4) + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]*Q; + + const short D4 = D/4; + const short D8 = D/8; + const short Q8 = Q/8; + const short NW = N_SIMDWIDTH; + const short SH = (C + Q); // shared memory per simdgroup in (half) + + const short T = D + nsg*SH; // shared memory size per query in (half) + const short T4 = T/4; // shared memory size per query in (half4) threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 @@ -2127,10 +2127,10 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_half8x8 lo[Q8][D8]; // load heads from Q to shared memory - for (int64_t j = sgitg; j < Q; j += nsg) { + for (short j = sgitg; j < Q; j += nsg) { device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); - for (int64_t i = tiisg; i < D4; i += NW) { + for (short i = tiisg; i < D4; i += NW) { if (iq1 + j < ne01) { sq4[j*T4 + i] = (half4) q4[i]; } else { @@ -2140,15 +2140,15 @@ kernel void kernel_flash_attn_ext_f16( } // zero out lo - for (int64_t j = 0; j < Q8; ++j) { - for (int64_t i = 0; i < D8; ++i) { + for (short j = 0; j < Q8; ++j) { + for (short i = 0; i < D8; ++i) { lo[j][i] = make_filled_simdgroup_matrix(0.0h); } } // zero out shared memory SH - for (int64_t j = 0; j < Q; ++j) { - for (int64_t i = tiisg; i < SH; i += NW) { + for (short j = 0; j < Q; ++j) { + for (short i = tiisg; i < SH; i += NW) { ss[j*T + i] = 0.0h; } } @@ -2160,33 +2160,33 @@ kernel void kernel_flash_attn_ext_f16( half M[Q] = { [0 ... Q-1] = -INFINITY }; // assume K and V are same shape - const int64_t ne22 = ne12; - const int64_t ne23 = ne13; + const short ne22 = ne12; + const short ne23 = ne13; - const uint64_t nb21 = nb11; - const uint64_t nb22 = nb12; - const uint64_t nb23 = nb13; + const uint nb21 = nb11; + const uint nb22 = nb12; + const uint nb23 = nb13; // broadcast - const int64_t rk2 = ne02/ne12; - const int64_t rk3 = ne03/ne13; + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; - const int64_t rv2 = ne02/ne22; - const int64_t rv3 = ne03/ne23; + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; // k indices - const int64_t ik2 = iq2 / rk2; - const int64_t ik3 = iq3 / rk3; + const short ik2 = iq2 / rk2; + const short ik3 = iq3 / rk3; // v indices - const int64_t iv2 = iq2 / rv2; - const int64_t iv3 = iq3 / rv3; + const short iv2 = iq2 / rv2; + const short iv3 = iq3 / rv3; // load the queries from shared memory into local memory simdgroup_half8x8 mq[Q8][D8]; - for (int64_t j = 0; j < Q8; ++j) { - for (int64_t i = 0; i < D8; ++i) { + for (short j = 0; j < Q8; ++j) { + for (short i = 0; i < D8; ++i) { simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T); } } @@ -2199,28 +2199,33 @@ kernel void kernel_flash_attn_ext_f16( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int64_t ic = C*sgitg; ic < ne11; ic += C*nsg) { + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + // Q*K^T { - for (int cc = 0; cc < C/8; ++cc) { + for (short cc = 0; cc < C/8; ++cc) { simdgroup_half8x8 mqk[Q8]; - for (int64_t j = 0; j < Q8; ++j) { + for (short j = 0; j < Q8; ++j) { mqk[j] = make_filled_simdgroup_matrix(0.h); } device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); - for (int64_t i = 0; i < D8; ++i) { + for (short i = 0; i < D8; ++i) { simdgroup_half8x8 mk; simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose - for (int64_t j = 0; j < Q8; ++j) { + for (short j = 0; j < Q8; ++j) { simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]); } } // mqk = mqk*scale + mask - for (int64_t j = 0; j < Q8; ++j) { + for (short j = 0; j < Q8; ++j) { simdgroup_half8x8 mm; simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); @@ -2237,8 +2242,8 @@ kernel void kernel_flash_attn_ext_f16( if (C == 32) { half ms[Q]; - for (int64_t j = 0; j < Q; ++j) { - const int64_t p = tiisg; + for (short j = 0; j < Q; ++j) { + const short p = tiisg; const half m = M[j]; const half s = ss[j*T + p]; @@ -2262,10 +2267,10 @@ kernel void kernel_flash_attn_ext_f16( } else { half ms[Q]; - for (int64_t j = 0; j < Q; ++j) { + for (short j = 0; j < Q; ++j) { const half m = M[j]; - for (int64_t p = tiisg; p < C; p += NW) { + for (short p = tiisg; p < C; p += NW) { const half s = ss[j*T + p]; smax = max(smax, s); @@ -2280,7 +2285,7 @@ kernel void kernel_flash_attn_ext_f16( // local sum half ls = 0.0h; - for (int64_t p = tiisg; p < C; p += NW) { + for (short p = tiisg; p < C; p += NW) { const half s = ss[j*T + p]; const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); @@ -2306,25 +2311,25 @@ kernel void kernel_flash_attn_ext_f16( } // O = diag(ms)*O - for (int64_t j = 0; j < Q8; ++j) { + for (short j = 0; j < Q8; ++j) { simdgroup_half8x8 mm; simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false); - for (int64_t i = 0; i < D8; ++i) { + for (short i = 0; i < D8; ++i) { simdgroup_multiply(lo[j][i], mm, lo[j][i]); } } // O = O + (Q*K^T)*V { - for (int cc = 0; cc < C/8; ++cc) { + for (short cc = 0; cc < C/8; ++cc) { device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); - for (int64_t i = 0; i < D8; ++i) { + for (short i = 0; i < D8; ++i) { simdgroup_half8x8 mk; simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); - for (int64_t j = 0; j < Q8; ++j) { + for (short j = 0; j < Q8; ++j) { simdgroup_half8x8 mv; simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false); @@ -2336,7 +2341,7 @@ kernel void kernel_flash_attn_ext_f16( } // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (int64_t j = 0; j < Q; ++j) { + for (short j = 0; j < Q; ++j) { if (tiisg == 0) { ss[j*T + 0] = S[j]; ss[j*T + 1] = M[j]; @@ -2345,7 +2350,7 @@ kernel void kernel_flash_attn_ext_f16( } // reduce the warps sequentially - for (int64_t sg = 1; sg < nsg; ++sg) { + for (short sg = 1; sg < nsg; ++sg) { half S = { 0.0h }; half M = { -INFINITY }; @@ -2353,8 +2358,8 @@ kernel void kernel_flash_attn_ext_f16( // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { - for (int64_t j = 0; j < Q8; ++j) { - for (int64_t i = 0; i < D8; ++i) { + for (short j = 0; j < Q8; ++j) { + for (short i = 0; i < D8; ++i) { simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); } } @@ -2364,7 +2369,7 @@ kernel void kernel_flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (sgitg == 0) { - for (int64_t j = 0; j < Q; ++j) { + for (short j = 0; j < Q; ++j) { const half S0 = ss[j*T + 0]; const half S1 = ss[j*T + sg*SH + 0]; @@ -2388,7 +2393,7 @@ kernel void kernel_flash_attn_ext_f16( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (int64_t j = 0; j < Q8; ++j) { + for (short j = 0; j < Q8; ++j) { simdgroup_half8x8 t; simdgroup_half8x8 ms0; simdgroup_half8x8 ms1; @@ -2396,7 +2401,7 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false); simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false); - for (int64_t i = 0; i < D8; ++i) { + for (short i = 0; i < D8; ++i) { simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false); simdgroup_multiply(t, ms1, t); @@ -2408,8 +2413,8 @@ kernel void kernel_flash_attn_ext_f16( // store result to shared memory (reuse sq) if (sgitg == 0) { - for (int64_t j = 0; j < Q8; ++j) { - for (int64_t i = 0; i < D8; ++i) { + for (short j = 0; j < Q8; ++j) { + for (short i = 0; i < D8; ++i) { simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); } } @@ -2419,10 +2424,10 @@ kernel void kernel_flash_attn_ext_f16( // final rescale with 1/S and store to global memory if (sgitg == 0) { - for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { + for (short j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; - for (int64_t i = tiisg; i < D4; i += NW) { + for (short i = tiisg; i < D4; i += NW) { dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; } } From 08e69c50081622d8146e749195939157be2a2207 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 28 Mar 2024 19:40:11 +0200 Subject: [PATCH 062/100] cuda : adapt soft_max to F16 mask and pos --- ggml-cuda/softmax.cu | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ggml-cuda/softmax.cu b/ggml-cuda/softmax.cu index 9bda18e581c75..8f6dca4d0f9bf 100644 --- a/ggml-cuda/softmax.cu +++ b/ggml-cuda/softmax.cu @@ -1,7 +1,7 @@ #include "softmax.cuh" template -static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { +static __global__ void soft_max_f32(const float * x, const half * mask, const half * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; @@ -43,7 +43,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f); + const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? slope*__half2float(pos[col]) : 0.0f); vals[col] = val; max_val = max(max_val, val); @@ -114,7 +114,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f } } -static void soft_max_f32_cuda(const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { +static void soft_max_f32_cuda(const float * x, const half * mask, const half * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -168,14 +168,14 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; const float * src0_d = (const float *)src0->data; - const float * src1_d = src1 ? (const float *)src1->data : nullptr; + const half * src1_d = src1 ? (const half *)src1->data : nullptr; float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); // src1 contains mask and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -188,13 +188,13 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); // positions tensor - float * src2_dd = nullptr; + half * src2_dd = nullptr; ggml_tensor * src2 = dst->src[2]; const bool use_src2 = src2 != nullptr; if (use_src2) { - src2_dd = (float *)src2->data; + src2_dd = (half *)src2->data; } soft_max_f32_cuda(src0_d, src1_d, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); From 75aa7b4b189a5a2f6518840e84ec489da71e0443 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 29 Mar 2024 23:02:39 +0100 Subject: [PATCH 063/100] CUDA: faster FlashAttention, kernel for bs == 1 --- ggml-cuda/fattn.cu | 1357 +++++++++++++++++++++++++++++--------------- 1 file changed, 906 insertions(+), 451 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index bcf27fd794aaf..ccb3c924609a4 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -28,414 +28,416 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX } -#if __CUDA_ARCH__ >= CC_VOLTA -typedef nvcuda::wmma::fragment half16x16_a; -typedef nvcuda::wmma::fragment half16x16_b; -typedef nvcuda::wmma::fragment half16x16_bT; -typedef nvcuda::wmma::fragment half16x16_acc; -#endif - -// based on metal version -template // D head size, Q queries per block, C cache items per block -static __global__ void flash_attn_ext_f16( - const char * __restrict__ q, - const char * __restrict__ k, - const char * __restrict__ v, +template // D == head size +__launch_bounds__(D, 1) +static __global__ void flash_attn_vec_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, const char * __restrict__ mask, float * __restrict__ dst, - float scale, - int ne00, - int ne01, - int ne02, - int ne03, - int ne10, - int ne11, - int ne12, - int ne13, - int ne31, - int nb31, - int nb01, - int nb02, - int nb03, - int nb11, - int nb12, - int nb13, - int ne0, - int ne1, - int ne2, - int ne3) { -#if __CUDA_ARCH__ >= CC_VOLTA - const int warp_id = threadIdx.y; - const int lane_id = threadIdx.x; - - const int num_warps = blockDim.y; // number of warps - const int iq3 = blockIdx.z; - const int iq2 = blockIdx.y; - const int iq1 = blockIdx.x * Q; - - const int D16 = D/16; - const int Q16 = Q/16; - const int C16 = C/16; - - const int NW = WARP_SIZE; - const int SH = (C + Q); // shared memory per simdgroup in (half) - - const int T = D + num_warps*SH; // shared memory size per query in (half) - const int T2 = T/2; // shared memory size per query in (half2) - const int C2 = C/2; - const int D2 = D/2; - - extern __shared__ half __flash_attn_f16_shmem[]; - // pq - half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data - half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 - half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix - half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2 - - half16x16_acc zr; - half16x16_acc lo[Q16][D16]; - - // load heads from Q to shared memory + const float scale, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*blockIdx.x); + const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) mask + ne31*blockIdx.x; + + const int stride_KV = nb11 / sizeof(half); + const int stride_KV2 = nb11 / sizeof(half2); + + constexpr int nwarps = D/WARP_SIZE; + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + __builtin_assume(tid < D); + + __shared__ half KQ[D]; + KQ[tid] = 0.0f; + half2 * KQ2 = (half2 *) KQ; + + half kqmax = -INFINITY; + half kqsum = 0.0f; + + __shared__ half kqmax_shared[WARP_SIZE]; + __shared__ half kqsum_shared[WARP_SIZE]; + if (threadIdx.y == 0) { + kqmax_shared[threadIdx.x] = -INFINITY; + kqsum_shared[threadIdx.x] = 0.0f; + } + + __syncthreads(); + + // Convert Q to half2 and store in registers: + half2 Q_h2[(D/2 + WARP_SIZE - 1) / WARP_SIZE]; #pragma unroll - for (int j0 = 0; j0 < Q; j0 += num_warps) { - const int j = j0 + warp_id; - if (j >= Q) { + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { break; } - const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + Q_h2[i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(Q_f2[i].x, Q_f2[i].y); + } + + half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value. + for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += D) { + // Calculate KQ tile and keep track of new maximum KQ values: + half kqmax_new = kqmax; #pragma unroll - for (int i0 = 0; i0 < D2; i0 += NW) { - const int i = i0 + lane_id; - if (i >= D2) { + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { + const int i_KQ = i_KQ_0 + threadIdx.y; + + if (256 % D != 0 && k_VKQ_0 + i_KQ >= ne11) { break; } - if (iq1 + j < ne01) { - sq2[j*T2 + i] = __float22half2_rn(q2[i]); - } else { - sq2[j*T2 + i] = make_half2(0.0, 0.0); + half2 sum2 = make_half2(0.0f, 0.0f); +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + if (k_KQ_0 + WARP_SIZE > D/2 && k_KQ >= D/2) { + break; + } + + const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; + sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE]; } - } - } - nvcuda::wmma::fill_fragment(zr, 0.0); + sum2 = warp_reduce_sum(sum2); + half sum = __low2half(sum2) + __high2half(sum2); + sum += mask ? maskh[k_VKQ_0 + i_KQ] : __float2half(0.0f); + kqmax_new = __hmax(kqmax_new, sum); + if (threadIdx.x == 0) { + KQ[i_KQ] = sum; + } + } - // zero out lo - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::fill_fragment(lo[j][i], 0.0); + kqmax_new = warp_reduce_max(kqmax_new); + if (threadIdx.x == 0) { + kqmax_shared[threadIdx.y] = kqmax_new; } - } + __syncthreads(); + kqmax_new = kqmax_shared[threadIdx.x]; + kqmax_new = warp_reduce_max(kqmax_new); + + const half KQ_max_scale = hexp(kqmax - kqmax_new); + kqmax = kqmax_new; + + const half val = hexp(KQ[tid] - kqmax); + kqsum = kqsum*KQ_max_scale + val; + KQ[tid] = val; + + VKQ *= __half2half2(KQ_max_scale); + + __syncthreads(); - // zero out shared memory SH - for (int j = 0; j < Q; ++j) { - for (int i0 = 0; i0 < SH; i0 += NW) { - const int i = i0 + lane_id; - if (i >= SH) { +#pragma unroll + for (int k0 = 0; k0 < D; k0 += 2) { + if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { break; } - ss[j*T + i] = 0.0; + half2 V_k; + reinterpret_cast(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid]; + reinterpret_cast(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid]; + VKQ += V_k*KQ2[k0/2]; } } + kqsum = warp_reduce_sum(kqsum); + if (threadIdx.x == 0) { + kqsum_shared[threadIdx.y] = kqsum; + } __syncthreads(); + kqsum = kqsum_shared[threadIdx.x]; + kqsum = warp_reduce_sum(kqsum); - { - half S = __float2half(0.0f); - half M[Q]; + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum; +} - for (int i = 0; i < Q; ++i) { - M[i] = CUDART_MIN_DENORM_FP16; +template // D == head size +__launch_bounds__(ncols == 8 || D > 128 ? D : 2*D, 1) +static __global__ void flash_attn_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + const float scale, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); + constexpr int frag_m = ncols == 8 ? 32 : 16; + constexpr int frag_n = ncols == 8 ? 8 : 16; + static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); + typedef nvcuda::wmma::fragment frag_a_K; + typedef nvcuda::wmma::fragment frag_a_V; + typedef nvcuda::wmma::fragment frag_b; + typedef nvcuda::wmma::fragment frag_c; + + constexpr int nwarps = (D <= 128 || ncols == 8 ? D : D/2) / frag_m; + constexpr int nthreads = nwarps*WARP_SIZE; + static_assert(nthreads % D == 0, "nthreads not divisible by D."); + constexpr int tc_vals_per_iter = nwarps*frag_m; + static_assert(D % tc_vals_per_iter == 0, "D not divisible by tensor core vals per iter."); + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + __builtin_assume(tid < nthreads); + constexpr int D_padded = D + 8; // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts. + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float * Q_f = (const float *) (Q + nb02* blockIdx.y + ncols*nb01*blockIdx.x); + const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half2 * mask2 = (half2 *) mask + ncols*ne11*blockIdx.x/2; + + const int stride_Q = nb01 / sizeof(float); + const int stride_KV = nb11 / sizeof(half); + + frag_b Q_b[D/16][ncols/frag_n]; + + __shared__ half KQ[ncols*D_padded]; // Buffer for temporarily holding tiles of KQ. + half2 * KQ2 = (half2 *) KQ; + + half2 KQ_rowsum[(ncols + nwarps - 1) / nwarps] = {{0.0f, 0.0f}}; + half2 KQ_max[(ncols + nwarps - 1) / nwarps] = {{-INFINITY, -INFINITY}}; + half2 KQ_max_scale[(ncols + nwarps - 1) / nwarps] = {{0.0f, 0.0f}}; + + __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. + half2 * VKQ2 = (half2 *) VKQ; +#pragma unroll + for (int i0 = 0; i0 < ncols*D_padded/2; i0 += nthreads) { + const int i = i0 + tid; + if (i0 + nthreads > ncols*D_padded/2 && i >= ncols*D_padded/2) { + break; } - // assume K and V are same shape - const int ne22 = ne12; - const int ne23 = ne13; - - const int nb21 = nb11; - const int nb22 = nb12; - const int nb23 = nb13; - - // broadcast - const int rk2 = ne02/ne12; - const int rk3 = ne03/ne13; - - const int rv2 = ne02/ne22; - const int rv3 = ne03/ne23; + VKQ2[i] = make_half2(0.0f, 0.0f); + } - // k indices - const int ik2 = iq2 / rk2; - const int ik3 = iq3 / rk3; + // Convert Q to half and apply scale, temporarily store in KQ: +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nthreads/D) { + const int j = j0 + tid/D; + const int i = tid % D; + KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + } - // v indices - const int iv2 = iq2 / rv2; - const int iv3 = iq3 / rv3; + __syncthreads(); - // load the queries from shared memory into local memory - half16x16_a mq[Q16][D16]; - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); - } + // Load Q into tensor core fragments/registers since it will be used frequently: +#pragma unroll + for (int i0 = 0; i0 < D; i0 += 16) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); } + } - // pointer to the mask - const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr; + __syncthreads(); - // prepare diagonal scale matrix - half16x16_b mscale; - for (int i = 0; i < 16; ++i) { - ss[i*T + i] = __float2half(scale); - } - nvcuda::wmma::load_matrix_sync(mscale, ss, T); + // Iterate over ne11 == previous tokens: + for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += D) { + const bool has_valid_data = 256 % D == 0 || k_VKQ_0 + frag_m*threadIdx.y < ne11; - // loop over the KV cache - // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < ne11; ic0 += C*num_warps) { - const int ic = ic0 + warp_id*C; - if (ic >= ne11) { - break; + // Calculate tile of KQ: +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) { + frag_c KQ_c[ncols/frag_n]; +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); } - - // Q*K^T - { + if (has_valid_data) { #pragma unroll - for (int cc = 0; cc < C16; ++cc) { - half16x16_acc mqk[Q16]; - for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::fill_fragment(mqk[j], 0); - } - - const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); - - for (int i = 0; i < D16; ++i) { - half16x16_bT mk; // transposed key - nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); - - for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); - } - } - - // mqk = mqk*scale + mask - for (int j = 0; j < Q16; ++j) { - half16x16_a mqka; - half16x16_acc mm; - - if (mp) { - nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); - } - - // convert accumulator to matrix_a - nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); - nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T); - - nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr); - nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); + for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { + frag_a_K K_a; + nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); } } } +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::store_matrix_sync(KQ + j0*D_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], D_padded, nvcuda::wmma::mem_col_major); + } + } - // used to detect blocks full of -INF - half2 smax = make_half2(-INFINITY, -INFINITY); - - // online softmax - for (int j = 0; j < Q; ++j) { - const half m = M[j]; - - for (int p0 = 0; p0 < C2; p0 += NW) { - const int p = p0 + lane_id; + __syncthreads(); - const half2 s = ss2[j*T2 + p]; + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + if (j0 + nwarps > ncols && j >= ncols) { + break; + } - smax = __hmax2(smax, s); - M[j] = __hmax(M[j], __hmax(s.x, s.y)); + half2 KQ_max_new = KQ_max[j0/nwarps]; +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + if (k0 + WARP_SIZE > D/2 && k >= D/2) { + break; } + KQ_max_new = __hmax2(KQ_max_new, KQ2[j*(D_padded/2) + k]); + } + KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); + KQ_max_scale[j0/nwarps] = h2exp(KQ_max[j0/nwarps] - KQ_max_new); + KQ_max[j0/nwarps] = KQ_max_new; - M[j] = warp_reduce_max(M[j]); - - // local sum - half2 ls = make_half2(0.0f, 0.0f); - half2 M2 = make_half2(M[j], M[j]); - - for (int p0 = 0; p0 < C2; p0 += NW) { - const int p = p0 + lane_id; - - const half2 s = ss2[j*T2 + p]; - - const half2 vs = h2exp(s - M2); - - ls += vs; - - // the P matrix from the paper (Q rows, C columns) - ss2[j*T2 + p] = vs; + half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + if (k0 + WARP_SIZE > D/2 && k >= D/2) { + break; } - - ls = warp_reduce_sum(ls); - - const half ms = hexp(m - M[j]); - - // create a QxQ diagonal matrix for rescaling the output - if (lane_id == j) { - ss[j*T + C + j] = ms; - - S = S*ms + ls.x + ls.y; + if (256 % D != 0 && k_VKQ_0 + 2*k >= ne11) { + break; } - } - smax = warp_reduce_max(smax); - - // skip -INF blocks - if (__hisinf(smax.x) == -1 && __hisinf(smax.y) == -1) { - continue; + half2 val = KQ2[j*(D_padded/2) + k]; + val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + val = h2exp(val - KQ_max[j0/nwarps]); + KQ_rowsum_add += val; + KQ2[j*(D_padded/2) + k] = val; } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); - // O = diag(ms)*O - for (int j = 0; j < Q16; ++j) { - half16x16_a mm; - half16x16_b lob; - - nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum[j0/nwarps] = KQ_max_scale[j0/nwarps]*KQ_rowsum[j0/nwarps] + KQ_rowsum_add; + } - for (int i = 0; i < D16; ++i) { - // convert accumulator to matrix_b - nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); - nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); + __syncthreads(); - nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr); - } + frag_b KQ_b[D/16][ncols/frag_n]; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { +#pragma unroll + for (int k0 = 0; k0 < D; k0 += 16) { + nvcuda::wmma::load_matrix_sync(KQ_b[k0/16][j0/frag_n], KQ + j0*D_padded + k0, D_padded); } + } - // restore zeros - for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major); + frag_c VKQ_c[D/tc_vals_per_iter][ncols/frag_n]; +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) { + #pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::fill_fragment(VKQ_c[i_KQ_0/tc_vals_per_iter][j], 0.0f); } - // O = O + (Q*K^T)*V - { - for (int cc = 0; cc < C16; ++cc) { - const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); - - half16x16_b mv[D16]; - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half)); - } - - half16x16_a ms[Q16]; - for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T); - } + #pragma unroll + for (int k0 = 0; k0 < D; k0 += 16) { + if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { + break; + } - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]); - } - } + frag_a_V v_a; + nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k0)*stride_KV + i_KQ_0 + frag_m*threadIdx.y, stride_KV); + #pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(VKQ_c[i_KQ_0/tc_vals_per_iter][j], v_a, KQ_b[k0/16][j], VKQ_c[i_KQ_0/tc_vals_per_iter][j]); } } } - // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - if (lane_id < Q) { - ss[lane_id*T + 0] = S; - ss[lane_id*T + 1] = M[lane_id]; - } - } - - // reduce the warps sequentially - for (int sg = 1; sg < num_warps; ++sg) { __syncthreads(); - // each simdgroup stores its output to shared memory, reusing sq - if (warp_id == sg) { - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); - } +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::store_matrix_sync( + KQ + j0*D_padded + i_KQ_0 + frag_m*threadIdx.y, + VKQ_c[i_KQ_0/tc_vals_per_iter][j0/frag_n], + D_padded, nvcuda::wmma::mem_col_major); } } __syncthreads(); - // the first simdgroup accumulates the results from the other simdgroups - if (warp_id == 0) { - for (int j = lane_id; j < Q; j += NW) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + sg*SH + 0]; - - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + sg*SH + 1]; - - const half M = __hmax(M0, M1); - - const half ms0 = hexp(M0 - M); - const half ms1 = hexp(M1 - M); - - const half S = S0*ms0 + S1*ms1; - - ss[j*T + 0] = S; - ss[j*T + 1] = M; - - ss[j*T + C + j ] = ms0; - ss[j*T + C + j + sg*SH] = ms1; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + if (j0 + nwarps > ncols && j >= ncols) { + break; } - - // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (int j = 0; j < Q16; ++j) { - half16x16_a ms0; - half16x16_a ms1; - half16x16_b t; - half16x16_acc t2; - - nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T); - nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); - - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); - nvcuda::wmma::mma_sync(t2, ms1, t, zr); - - // convert accumulator to matrix_b - nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); - nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T); - - nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; } + VKQ2[j*(D_padded/2) + i] = KQ_max_scale[j0/nwarps]*VKQ2[j*(D_padded/2) + i] + KQ2[j*(D_padded/2) + i]; } } - } - // store result to shared memory (reuse sq) - if (warp_id == 0) { - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); - } - } + __syncthreads(); } - // final rescale with 1/S and store to global memory - if (warp_id == 0) { - for (int j = 0; j < Q && iq1 + j < ne01; ++j) { - const half S = ss[j*T + 0]; - - for (int i0 = 0; i0 < D; i0 += NW) { - const int i = i0 + lane_id; - if (i >= D) { - break; - } - - dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + if ((j0 + nwarps > ncols && j >= ncols) || ncols*blockIdx.x + j >= ne01) { + return; + } + const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]); +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; } + dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j; } } -#else - NO_DEVICE_CODE; -#endif } + void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; @@ -461,133 +463,586 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst float scale; memcpy(&scale, KQV->op_params, sizeof(float)); -#define NQPB 16 -#define NCPW 128 - - const int nqpb = NQPB; // queries per block - const int ncpw = NCPW; // cache values per warp (does not work for other values) - - GGML_ASSERT(NQPB <= 32); - - const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? - // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why - const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1; - - dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); - dim3 block_dim(32, nwarps, 1); - - const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); + if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[1] == 1) { + const int nwarps = Q->ne[0] / WARP_SIZE; + const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]); + const dim3 block_dim(WARP_SIZE, nwarps, 1); + const int shmem = 0; + switch (Q->ne[0]) { + case 64: + flash_attn_vec_ext_f16<64> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + // case 80: + // flash_attn_vec_ext_f16<80> + // <<>> ( + // (const char *) Q->data, // Query + // (const char *) K->data, // Key + // (const char *) V->data, // Value + // mask ? ((const char *) mask->data) : nullptr, // Mask + // (float *) KQV->data, // dst + // scale, + // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + // K->ne[0], K->ne[1], K->ne[2], K->ne[3], + // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + // Q->nb[1], Q->nb[2], Q->nb[3], + // K->nb[1], K->nb[2], K->nb[3], + // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + // ); + // break; + case 96: + flash_attn_vec_ext_f16<96> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + // case 112: + // flash_attn_vec_ext_f16<112> + // <<>> ( + // (const char *) Q->data, // Query + // (const char *) K->data, // Key + // (const char *) V->data, // Value + // mask ? ((const char *) mask->data) : nullptr, // Mask + // (float *) KQV->data, // dst + // scale, + // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + // K->ne[0], K->ne[1], K->ne[2], K->ne[3], + // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + // Q->nb[1], Q->nb[2], Q->nb[3], + // K->nb[1], K->nb[2], K->nb[3], + // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + // ); + // break; + case 128: + flash_attn_vec_ext_f16<128> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 256: + flash_attn_vec_ext_f16<256> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + GGML_ASSERT(false); + break; + } + CUDA_CHECK(cudaGetLastError()); + return; + } - // increase shared memory limit to 96KB - //const size_t shmem_max = 96*1024; - //cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max); + int cols_per_block; + if (Q->ne[1] >= 128 && Q->ne[0] <= 128) { + cols_per_block = 64; + } else if (Q->ne[1] >= 64) { + cols_per_block = 32; + } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { + cols_per_block = 16; + } else { + cols_per_block = 8; + } + const int frag_m = cols_per_block == 8 ? 32 : 16; + const int nwarps = (Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m; + const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); + const dim3 block_dim(WARP_SIZE, nwarps, 1); + const size_t shmem = 0; switch (Q->ne[0]) { - case 64: - flash_attn_ext_f16<64, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 80: - flash_attn_ext_f16<80, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 96: - flash_attn_ext_f16<96, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 112: - flash_attn_ext_f16<112, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 128: - flash_attn_ext_f16<128, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 256: - flash_attn_ext_f16<256, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + case 64: switch (cols_per_block) { + case 8: + flash_attn_ext_f16<64, 8> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 16: + flash_attn_ext_f16<64, 16> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 32: + flash_attn_ext_f16<64, 32> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 64: + flash_attn_ext_f16<64, 64> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; + case 80: switch (cols_per_block) { + // case 8: + // fused_attn_vec_ext_f16<80, 8> + // <<>> ( + // (const char *) Q->data, // Query + // (const char *) K->data, // Key + // (const char *) V->data, // Value + // mask ? ((const char *) mask->data) : nullptr, // Mask + // (float *) KQV->data, // dst + // scale, + // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + // K->ne[0], K->ne[1], K->ne[2], K->ne[3], + // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + // Q->nb[1], Q->nb[2], Q->nb[3], + // K->nb[1], K->nb[2], K->nb[3], + // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + // ); + // break; + case 16: + flash_attn_ext_f16<80, 16> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 32: + flash_attn_ext_f16<80, 32> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 64: + flash_attn_ext_f16<80, 64> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; + case 96: switch (cols_per_block) { + case 8: + flash_attn_ext_f16<96, 8> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 16: + flash_attn_ext_f16<96, 16> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 32: + flash_attn_ext_f16<96, 32> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 64: + flash_attn_ext_f16<96, 64> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; + case 112: switch (cols_per_block) { + // case 8: + // fused_attn_vec_ext_f16<112, 8> + // <<>> ( + // (const char *) Q->data, // Query + // (const char *) K->data, // Key + // (const char *) V->data, // Value + // mask ? ((const char *) mask->data) : nullptr, // Mask + // (float *) KQV->data, // dst + // scale, + // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + // K->ne[0], K->ne[1], K->ne[2], K->ne[3], + // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + // Q->nb[1], Q->nb[2], Q->nb[3], + // K->nb[1], K->nb[2], K->nb[3], + // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + // ); + // break; + case 16: + flash_attn_ext_f16<112, 16> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 32: + flash_attn_ext_f16<112, 32> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 64: + flash_attn_ext_f16<112, 64> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; + case 128: switch (cols_per_block) { + case 8: + flash_attn_ext_f16<128, 8> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 16: + flash_attn_ext_f16<128, 16> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 32: + flash_attn_ext_f16<128, 32> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 64: + flash_attn_ext_f16<128, 64> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; + case 256: switch (cols_per_block) { + case 8: + flash_attn_ext_f16<256, 8> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 16: + flash_attn_ext_f16<256, 16> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 32: + flash_attn_ext_f16<256, 32> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + // case 64: + // flash_attn_ext_f16<256, 64> + // <<>> ( + // (const char *) Q->data, // Query + // (const char *) K->data, // Key + // (const char *) V->data, // Value + // mask ? ((const char *) mask->data) : nullptr, // Mask + // (float *) KQV->data, // dst + // scale, + // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + // K->ne[0], K->ne[1], K->ne[2], K->ne[3], + // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + // Q->nb[1], Q->nb[2], Q->nb[3], + // K->nb[1], K->nb[2], K->nb[3], + // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + // ); + // break; + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; default: + GGML_ASSERT(false); break; } - CUDA_CHECK(cudaGetLastError()); } From d59ac670bf92f18ba9db44f37fef93b002528ac8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 30 Mar 2024 09:19:19 +0100 Subject: [PATCH 064/100] 16 cols for Phi-2 --- ggml-cuda/fattn.cu | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index ccb3c924609a4..d34924c3173e6 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -579,15 +579,17 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } - int cols_per_block; - if (Q->ne[1] >= 128 && Q->ne[0] <= 128) { - cols_per_block = 64; - } else if (Q->ne[1] >= 64) { - cols_per_block = 32; - } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { - cols_per_block = 16; - } else { - cols_per_block = 8; + int cols_per_block = 16; + if (Q->ne[0] % 32 == 0) { + if (Q->ne[1] >= 128 && Q->ne[0] <= 128) { + cols_per_block = 64; + } else if (Q->ne[1] >= 64) { + cols_per_block = 32; + } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { + cols_per_block = 16; + } else { + cols_per_block = 8; + } } const int frag_m = cols_per_block == 8 ? 32 : 16; const int nwarps = (Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m; From 81da919864831948f292aeb0a5bd11eb5868bdb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 30 Mar 2024 10:34:09 +0100 Subject: [PATCH 065/100] no vec for hs, no hs==256 ncols==32 for Volta --- ggml-cuda/common.cuh | 1 + ggml-cuda/fattn.cu | 72 ++++++++++++++++++++++---------------------- 2 files changed, 37 insertions(+), 36 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 33c8ed1da8d83..c245dd6ac009a 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -141,6 +141,7 @@ #define CC_PASCAL 600 #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products #define CC_VOLTA 700 +#define CC_AMPERE 800 #define CC_OFFSET_AMD 1000000 #define CC_RDNA1 (CC_OFFSET_AMD + 1010) #define CC_RDNA2 (CC_OFFSET_AMD + 1030) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index d34924c3173e6..43b9a9f4a3d11 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -463,29 +463,29 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst float scale; memcpy(&scale, KQV->op_params, sizeof(float)); - if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[1] == 1) { + if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[0] >= 128 && Q->ne[1] == 1) { const int nwarps = Q->ne[0] / WARP_SIZE; const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]); const dim3 block_dim(WARP_SIZE, nwarps, 1); const int shmem = 0; switch (Q->ne[0]) { - case 64: - flash_attn_vec_ext_f16<64> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + // case 64: + // flash_attn_vec_ext_f16<64> + // <<>> ( + // (const char *) Q->data, // Query + // (const char *) K->data, // Key + // (const char *) V->data, // Value + // mask ? ((const char *) mask->data) : nullptr, // Mask + // (float *) KQV->data, // dst + // scale, + // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + // K->ne[0], K->ne[1], K->ne[2], K->ne[3], + // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + // Q->nb[1], Q->nb[2], Q->nb[3], + // K->nb[1], K->nb[2], K->nb[3], + // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + // ); + // break; // case 80: // flash_attn_vec_ext_f16<80> // <<>> ( @@ -503,23 +503,23 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] // ); // break; - case 96: - flash_attn_vec_ext_f16<96> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + // case 96: + // flash_attn_vec_ext_f16<96> + // <<>> ( + // (const char *) Q->data, // Query + // (const char *) K->data, // Key + // (const char *) V->data, // Value + // mask ? ((const char *) mask->data) : nullptr, // Mask + // (float *) KQV->data, // dst + // scale, + // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + // K->ne[0], K->ne[1], K->ne[2], K->ne[3], + // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + // Q->nb[1], Q->nb[2], Q->nb[3], + // K->nb[1], K->nb[2], K->nb[3], + // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + // ); + // break; // case 112: // flash_attn_vec_ext_f16<112> // <<>> ( @@ -583,7 +583,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst if (Q->ne[0] % 32 == 0) { if (Q->ne[1] >= 128 && Q->ne[0] <= 128) { cols_per_block = 64; - } else if (Q->ne[1] >= 64) { + } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { cols_per_block = 32; } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { cols_per_block = 16; From 269374ed818dde2b267307d62be7ba59385aebfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 31 Mar 2024 16:01:27 +0200 Subject: [PATCH 066/100] adjust kernel selection logic --- ggml-cuda/fattn.cu | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 43b9a9f4a3d11..f2c46008633d9 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -579,17 +579,15 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } - int cols_per_block = 16; - if (Q->ne[0] % 32 == 0) { - if (Q->ne[1] >= 128 && Q->ne[0] <= 128) { - cols_per_block = 64; - } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { - cols_per_block = 32; - } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { - cols_per_block = 16; - } else { - cols_per_block = 8; - } + int cols_per_block; + if (Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) { + cols_per_block = 64; + } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { + cols_per_block = 32; + } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { + cols_per_block = 16; + } else { + cols_per_block = 8; } const int frag_m = cols_per_block == 8 ? 32 : 16; const int nwarps = (Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m; From cca6d027a323b071d951f702ab3ede0d1937bb6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 31 Mar 2024 18:39:02 +0200 Subject: [PATCH 067/100] 4 warps, 256 stride for all D --- ggml-cuda/fattn.cu | 633 +++++++++++---------------------------------- 1 file changed, 147 insertions(+), 486 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index f2c46008633d9..aa85244fc52ce 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -1,3 +1,4 @@ +#include "common.cuh" #include "fattn.cuh" #include @@ -176,8 +177,10 @@ static __global__ void flash_attn_vec_ext_f16( dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum; } -template // D == head size -__launch_bounds__(ncols == 8 || D > 128 ? D : 2*D, 1) +#define FATTN_KQ_STRIDE 256 + +template // D == head size, VKQ_stride == num VKQ rows calculated in parallel +__launch_bounds__(nwarps*WARP_SIZE, 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -206,6 +209,7 @@ static __global__ void flash_attn_ext_f16( const int ne2, const int ne3) { //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16; @@ -215,14 +219,13 @@ static __global__ void flash_attn_ext_f16( typedef nvcuda::wmma::fragment frag_b; typedef nvcuda::wmma::fragment frag_c; - constexpr int nwarps = (D <= 128 || ncols == 8 ? D : D/2) / frag_m; - constexpr int nthreads = nwarps*WARP_SIZE; - static_assert(nthreads % D == 0, "nthreads not divisible by D."); - constexpr int tc_vals_per_iter = nwarps*frag_m; - static_assert(D % tc_vals_per_iter == 0, "D not divisible by tensor core vals per iter."); - const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; - __builtin_assume(tid < nthreads); - constexpr int D_padded = D + 8; // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts. + constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. + constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. + static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps."); + + // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: + constexpr int D_padded = D + 8; + constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const float * Q_f = (const float *) (Q + nb02* blockIdx.y + ncols*nb01*blockIdx.x); @@ -235,31 +238,43 @@ static __global__ void flash_attn_ext_f16( frag_b Q_b[D/16][ncols/frag_n]; - __shared__ half KQ[ncols*D_padded]; // Buffer for temporarily holding tiles of KQ. + // A single buffer for temporarily holding tiles of KQ and VKQ parts: + constexpr int mem_KQ = ncols*kqs_padded; + constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; + __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; half2 * KQ2 = (half2 *) KQ; - half2 KQ_rowsum[(ncols + nwarps - 1) / nwarps] = {{0.0f, 0.0f}}; - half2 KQ_max[(ncols + nwarps - 1) / nwarps] = {{-INFINITY, -INFINITY}}; - half2 KQ_max_scale[(ncols + nwarps - 1) / nwarps] = {{0.0f, 0.0f}}; + half2 KQ_rowsum[ncols/nwarps] = {{0.0f, 0.0f}}; + half2 KQ_max[ncols/nwarps] = {{-INFINITY, -INFINITY}}; + half2 KQ_max_scale[ncols/nwarps] = {{0.0f, 0.0f}}; __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; #pragma unroll - for (int i0 = 0; i0 < ncols*D_padded/2; i0 += nthreads) { - const int i = i0 + tid; - if (i0 + nthreads > ncols*D_padded/2 && i >= ncols*D_padded/2) { - break; + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; + } + VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); } - - VKQ2[i] = make_half2(0.0f, 0.0f); } // Convert Q to half and apply scale, temporarily store in KQ: #pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nthreads/D) { - const int j = j0 + tid/D; - const int i = tid % D; - KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + } } __syncthreads(); @@ -276,31 +291,27 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); // Iterate over ne11 == previous tokens: - for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += D) { - const bool has_valid_data = 256 % D == 0 || k_VKQ_0 + frag_m*threadIdx.y < ne11; - + for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += FATTN_KQ_STRIDE) { // Calculate tile of KQ: #pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) { + for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { frag_c KQ_c[ncols/frag_n]; #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); } - if (has_valid_data) { #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { - frag_a_K K_a; - nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); + for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { + frag_a_K K_a; + nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); #pragma unroll - for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); - } + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); } } #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::store_matrix_sync(KQ + j0*D_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], D_padded, nvcuda::wmma::mem_col_major); + nvcuda::wmma::store_matrix_sync(KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); } } @@ -311,18 +322,12 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; - if (j0 + nwarps > ncols && j >= ncols) { - break; - } half2 KQ_max_new = KQ_max[j0/nwarps]; #pragma unroll - for (int k0 = 0; k0 < D/2; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - if (k0 + WARP_SIZE > D/2 && k >= D/2) { - break; - } - KQ_max_new = __hmax2(KQ_max_new, KQ2[j*(D_padded/2) + k]); + KQ_max_new = __hmax2(KQ_max_new, KQ2[j*(kqs_padded/2) + k]); } KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); KQ_max_scale[j0/nwarps] = h2exp(KQ_max[j0/nwarps] - KQ_max_new); @@ -330,20 +335,14 @@ static __global__ void flash_attn_ext_f16( half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); #pragma unroll - for (int k0 = 0; k0 < D/2; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - if (k0 + WARP_SIZE > D/2 && k >= D/2) { - break; - } - if (256 % D != 0 && k_VKQ_0 + 2*k >= ne11) { - break; - } - half2 val = KQ2[j*(D_padded/2) + k]; + half2 val = KQ2[j*(kqs_padded/2) + k]; val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); val = h2exp(val - KQ_max[j0/nwarps]); KQ_rowsum_add += val; - KQ2[j*(D_padded/2) + k] = val; + KQ2[j*(kqs_padded/2) + k] = val; } KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); @@ -353,47 +352,46 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); - frag_b KQ_b[D/16][ncols/frag_n]; + frag_b KQ_b[FATTN_KQ_STRIDE/16][ncols/frag_n]; #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { #pragma unroll - for (int k0 = 0; k0 < D; k0 += 16) { - nvcuda::wmma::load_matrix_sync(KQ_b[k0/16][j0/frag_n], KQ + j0*D_padded + k0, D_padded); + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += 16) { + nvcuda::wmma::load_matrix_sync(KQ_b[k0/16][j0/frag_n], KQ + j0*kqs_padded + k0, kqs_padded); } } - frag_c VKQ_c[D/tc_vals_per_iter][ncols/frag_n]; + frag_c VKQ_c[D/VKQ_stride][ncols/frag_n]; +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { #pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) { - #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::fill_fragment(VKQ_c[i_KQ_0/tc_vals_per_iter][j], 0.0f); + nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); } - #pragma unroll - for (int k0 = 0; k0 < D; k0 += 16) { - if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { - break; - } +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; frag_a_V v_a; - nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k0)*stride_KV + i_KQ_0 + frag_m*threadIdx.y, stride_KV); - #pragma unroll + nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); +#pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(VKQ_c[i_KQ_0/tc_vals_per_iter][j], v_a, KQ_b[k0/16][j], VKQ_c[i_KQ_0/tc_vals_per_iter][j]); + nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k/16][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); } } } __syncthreads(); + const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded); #pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) { + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { nvcuda::wmma::store_matrix_sync( - KQ + j0*D_padded + i_KQ_0 + frag_m*threadIdx.y, - VKQ_c[i_KQ_0/tc_vals_per_iter][j0/frag_n], + KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), + VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], D_padded, nvcuda::wmma::mem_col_major); } } @@ -403,16 +401,19 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; - if (j0 + nwarps > ncols && j >= ncols) { - break; - } #pragma unroll for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; if (i0 + WARP_SIZE > D/2 && i >= D/2) { break; } - VKQ2[j*(D_padded/2) + i] = KQ_max_scale[j0/nwarps]*VKQ2[j*(D_padded/2) + i] + KQ2[j*(D_padded/2) + i]; + + half2 VKQ_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int l = 0; l < VKQ_ratio; ++l) { + VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; + } + VKQ2[j*(D_padded/2) + i] = KQ_max_scale[j0/nwarps]*VKQ2[j*(D_padded/2) + i] + VKQ_add; } } @@ -422,7 +423,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; - if ((j0 + nwarps > ncols && j >= ncols) || ncols*blockIdx.x + j >= ne01) { + if (ncols*blockIdx.x + j >= ne01) { return; } const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]); @@ -437,6 +438,50 @@ static __global__ void flash_attn_ext_f16( } } +constexpr int get_max_power_of_2(int x) { + return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1; +} + +static_assert(get_max_power_of_2(1) == 1, "Test failed."); +static_assert(get_max_power_of_2(2) == 2, "Test failed."); +static_assert(get_max_power_of_2(4) == 4, "Test failed."); +static_assert(get_max_power_of_2(6) == 2, "Test failed."); + +// Number of VKQ rows calculated in parallel: +constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) { + return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m; +} + +static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed."); +static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); + +#define FATTN_SWITCH_CASE(D, ncols, nwarps) \ + case ncols: { \ + constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \ + flash_attn_ext_f16 \ + <<>> ( \ + (const char *) Q->data, \ + (const char *) K->data, \ + (const char *) V->data, \ + mask ? ((const char *) mask->data) : nullptr, \ + (float *) KQV->data, \ + scale, \ + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \ + K->ne[0], K->ne[1], K->ne[2], K->ne[3], \ + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, \ + Q->nb[1], Q->nb[2], Q->nb[3], \ + K->nb[1], K->nb[2], K->nb[3], \ + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] \ + ); \ + } \ + break; \ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; @@ -580,7 +625,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } int cols_per_block; - if (Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) { + if (false && Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) { cols_per_block = 64; } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { cols_per_block = 32; @@ -590,451 +635,67 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst cols_per_block = 8; } const int frag_m = cols_per_block == 8 ? 32 : 16; - const int nwarps = (Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m; + const int nwarps = 4; const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); const dim3 block_dim(WARP_SIZE, nwarps, 1); const size_t shmem = 0; switch (Q->ne[0]) { case 64: switch (cols_per_block) { - case 8: - flash_attn_ext_f16<64, 8> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 16: - flash_attn_ext_f16<64, 16> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 32: - flash_attn_ext_f16<64, 32> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 64: - flash_attn_ext_f16<64, 64> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + FATTN_SWITCH_CASE(64, 8, nwarps); + FATTN_SWITCH_CASE(64, 16, nwarps); + FATTN_SWITCH_CASE(64, 32, nwarps); + FATTN_SWITCH_CASE(64, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); break; } break; case 80: switch (cols_per_block) { - // case 8: - // fused_attn_vec_ext_f16<80, 8> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; - case 16: - flash_attn_ext_f16<80, 16> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 32: - flash_attn_ext_f16<80, 32> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 64: - flash_attn_ext_f16<80, 64> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + // FATTN_SWITCH_CASE(80, 8, nwarps); + FATTN_SWITCH_CASE(80, 16, nwarps); + FATTN_SWITCH_CASE(80, 32, nwarps); + // FATTN_SWITCH_CASE(80, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); break; } break; case 96: switch (cols_per_block) { - case 8: - flash_attn_ext_f16<96, 8> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 16: - flash_attn_ext_f16<96, 16> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 32: - flash_attn_ext_f16<96, 32> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 64: - flash_attn_ext_f16<96, 64> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + FATTN_SWITCH_CASE(96, 8, nwarps); + FATTN_SWITCH_CASE(96, 16, nwarps); + FATTN_SWITCH_CASE(96, 32, nwarps); + FATTN_SWITCH_CASE(96, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); break; } break; case 112: switch (cols_per_block) { - // case 8: - // fused_attn_vec_ext_f16<112, 8> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; - case 16: - flash_attn_ext_f16<112, 16> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 32: - flash_attn_ext_f16<112, 32> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 64: - flash_attn_ext_f16<112, 64> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + // FATTN_SWITCH_CASE(112, 8, nwarps); + FATTN_SWITCH_CASE(112, 16, nwarps); + FATTN_SWITCH_CASE(112, 32, nwarps); + // FATTN_SWITCH_CASE(112, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); break; } break; case 128: switch (cols_per_block) { - case 8: - flash_attn_ext_f16<128, 8> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 16: - flash_attn_ext_f16<128, 16> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 32: - flash_attn_ext_f16<128, 32> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 64: - flash_attn_ext_f16<128, 64> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + FATTN_SWITCH_CASE(128, 8, nwarps); + FATTN_SWITCH_CASE(128, 16, nwarps); + FATTN_SWITCH_CASE(128, 32, nwarps); + // FATTN_SWITCH_CASE(128, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); break; } break; case 256: switch (cols_per_block) { - case 8: - flash_attn_ext_f16<256, 8> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 16: - flash_attn_ext_f16<256, 16> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 32: - flash_attn_ext_f16<256, 32> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - // case 64: - // flash_attn_ext_f16<256, 64> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; + FATTN_SWITCH_CASE(256, 8, nwarps); + FATTN_SWITCH_CASE(256, 16, nwarps); + FATTN_SWITCH_CASE(256, 32, nwarps); + // FATTN_SWITCH_CASE(256, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); From 68d793bee816e44876b22232891ee7bab51ee5e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 1 Apr 2024 15:54:50 +0200 Subject: [PATCH 068/100] no ncols == 64 --- ggml-cuda/fattn.cu | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index aa85244fc52ce..19108044e762f 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -625,9 +625,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } int cols_per_block; - if (false && Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) { - cols_per_block = 64; - } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { + if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { cols_per_block = 32; } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { cols_per_block = 16; @@ -645,7 +643,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst FATTN_SWITCH_CASE(64, 8, nwarps); FATTN_SWITCH_CASE(64, 16, nwarps); FATTN_SWITCH_CASE(64, 32, nwarps); - FATTN_SWITCH_CASE(64, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); @@ -655,7 +652,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // FATTN_SWITCH_CASE(80, 8, nwarps); FATTN_SWITCH_CASE(80, 16, nwarps); FATTN_SWITCH_CASE(80, 32, nwarps); - // FATTN_SWITCH_CASE(80, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); @@ -665,7 +661,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst FATTN_SWITCH_CASE(96, 8, nwarps); FATTN_SWITCH_CASE(96, 16, nwarps); FATTN_SWITCH_CASE(96, 32, nwarps); - FATTN_SWITCH_CASE(96, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); @@ -675,7 +670,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // FATTN_SWITCH_CASE(112, 8, nwarps); FATTN_SWITCH_CASE(112, 16, nwarps); FATTN_SWITCH_CASE(112, 32, nwarps); - // FATTN_SWITCH_CASE(112, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); @@ -685,7 +679,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst FATTN_SWITCH_CASE(128, 8, nwarps); FATTN_SWITCH_CASE(128, 16, nwarps); FATTN_SWITCH_CASE(128, 32, nwarps); - // FATTN_SWITCH_CASE(128, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); @@ -695,7 +688,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst FATTN_SWITCH_CASE(256, 8, nwarps); FATTN_SWITCH_CASE(256, 16, nwarps); FATTN_SWITCH_CASE(256, 32, nwarps); - // FATTN_SWITCH_CASE(256, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); From 3f777acf06a0c21780e2e6b1d5b0b8e9a2fbd922 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 1 Apr 2024 16:41:56 +0200 Subject: [PATCH 069/100] Multiple parallel blocks for batch size 1 --- ggml-cuda/fattn.cu | 417 ++++++++++++++++++++++++++++++--------------- 1 file changed, 283 insertions(+), 134 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 19108044e762f..4b51f1b747cf3 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -29,14 +29,17 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX } -template // D == head size -__launch_bounds__(D, 1) +#define FATTN_KQ_STRIDE 256 + +template // D == head size +__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1) static __global__ void flash_attn_vec_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, float * __restrict__ dst, + half2 * __restrict__ dst_meta, const float scale, const int ne00, const int ne01, @@ -60,20 +63,25 @@ static __global__ void flash_attn_vec_ext_f16( const int ne3) { //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*blockIdx.x); + const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y); const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + ne31*blockIdx.x; + const half * maskh = (const half *) mask; + + if (parallel_blocks == 1) { + Q_f2 += blockIdx.x*nb01/sizeof(float2); + maskh += blockIdx.x*ne11; + } const int stride_KV = nb11 / sizeof(half); const int stride_KV2 = nb11 / sizeof(half2); - constexpr int nwarps = D/WARP_SIZE; + constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; - __builtin_assume(tid < D); + __builtin_assume(tid < nwarps*WARP_SIZE); - __shared__ half KQ[D]; - KQ[tid] = 0.0f; + __shared__ half KQ[nwarps*WARP_SIZE]; + KQ[tid] = -INFINITY; half2 * KQ2 = (half2 *) KQ; half kqmax = -INFINITY; @@ -85,7 +93,6 @@ static __global__ void flash_attn_vec_ext_f16( kqmax_shared[threadIdx.x] = -INFINITY; kqsum_shared[threadIdx.x] = 0.0f; } - __syncthreads(); // Convert Q to half2 and store in registers: @@ -102,14 +109,15 @@ static __global__ void flash_attn_vec_ext_f16( half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value. - for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += D) { + const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*D; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { // Calculate KQ tile and keep track of new maximum KQ values: half kqmax_new = kqmax; #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { const int i_KQ = i_KQ_0 + threadIdx.y; - if (256 % D != 0 && k_VKQ_0 + i_KQ >= ne11) { + if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { break; } @@ -153,19 +161,25 @@ static __global__ void flash_attn_vec_ext_f16( __syncthreads(); + if (tid < D) { #pragma unroll - for (int k0 = 0; k0 < D; k0 += 2) { - if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { - break; - } + for (int k0 = 0; k0 < D; k0 += 2) { + if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { + break; + } - half2 V_k; - reinterpret_cast(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid]; - reinterpret_cast(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid]; - VKQ += V_k*KQ2[k0/2]; + half2 V_k; + reinterpret_cast(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid]; + reinterpret_cast(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid]; + VKQ += V_k*KQ2[k0/2]; + } } } + if (tid >= D) { + kqsum = 0.0f; + } + kqsum = warp_reduce_sum(kqsum); if (threadIdx.x == 0) { kqsum_shared[threadIdx.y] = kqsum; @@ -174,12 +188,22 @@ static __global__ void flash_attn_vec_ext_f16( kqsum = kqsum_shared[threadIdx.x]; kqsum = warp_reduce_sum(kqsum); - dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum; -} + if (tid >= D) { + return; + } -#define FATTN_KQ_STRIDE 256 + if (parallel_blocks == 1) { + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum; + } else { + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)); -template // D == head size, VKQ_stride == num VKQ rows calculated in parallel + if (tid == 0) { + dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(kqmax, kqsum); + } + } +} + +template // D == head size, VKQ_stride == num VKQ rows calculated in parallel __launch_bounds__(nwarps*WARP_SIZE, 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, @@ -187,6 +211,7 @@ static __global__ void flash_attn_ext_f16( const char * __restrict__ V, const char * __restrict__ mask, float * __restrict__ dst, + half2 * __restrict__ dst_meta, const float scale, const int ne00, const int ne01, @@ -228,10 +253,15 @@ static __global__ void flash_attn_ext_f16( constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb02* blockIdx.y + ncols*nb01*blockIdx.x); + const float * Q_f = (const float *) (Q + nb02* blockIdx.y); const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half2 * mask2 = (half2 *) mask + ncols*ne11*blockIdx.x/2; + const half2 * mask2 = (half2 *) mask; + + if (parallel_blocks == 1) { + Q_f += blockIdx.x * ncols*nb01/sizeof(float); + mask2 += blockIdx.x * ncols*ne11/2; + } const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); @@ -273,7 +303,11 @@ static __global__ void flash_attn_ext_f16( if (i0 + WARP_SIZE > D && i >= D) { break; } - KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + if (parallel_blocks == 1) { + KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + } else { + KQ[j*D_padded + i] = j == 0 ? Q_f[j*stride_Q + i] * scale : 0.0f; + } } } @@ -291,7 +325,8 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); // Iterate over ne11 == previous tokens: - for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += FATTN_KQ_STRIDE) { + const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*FATTN_KQ_STRIDE; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { // Calculate tile of KQ: #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { @@ -420,22 +455,75 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); } + if (parallel_blocks == 1) { #pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - if (ncols*blockIdx.x + j >= ne01) { - return; - } - const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]); + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + if (ncols*blockIdx.x + j >= ne01) { + return; + } + const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]); #pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { - break; + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j; } - dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j; } + return; + } + +#pragma unroll + for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) { + const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x; + if (i0 + nwarps*WARP_SIZE > D && i >= D) { + return; + } + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + i] = VKQ[i]; + } + + if (threadIdx.y == 0 && threadIdx.x == 0) { + dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2( + __low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0])); + } +} + +template // D == head size +__launch_bounds__(D, 1) +static __global__ void flash_attn_combine_results( + const float * __restrict__ VKQ_parts, + const half2 * __restrict__ VKQ_meta, + float * __restrict__ dst) { + + const int tid = threadIdx.x; + __builtin_assume(tid < D); + + __shared__ half2 meta[parallel_blocks]; + if (tid < parallel_blocks) { + meta[threadIdx.x] = VKQ_meta[blockIdx.y*parallel_blocks + tid]; } + + __syncthreads(); + + half kqmax = __low2half(meta[0]); +#pragma unroll + for (int l = 1; l < parallel_blocks; ++l) { + kqmax = __hmax(kqmax, __low2half(meta[l])); + } + + float VKQ_numerator = 0.0f; + float VKQ_denominator = 0.0f; +#pragma unroll + for (int l = 0; l < parallel_blocks; ++l) { + float KQ_max_scale = hexp(__low2half(meta[l]) - kqmax); + + VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; + VKQ_denominator += KQ_max_scale * __high2float(meta[l]); + } + + dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; } constexpr int get_max_power_of_2(int x) { @@ -462,26 +550,26 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); -#define FATTN_SWITCH_CASE(D, ncols, nwarps) \ - case ncols: { \ - constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \ - flash_attn_ext_f16 \ - <<>> ( \ - (const char *) Q->data, \ - (const char *) K->data, \ - (const char *) V->data, \ - mask ? ((const char *) mask->data) : nullptr, \ - (float *) KQV->data, \ - scale, \ - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \ - K->ne[0], K->ne[1], K->ne[2], K->ne[3], \ - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, \ - Q->nb[1], Q->nb[2], Q->nb[3], \ - K->nb[1], K->nb[2], K->nb[3], \ - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] \ - ); \ - } \ - break; \ +#define FATTN_SWITCH_CASE(D, ncols, nwarps) \ + case ncols: { \ + constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \ + flash_attn_ext_f16 \ + <<>> ( \ + (const char *) Q->data, \ + (const char *) K->data, \ + (const char *) V->data, \ + mask ? ((const char *) mask->data) : nullptr, \ + (float *) KQV->data, nullptr, \ + scale, \ + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \ + K->ne[0], K->ne[1], K->ne[2], K->ne[3], \ + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, \ + Q->nb[1], Q->nb[2], Q->nb[3], \ + K->nb[1], K->nb[2], K->nb[3], \ + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] \ + ); \ + } \ + break; \ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; @@ -508,88 +596,135 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst float scale; memcpy(&scale, KQV->op_params, sizeof(float)); - if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[0] >= 128 && Q->ne[1] == 1) { - const int nwarps = Q->ne[0] / WARP_SIZE; - const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]); + if (Q->ne[1] == 1) { + constexpr int parallel_blocks = 4; + + ggml_cuda_pool_alloc dst_tmp(ctx.pool()); + ggml_cuda_pool_alloc dst_tmp_meta(ctx.pool()); + + const int nwarps = (Q->ne[0] + WARP_SIZE - 1) / WARP_SIZE; + const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); const dim3 block_dim(WARP_SIZE, nwarps, 1); const int shmem = 0; + + // Performance of the vector kernel is very bad for head sizes 80 and 112, use the tensor core kernel instead: + constexpr int nwarps_tc = 4; + constexpr dim3 block_dim_tc(WARP_SIZE, nwarps_tc, 1); + + const dim3 blocks_num_combine(1, blocks_num.y, blocks_num.z); + const dim3 block_dim_combine(Q->ne[0], 1, 1); + const int shmem_combine = 0; + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + switch (Q->ne[0]) { - // case 64: - // flash_attn_vec_ext_f16<64> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; - // case 80: - // flash_attn_vec_ext_f16<80> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; - // case 96: - // flash_attn_vec_ext_f16<96> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; - // case 112: - // flash_attn_vec_ext_f16<112> - // <<>> ( - // (const char *) Q->data, // Query - // (const char *) K->data, // Key - // (const char *) V->data, // Value - // mask ? ((const char *) mask->data) : nullptr, // Mask - // (float *) KQV->data, // dst - // scale, - // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - // K->ne[0], K->ne[1], K->ne[2], K->ne[3], - // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - // Q->nb[1], Q->nb[2], Q->nb[3], - // K->nb[1], K->nb[2], K->nb[3], - // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - // ); - // break; + case 64: + flash_attn_vec_ext_f16<64, parallel_blocks> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<64, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + break; + case 80: + flash_attn_ext_f16<80, 16, nwarps_tc, get_VKQ_stride(80, nwarps_tc, 16), parallel_blocks> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<80, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + break; + case 96: + flash_attn_vec_ext_f16<96, parallel_blocks> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<96, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + break; + case 112: + flash_attn_vec_ext_f16<112, parallel_blocks> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<112, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + break; case 128: - flash_attn_vec_ext_f16<128> + flash_attn_vec_ext_f16<128, parallel_blocks> <<>> ( (const char *) Q->data, // Query (const char *) K->data, // Key (const char *) V->data, // Value mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, scale, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], @@ -598,15 +733,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst K->nb[1], K->nb[2], K->nb[3], KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<128, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); break; case 256: - flash_attn_vec_ext_f16<256> + flash_attn_vec_ext_f16<256, parallel_blocks> <<>> ( (const char *) Q->data, // Query (const char *) K->data, // Key (const char *) V->data, // Value mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, scale, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], @@ -615,6 +757,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst K->nb[1], K->nb[2], K->nb[3], KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<256, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); break; default: GGML_ASSERT(false); @@ -633,7 +782,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst cols_per_block = 8; } const int frag_m = cols_per_block == 8 ? 32 : 16; - const int nwarps = 4; + constexpr int nwarps = 4; const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); const dim3 block_dim(WARP_SIZE, nwarps, 1); const size_t shmem = 0; From e1ecd3b1290adf0086b7b5a45fe97d22afe6a963 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 2 Apr 2024 10:27:34 +0200 Subject: [PATCH 070/100] fix compile warnings --- ggml-cuda/fattn.cu | 48 ++++++++++++++++++++++------------------------ 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 4b51f1b747cf3..d1d018dc7a7e8 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -16,18 +16,18 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } -static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); - } - return x; -#else - GGML_UNUSED(x); - NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -} +// static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { +// #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +// #pragma unroll +// for (int mask = 16; mask > 0; mask >>= 1) { +// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); +// } +// return x; +// #else +// GGML_UNUSED(x); +// NO_DEVICE_CODE; +// #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +// } #define FATTN_KQ_STRIDE 256 @@ -472,21 +472,20 @@ static __global__ void flash_attn_ext_f16( dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j; } } - return; - } - + } else { #pragma unroll - for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) { - const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x; - if (i0 + nwarps*WARP_SIZE > D && i >= D) { - return; + for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) { + const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x; + if (i0 + nwarps*WARP_SIZE > D && i >= D) { + return; + } + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + i] = VKQ[i]; } - dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + i] = VKQ[i]; - } - if (threadIdx.y == 0 && threadIdx.x == 0) { - dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2( - __low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0])); + if (threadIdx.y == 0 && threadIdx.x == 0) { + dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2( + __low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0])); + } } } @@ -781,7 +780,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } else { cols_per_block = 8; } - const int frag_m = cols_per_block == 8 ? 32 : 16; constexpr int nwarps = 4; const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); const dim3 block_dim(WARP_SIZE, nwarps, 1); From bb0d51accd7e99c151088c11f1b6774a753dc05c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 2 Apr 2024 11:13:46 +0200 Subject: [PATCH 071/100] fix excessive KQ_b loads --- ggml-cuda/fattn.cu | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index d1d018dc7a7e8..dcd2129a2cfc9 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -387,12 +387,16 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); - frag_b KQ_b[FATTN_KQ_STRIDE/16][ncols/frag_n]; + frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n]; #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += 16) { - nvcuda::wmma::load_matrix_sync(KQ_b[k0/16][j0/frag_n], KQ + j0*kqs_padded + k0, kqs_padded); + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; + nvcuda::wmma::load_matrix_sync( + KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], + KQ + j0*kqs_padded + k, + kqs_padded); } } @@ -412,7 +416,7 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k/16][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); + nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); } } } From c63dfdf765c48a0f78e162d0c02c7d69cbbc3083 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 2 Apr 2024 11:58:59 +0200 Subject: [PATCH 072/100] fix cmake build --- ggml-cuda/common.cuh | 26 ++++++++++++-------------- ggml-cuda/fattn.cu | 38 ++++++++++++-------------------------- 2 files changed, 24 insertions(+), 40 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index c245dd6ac009a..510ca6281471e 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -271,7 +271,6 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { return a; } -#ifdef GGML_CUDA_F16 static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL #pragma unroll @@ -284,7 +283,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { NO_DEVICE_CODE; #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } -#endif // GGML_CUDA_F16 static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll @@ -294,18 +292,18 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } -//static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -//#pragma unroll -// for (int mask = 16; mask > 0; mask >>= 1) { -// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); -// } -// return x; -//#else -// GGML_UNUSED(x); -// NO_DEVICE_CODE; -//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -//} +static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#else + GGML_UNUSED(x); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +} #if defined(GGML_USE_HIPBLAS) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index dcd2129a2cfc9..1d29346c7453b 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -3,32 +3,6 @@ #include -static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); - } - return a; -#else - GGML_UNUSED(a); - NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL -} - -// static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -// #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -// #pragma unroll -// for (int mask = 16; mask > 0; mask >>= 1) { -// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); -// } -// return x; -// #else -// GGML_UNUSED(x); -// NO_DEVICE_CODE; -// #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -// } - #define FATTN_KQ_STRIDE 256 template // D == head size @@ -61,6 +35,7 @@ static __global__ void flash_attn_vec_ext_f16( const int ne1, const int ne2, const int ne3) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y); @@ -201,6 +176,9 @@ static __global__ void flash_attn_vec_ext_f16( dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(kqmax, kqsum); } } +#else + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } template // D == head size, VKQ_stride == num VKQ rows calculated in parallel @@ -233,6 +211,7 @@ static __global__ void flash_attn_ext_f16( const int ne1, const int ne2, const int ne3) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA //In this kernel Q, K, V are matrices while i, j, k are matrix indices. static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); @@ -491,6 +470,9 @@ static __global__ void flash_attn_ext_f16( __low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0])); } } +#else + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA } template // D == head size @@ -499,6 +481,7 @@ static __global__ void flash_attn_combine_results( const float * __restrict__ VKQ_parts, const half2 * __restrict__ VKQ_meta, float * __restrict__ dst) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL const int tid = threadIdx.x; __builtin_assume(tid < D); @@ -527,6 +510,9 @@ static __global__ void flash_attn_combine_results( } dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; +#else + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } constexpr int get_max_power_of_2(int x) { From ee19a4ab7eba53347bde5fa3b3d38e38f7427f55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 2 Apr 2024 17:26:22 +0200 Subject: [PATCH 073/100] fix KV cache padding, NaN from INFINITY (#6438) --- ggml-cuda/fattn.cu | 15 +++++++++------ llama.cpp | 4 ++-- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 1d29346c7453b..91ef5551e025a 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -4,6 +4,7 @@ #include #define FATTN_KQ_STRIDE 256 +#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. template // D == head size __launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1) @@ -59,13 +60,13 @@ static __global__ void flash_attn_vec_ext_f16( KQ[tid] = -INFINITY; half2 * KQ2 = (half2 *) KQ; - half kqmax = -INFINITY; + half kqmax = -HALF_MAX_HALF; half kqsum = 0.0f; __shared__ half kqmax_shared[WARP_SIZE]; __shared__ half kqsum_shared[WARP_SIZE]; if (threadIdx.y == 0) { - kqmax_shared[threadIdx.x] = -INFINITY; + kqmax_shared[threadIdx.x] = -HALF_MAX_HALF; kqsum_shared[threadIdx.x] = 0.0f; } __syncthreads(); @@ -139,7 +140,7 @@ static __global__ void flash_attn_vec_ext_f16( if (tid < D) { #pragma unroll for (int k0 = 0; k0 < D; k0 += 2) { - if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { + if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) { break; } @@ -253,9 +254,9 @@ static __global__ void flash_attn_ext_f16( __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; half2 * KQ2 = (half2 *) KQ; - half2 KQ_rowsum[ncols/nwarps] = {{0.0f, 0.0f}}; - half2 KQ_max[ncols/nwarps] = {{-INFINITY, -INFINITY}}; - half2 KQ_max_scale[ncols/nwarps] = {{0.0f, 0.0f}}; + half2 KQ_rowsum[ncols/nwarps] = {{ 0.0f, 0.0f}}; + half2 KQ_max[ncols/nwarps] = {{-HALF_MAX_HALF, -HALF_MAX_HALF}}; + half2 KQ_max_scale[ncols/nwarps] = {{ 0.0f, 0.0f}}; __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; @@ -578,6 +579,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); + ggml_cuda_set_device(ctx.device); const cudaStream_t main_stream = ctx.stream(); diff --git a/llama.cpp b/llama.cpp index 9ea9886fe3b89..b50588e4467b0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -9973,7 +9973,7 @@ static int llama_decode_internal( // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min(kv_self.size, std::max(128u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128))); + kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256))); //kv_self.n = llama_kv_cache_cell_max(kv_self); } } @@ -13909,7 +13909,7 @@ struct llama_context * llama_new_context_with_model( cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; // this is necessary due to kv_self.n being padded later during inference - cparams.n_ctx = GGML_PAD(cparams.n_ctx, 32); + cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256); // with causal attention, the batch size is limited by the context size cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; From 4fbd8098e63670c6ae11a8adc350f5ba191cfda3 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 16 Apr 2024 08:13:13 +0200 Subject: [PATCH 074/100] gguf : add special tokens metadata for FIM/Infill (#6689) This commit adds special token metadata for Fill-In-the-Middle (FIM)/Infill to the GGUF model. The motivation for this is that currently there is support for CodeLlama but other models exist now like CodeGemma, but the different models use different token ids for the special tokens and this commit allows for supporting multiple models. Signed-off-by: Daniel Bevenius --- convert-hf-to-gguf.py | 15 ++++++++++ gguf-py/gguf/constants.py | 9 ++++++ gguf-py/gguf/gguf_writer.py | 12 ++++++++ llama.cpp | 58 ++++++++++++++++++++++++++++++------- 4 files changed, 83 insertions(+), 11 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index b51d683075316..6d28ab5e4919c 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1221,6 +1221,14 @@ def set_vocab(self): except FileNotFoundError: self._set_vocab_llama_hf() + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False, + special_token_types = ['prefix', 'suffix', 'middle', 'eot']) + special_vocab._set_special_token("prefix", 32007) + special_vocab._set_special_token("suffix", 32008) + special_vocab._set_special_token("middle", 32009) + special_vocab._set_special_token("eot", 32010) + special_vocab.add_to_gguf(self.gguf_writer) + def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams @@ -2240,6 +2248,13 @@ class GemmaModel(Model): def set_vocab(self): self._set_vocab_sentencepiece() + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False, + special_token_types = ['prefix', 'suffix', 'middle', 'eot']) + special_vocab._set_special_token("prefix", 67) + special_vocab._set_special_token("suffix", 69) + special_vocab._set_special_token("middle", 68) + special_vocab._set_special_token("eot", 70) + special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): hparams = self.hparams diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 2566b2fb8f296..1358206a3e569 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -90,6 +90,11 @@ class Tokenizer: HF_JSON = "tokenizer.huggingface.json" RWKV = "tokenizer.rwkv.world" CHAT_TEMPLATE = "tokenizer.chat_template" + # FIM/Infill special tokens constants + PREFIX_ID = "tokenizer.ggml.prefix_token_id" + SUFFIX_ID = "tokenizer.ggml.suffix_token_id" + MIDDLE_ID = "tokenizer.ggml.middle_token_id" + EOT_ID = "tokenizer.ggml.eot_token_id" # @@ -885,3 +890,7 @@ def get_type(val: Any) -> GGUFValueType: KEY_TOKENIZER_MASK_ID = Keys.Tokenizer.MASK_ID KEY_TOKENIZER_HF_JSON = Keys.Tokenizer.HF_JSON KEY_TOKENIZER_RWKV = Keys.Tokenizer.RWKV +KEY_TOKENIZER_PRIFIX_ID = Keys.Tokenizer.PREFIX_ID +KEY_TOKENIZER_SUFFIX_ID = Keys.Tokenizer.SUFFIX_ID +KEY_TOKENIZER_MIDDLE_ID = Keys.Tokenizer.MIDDLE_ID +KEY_TOKENIZER_EOT_ID = Keys.Tokenizer.EOT_ID diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index f4c4407667bda..ff9326d59c717 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -469,6 +469,18 @@ def add_add_space_prefix(self, value: bool) -> None: def add_chat_template(self, value: str) -> None: self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value) + def add_prefix_token_id(self, id: int) -> None: + self.add_uint32(Keys.Tokenizer.PREFIX_ID, id) + + def add_suffix_token_id(self, id: int) -> None: + self.add_uint32(Keys.Tokenizer.SUFFIX_ID, id) + + def add_middle_token_id(self, id: int) -> None: + self.add_uint32(Keys.Tokenizer.MIDDLE_ID, id) + + def add_eot_token_id(self, id: int) -> None: + self.add_uint32(Keys.Tokenizer.EOT_ID, id) + def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes: pack_prefix = '' if not skip_pack_prefix: diff --git a/llama.cpp b/llama.cpp index a5ef2fd8fa575..38e5936254e4a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -327,6 +327,10 @@ enum llm_kv { LLM_KV_TOKENIZER_ADD_PREFIX, LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_RWKV, + LLM_KV_TOKENIZER_PREFIX_ID, + LLM_KV_TOKENIZER_SUFFIX_ID, + LLM_KV_TOKENIZER_MIDDLE_ID, + LLM_KV_TOKENIZER_EOT_ID, }; static const std::map LLM_KV_NAMES = { @@ -399,6 +403,10 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, + { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, + { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, + { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, + { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, }; struct LLM_KV { @@ -2055,10 +2063,10 @@ struct llama_vocab { int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add. id linefeed_id = 13; - id special_prefix_id = 32007; - id special_middle_id = 32009; - id special_suffix_id = 32008; - id special_eot_id = 32010; + id special_prefix_id = -1; + id special_suffix_id = -1; + id special_middle_id = -1; + id special_eot_id = -1; bool add_space_prefix = true; @@ -4072,6 +4080,30 @@ static void llm_load_vocab( vocab.special_cls_id = -1; vocab.special_mask_id = -1; + // For Fill-In-the-Middle (FIM)/infill models which where converted + // prior to support of FIM special tokens in GGUF, the following + // will allow those models to continue to work. The general names + // of the known models are currently CodeLlama (LLM_ARCH_LLAMA) and + // CodeGemma (LLM_ARCH_GEMMA). This can potentially be removed once + // new versions of these models have been published. + std::string gen_name; + ml.get_key(LLM_KV_GENERAL_NAME, gen_name); + std::transform(gen_name.begin(), gen_name.end(), gen_name.begin(), + [](unsigned char c){ return std::tolower(c); }); + if (gen_name.find("code") != std::string::npos) { + if (model.arch == LLM_ARCH_LLAMA) { + vocab.special_prefix_id = 32007; + vocab.special_suffix_id = 32008; + vocab.special_middle_id = 32009; + vocab.special_eot_id = 32010; + } else if (model.arch == LLM_ARCH_GEMMA) { + vocab.special_prefix_id = 67; + vocab.special_suffix_id = 69; + vocab.special_middle_id = 68; + vocab.special_eot_id = 70; + } + } + const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str()); if (add_space_prefix_keyidx != -1) { vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx); @@ -4185,13 +4217,17 @@ static void llm_load_vocab( // special tokens { const std::vector> special_token_types = { - { LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id }, - { LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id }, - { LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id }, - { LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id }, - { LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id }, - { LLM_KV_TOKENIZER_CLS_ID, vocab.special_cls_id }, - { LLM_KV_TOKENIZER_MASK_ID, vocab.special_mask_id }, + { LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id }, + { LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id }, + { LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id }, + { LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id }, + { LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id }, + { LLM_KV_TOKENIZER_CLS_ID, vocab.special_cls_id }, + { LLM_KV_TOKENIZER_MASK_ID, vocab.special_mask_id }, + { LLM_KV_TOKENIZER_PREFIX_ID, vocab.special_prefix_id }, + { LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_suffix_id }, + { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id }, + { LLM_KV_TOKENIZER_EOT_ID, vocab.special_eot_id }, }; for (const auto & it : special_token_types) { const std::string & key = kv(std::get<0>(it)); From 58227ffdeb4fb89cacb0cffaadf76b1914324ad3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 16 Apr 2024 09:28:33 +0300 Subject: [PATCH 075/100] perplexity : require positive --ctx-size arg (#6695) --- examples/perplexity/perplexity.cpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index bab79aaea89ca..9a3374fdc13f2 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -1852,12 +1852,20 @@ int main(int argc, char ** argv) { const int32_t n_ctx = params.n_ctx; + if (n_ctx <= 0) { + fprintf(stderr, "%s: perplexity tool requires '--ctx-size' > 0\n", __func__); + return 1; + } + const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence; + if (ppl) { - int n_seq = std::max(1, params.n_batch / n_ctx); - int32_t n_kv = n_seq * n_ctx; + const int32_t n_seq = std::max(1, params.n_batch / n_ctx); + const int32_t n_kv = n_seq * n_ctx; + params.n_parallel = n_seq; - params.n_ctx = n_kv; + params.n_ctx = n_kv; + params.n_batch = std::min(params.n_batch, n_kv); } else { params.n_batch = std::min(params.n_batch, params.n_ctx); From 8a56075b07a8b571bf95a912ffdce4c928c2b414 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 16 Apr 2024 08:34:06 +0200 Subject: [PATCH 076/100] gritlm : add --outdir option to hf.sh script (#6699) This commit updates the hf.sh script usage to include the --outdir option and specifies the models directory as the output directory. The motivation for this is to avoid cluttering the root directory with model files. Signed-off-by: Daniel Bevenius --- examples/gritlm/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/gritlm/README.md b/examples/gritlm/README.md index 64cc192040c43..a3a3c1389ca27 100644 --- a/examples/gritlm/README.md +++ b/examples/gritlm/README.md @@ -21,12 +21,12 @@ not have to be performed at all. ### Running the example Download a Grit model: ```console -$ scripts/hf.sh --repo cohesionet/GritLM-7B_gguf --file gritlm-7b_q4_1.gguf +$ scripts/hf.sh --repo cohesionet/GritLM-7B_gguf --file gritlm-7b_q4_1.gguf --outdir models ``` Run the example using the downloaded model: ```console -$ ./gritlm -m gritlm-7b_q4_1.gguf +$ ./gritlm -m models/gritlm-7b_q4_1.gguf Cosine similarity between "Bitcoin: A Peer-to-Peer Electronic Cash System" and "A purely peer-to-peer version of electronic cash w" is: 0.605 Cosine similarity between "Bitcoin: A Peer-to-Peer Electronic Cash System" and "All text-based language problems can be reduced to" is: 0.103 From f4dea7da1841a92d2788b0535063abf2f0e28461 Mon Sep 17 00:00:00 2001 From: Shijie <821898965@qq.com> Date: Tue, 16 Apr 2024 23:40:48 +0800 Subject: [PATCH 077/100] llama : add qwen2moe (#6074) * support qwen2moe * fix-review * metal : support unary ops for nelements % 4 != 0 * metal : require contiguousness for float4 unary kernels * metal : require contiguousness for float4 unary kernels (cont) * fix-review * names : for brevity "SHARED_EXP" -> "SHEXP" * llama : reuse build_moe_ffn() * llama : add model type name --------- Co-authored-by: Georgi Gerganov --- convert-hf-to-gguf.py | 99 +++++++++++++ ggml-metal.m | 57 ++++++-- ggml-metal.metal | 26 ++++ gguf-py/gguf/constants.py | 169 +++++++++++++--------- gguf-py/gguf/tensor_mapping.py | 34 ++++- llama.cpp | 252 +++++++++++++++++++++++++++++++-- tests/test-backend-ops.cpp | 1 + 7 files changed, 537 insertions(+), 101 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 6d28ab5e4919c..a93b0666c6b86 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1700,6 +1700,105 @@ class Qwen2Model(Model): model_arch = gguf.MODEL_ARCH.QWEN2 +@Model.register("Qwen2MoeForCausalLM") +class Qwen2MoeModel(Model): + model_arch = gguf.MODEL_ARCH.QWEN2MOE + + def set_gguf_parameters(self): + super().set_gguf_parameters() + if (n_experts := self.hparams.get("num_experts")) is not None: + self.gguf_writer.add_expert_count(n_experts) + + def write_tensors(self): + block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + n_experts = self.hparams.get("num_experts") + experts = dict() + for name, data_torch in self.get_tensors(): + # we don't need these + if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")): + continue + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + + # process the experts separately + if name.find("experts") != -1: + experts[name] = data + if len(experts) >= n_experts * 3: + # merge the experts into a single 3d tensor + for bid in range(block_count): + for w_name in ["down_proj", "gate_proj", "up_proj"]: + full = True + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + if ename not in experts: + full = False + break + if not full: + continue + + datas = [] + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(experts[ename]) + del experts[ename] + + data = np.stack(datas, axis=0) + data_dtype = data.dtype + + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + if self.ftype == 1 and data_dtype == np.float32: + data = data.astype(np.float16) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + continue + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and (n_dims == 1 or new_name.endswith("_norm.weight")): + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, shape = {data.shape}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts.keys()}") + + @Model.register("GPT2LMHeadModel") class GPT2Model(Model): model_arch = gguf.MODEL_ARCH.GPT2 diff --git a/ggml-metal.m b/ggml-metal.m index 0207b787a0255..ae6ddeacd347f 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -41,8 +41,11 @@ GGML_METAL_KERNEL_TYPE_TANH, GGML_METAL_KERNEL_TYPE_RELU, GGML_METAL_KERNEL_TYPE_GELU, + GGML_METAL_KERNEL_TYPE_GELU_4, GGML_METAL_KERNEL_TYPE_GELU_QUICK, + GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, GGML_METAL_KERNEL_TYPE_SILU, + GGML_METAL_KERNEL_TYPE_SILU_4, GGML_METAL_KERNEL_TYPE_SOFT_MAX, GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, @@ -473,8 +476,11 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); @@ -1178,6 +1184,9 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_UNARY: switch (ggml_get_unary_op(gf->nodes[i])) { + // we are not taking into account the strides, so for now require contiguous tensors + GGML_ASSERT(ggml_is_contiguous(src0)); + case GGML_UNARY_OP_TANH: { id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline; @@ -1204,42 +1213,60 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_UNARY_OP_GELU: { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; + } [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = ggml_nelements(dst); - GGML_ASSERT(n % 4 == 0); - - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_UNARY_OP_GELU_QUICK: { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; + } [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = ggml_nelements(dst); - GGML_ASSERT(n % 4 == 0); - - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_UNARY_OP_SILU: { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; + } [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = ggml_nelements(dst); - GGML_ASSERT(n % 4 == 0); - - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; default: { diff --git a/ggml-metal.metal b/ggml-metal.metal index 56748166ca4ca..82a8cad93c8bc 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -242,6 +242,15 @@ constant float GELU_QUICK_COEF = -1.702f; constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; kernel void kernel_gelu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -255,6 +264,15 @@ kernel void kernel_gelu( } kernel void kernel_gelu_quick( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_gelu_quick_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -264,6 +282,14 @@ kernel void kernel_gelu_quick( } kernel void kernel_silu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = x / (1.0f + exp(-x)); +} + +kernel void kernel_silu_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1358206a3e569..df861164fe2f1 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -120,6 +120,7 @@ class MODEL_ARCH(IntEnum): STABLELM = auto() QWEN = auto() QWEN2 = auto() + QWEN2MOE = auto() PHI2 = auto() PLAMO = auto() CODESHELL = auto() @@ -135,41 +136,45 @@ class MODEL_ARCH(IntEnum): class MODEL_TENSOR(IntEnum): - TOKEN_EMBD = auto() - TOKEN_EMBD_NORM = auto() - TOKEN_TYPES = auto() - POS_EMBD = auto() - OUTPUT = auto() - OUTPUT_NORM = auto() - ROPE_FREQS = auto() - ATTN_Q = auto() - ATTN_K = auto() - ATTN_V = auto() - ATTN_QKV = auto() - ATTN_OUT = auto() - ATTN_NORM = auto() - ATTN_NORM_2 = auto() - ATTN_OUT_NORM = auto() - ATTN_ROT_EMBD = auto() - FFN_GATE_INP = auto() - FFN_NORM = auto() - FFN_GATE = auto() - FFN_DOWN = auto() - FFN_UP = auto() - FFN_ACT = auto() - FFN_GATE_EXP = auto() - FFN_DOWN_EXP = auto() - FFN_UP_EXP = auto() - ATTN_Q_NORM = auto() - ATTN_K_NORM = auto() - LAYER_OUT_NORM = auto() - SSM_IN = auto() - SSM_CONV1D = auto() - SSM_X = auto() - SSM_DT = auto() - SSM_A = auto() - SSM_D = auto() - SSM_OUT = auto() + TOKEN_EMBD = auto() + TOKEN_EMBD_NORM = auto() + TOKEN_TYPES = auto() + POS_EMBD = auto() + OUTPUT = auto() + OUTPUT_NORM = auto() + ROPE_FREQS = auto() + ATTN_Q = auto() + ATTN_K = auto() + ATTN_V = auto() + ATTN_QKV = auto() + ATTN_OUT = auto() + ATTN_NORM = auto() + ATTN_NORM_2 = auto() + ATTN_OUT_NORM = auto() + ATTN_ROT_EMBD = auto() + FFN_GATE_INP = auto() + FFN_GATE_INP_SHEXP = auto() + FFN_NORM = auto() + FFN_GATE = auto() + FFN_DOWN = auto() + FFN_UP = auto() + FFN_ACT = auto() + FFN_GATE_EXP = auto() + FFN_DOWN_EXP = auto() + FFN_UP_EXP = auto() + FFN_GATE_SHEXP = auto() + FFN_DOWN_SHEXP = auto() + FFN_UP_SHEXP = auto() + ATTN_Q_NORM = auto() + ATTN_K_NORM = auto() + LAYER_OUT_NORM = auto() + SSM_IN = auto() + SSM_CONV1D = auto() + SSM_X = auto() + SSM_DT = auto() + SSM_A = auto() + SSM_D = auto() + SSM_OUT = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -190,6 +195,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.STABLELM: "stablelm", MODEL_ARCH.QWEN: "qwen", MODEL_ARCH.QWEN2: "qwen2", + MODEL_ARCH.QWEN2MOE: "qwen2moe", MODEL_ARCH.PHI2: "phi2", MODEL_ARCH.PLAMO: "plamo", MODEL_ARCH.CODESHELL: "codeshell", @@ -205,41 +211,45 @@ class MODEL_TENSOR(IntEnum): } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { - MODEL_TENSOR.TOKEN_EMBD: "token_embd", - MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm", - MODEL_TENSOR.TOKEN_TYPES: "token_types", - MODEL_TENSOR.POS_EMBD: "position_embd", - MODEL_TENSOR.OUTPUT_NORM: "output_norm", - MODEL_TENSOR.OUTPUT: "output", - MODEL_TENSOR.ROPE_FREQS: "rope_freqs", - MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", - MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2", - MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", - MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q", - MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k", - MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v", - MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", - MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", - MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", - MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", - MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm", - MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp", - MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", - MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", - MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", - MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", - MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn", - MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", - MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", - MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", - MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", - MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", - MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", - MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x", - MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", - MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", - MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", - MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", + MODEL_TENSOR.TOKEN_EMBD: "token_embd", + MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm", + MODEL_TENSOR.TOKEN_TYPES: "token_types", + MODEL_TENSOR.POS_EMBD: "position_embd", + MODEL_TENSOR.OUTPUT_NORM: "output_norm", + MODEL_TENSOR.OUTPUT: "output", + MODEL_TENSOR.ROPE_FREQS: "rope_freqs", + MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", + MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2", + MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", + MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q", + MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k", + MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v", + MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", + MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", + MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", + MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", + MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm", + MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp", + MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp", + MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", + MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", + MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", + MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", + MODEL_TENSOR.FFN_GATE_SHEXP: "blk.{bid}.ffn_gate_shexp", + MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp", + MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp", + MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn", + MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", + MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", + MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", + MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", + MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", + MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", + MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x", + MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", + MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", + MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", + MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -474,6 +484,25 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.QWEN2MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_INP_SHEXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], MODEL_ARCH.PLAMO: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index ec6fcbb838425..10de36fa883ee 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -208,10 +208,15 @@ class TensorNameMap: MODEL_TENSOR.FFN_GATE_INP: ( "layers.{bid}.feed_forward.gate", # mixtral "model.layers.{bid}.block_sparse_moe.gate", # mixtral + "model.layers.{bid}.mlp.gate", # qwen2moe "transformer.decoder_layer.{bid}.router", # Grok "transformer.blocks.{bid}.ffn.router.layer", # dbrx ), + MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( + "model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe + ), + # Feed-forward up MODEL_TENSOR.FFN_UP: ( "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox @@ -236,9 +241,14 @@ class TensorNameMap: ), MODEL_TENSOR.FFN_UP_EXP: ( - "layers.{bid}.feed_forward.experts.w3", # mixtral (merged) - "transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged) - "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx + "layers.{bid}.feed_forward.experts.w3", # mixtral (merged) + "transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged) + "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx + "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe (merged) + ), + + MODEL_TENSOR.FFN_UP_SHEXP: ( + "model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe ), # AWQ-activation gate @@ -260,6 +270,11 @@ class TensorNameMap: "layers.{bid}.feed_forward.experts.w1", # mixtral (merged) "transformer.decoder_layer.{bid}.moe.linear", # Grok (merged) "transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx + "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe (merged) + ), + + MODEL_TENSOR.FFN_GATE_SHEXP: ( + "model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe ), # Feed-forward down @@ -285,9 +300,14 @@ class TensorNameMap: ), MODEL_TENSOR.FFN_DOWN_EXP: ( - "layers.{bid}.feed_forward.experts.w2", # mixtral (merged) - "transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged) - "transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx + "layers.{bid}.feed_forward.experts.w2", # mixtral (merged) + "transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged) + "transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx + "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe (merged) + ), + + MODEL_TENSOR.FFN_DOWN_SHEXP: ( + "model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe ), MODEL_TENSOR.ATTN_Q_NORM: ( @@ -366,7 +386,7 @@ def __init__(self, arch: MODEL_ARCH, n_blocks: int): if tensor not in MODEL_TENSORS[arch]: continue # TODO: make this configurable - n_experts = 8 + n_experts = 60 for xid in range(n_experts): tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid) self.mapping[tensor_name] = (tensor, tensor_name) diff --git a/llama.cpp b/llama.cpp index 38e5936254e4a..340e68fde3ac9 100644 --- a/llama.cpp +++ b/llama.cpp @@ -105,7 +105,7 @@ #endif #define LLAMA_MAX_NODES 8192 -#define LLAMA_MAX_EXPERTS 16 +#define LLAMA_MAX_EXPERTS 60 // @@ -209,6 +209,7 @@ enum llm_arch { LLM_ARCH_STABLELM, LLM_ARCH_QWEN, LLM_ARCH_QWEN2, + LLM_ARCH_QWEN2MOE, LLM_ARCH_PHI2, LLM_ARCH_PLAMO, LLM_ARCH_CODESHELL, @@ -242,6 +243,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_QWEN, "qwen" }, { LLM_ARCH_QWEN2, "qwen2" }, + { LLM_ARCH_QWEN2MOE, "qwen2moe" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PLAMO, "plamo" }, { LLM_ARCH_CODESHELL, "codeshell" }, @@ -437,6 +439,7 @@ enum llm_tensor { LLM_TENSOR_ATTN_OUT_NORM, LLM_TENSOR_ATTN_ROT_EMBD, LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_NORM, LLM_TENSOR_FFN_GATE, LLM_TENSOR_FFN_DOWN, @@ -448,6 +451,9 @@ enum llm_tensor { LLM_TENSOR_FFN_DOWN_EXPS, // merged experts LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_LAYER_OUT_NORM, @@ -745,6 +751,28 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_QWEN2MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, { LLM_ARCH_PHI2, { @@ -1731,6 +1759,7 @@ enum e_model { MODEL_MEDIUM, MODEL_LARGE, MODEL_XL, + MODEL_A2_7B, MODEL_8x7B, MODEL_8x22B, MODEL_16x12B, @@ -1917,6 +1946,12 @@ struct llama_layer { struct ggml_tensor * ffn_down_exps; struct ggml_tensor * ffn_up_exps ; + // ff shared expert (shexp) + struct ggml_tensor * ffn_gate_inp_shexp; + struct ggml_tensor * ffn_gate_shexp; + struct ggml_tensor * ffn_down_shexp; + struct ggml_tensor * ffn_up_shexp; + // ff bias struct ggml_tensor * ffn_down_b; // b2 struct ggml_tensor * ffn_up_b; // b3 @@ -3587,6 +3622,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_MEDIUM: return "0.4B"; case MODEL_LARGE: return "0.8B"; case MODEL_XL: return "1.5B"; + case MODEL_A2_7B: return "A2.7B"; case MODEL_8x7B: return "8x7B"; case MODEL_8x22B: return "8x22B"; case MODEL_16x12B: return "16x12B"; @@ -3886,6 +3922,14 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_QWEN2MOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_A2_7B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_PHI2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -5156,6 +5200,54 @@ static bool llm_load_tensors( layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; + case LLM_ARCH_QWEN2MOE: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + // optional bias tensors + layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); + layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + + GGML_ASSERT(hparams.n_expert > 0); + GGML_ASSERT(hparams.n_expert_used > 0); + + // MoE branch + auto n_ff_exp = n_ff / hparams.n_expert_used; + layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}); + layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + + // Shared expert branch + layer.ffn_gate_inp_shexp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}); + layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff}); + layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff, n_embd}); + layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff}); + } + } break; case LLM_ARCH_PHI2: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -6532,7 +6624,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, il); + cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, true, il); } cur = ggml_add(ctx0, cur, ffn_inp); @@ -6565,7 +6657,7 @@ struct llm_build_context { } // REVIEW: will be replaced by https://github.com/ggerganov/llama.cpp/pull/6505 - ggml_tensor * build_moe_ffn(ggml_tensor * cur, int32_t n_tokens, llm_ffn_op_type type_op, int il) { + ggml_tensor * build_moe_ffn(ggml_tensor * cur, int32_t n_tokens, llm_ffn_op_type type_op, bool norm_w, int il) { ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts] cb(logits, "ffn_moe_logits", il); @@ -6582,11 +6674,13 @@ struct llm_build_context { weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok] - ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); - cb(weights_sum, "ffn_moe_weights_sum", il); + if (norm_w) { + ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); + cb(weights_sum, "ffn_moe_weights_sum", il); - weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok] - cb(weights, "ffn_moe_weights_norm", il); + weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok] + cb(weights, "ffn_moe_weights_norm", il); + } // compute expert outputs ggml_tensor * moe_out = nullptr; @@ -7083,7 +7177,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = build_moe_ffn(cur, n_tokens, LLM_FFN_GELU, il); + cur = build_moe_ffn(cur, n_tokens, LLM_FFN_GELU, true, il); // Grok // if layer_out_norm is present then apply it before adding the input @@ -7219,7 +7313,7 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "attn_out_norm", il); - cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, il); + cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, true, il); cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); @@ -8434,6 +8528,141 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_qwen2moe() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + + Qcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, false, il); + + // FFN shared expert + { + ggml_tensor * cur_gate_inp = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp_shexp, cur); + cb(cur_gate_inp, "ffn_shexp_gate_inp", il); + + // sigmoid + ggml_tensor * cur_gate = ggml_div(ctx0, ggml_silu(ctx0, cur_gate_inp), cur_gate_inp); + cb(cur_gate, "ffn_shexp_gate", il); + + ggml_tensor * cur_ffn = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up_shexp, NULL, + model.layers[il].ffn_gate_shexp, NULL, + model.layers[il].ffn_down_shexp, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur_ffn, "ffn_shexp", il); + + ggml_tensor * ffn_shexp_out = ggml_mul(ctx0, cur_ffn, cur_gate); + cb(ffn_shexp_out, "ffn_shexp_out", il); + + moe_out = ggml_add(ctx0, moe_out, ffn_shexp_out); + cb(moe_out, "ffn_out", il); + + cur = moe_out; + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + struct ggml_cgraph * build_phi2() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -9917,6 +10146,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_qwen2(); } break; + case LLM_ARCH_QWEN2MOE: + { + result = llm.build_qwen2moe(); + } break; case LLM_ARCH_PHI2: { result = llm.build_phi2(); @@ -14834,6 +15067,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_STABLELM: case LLM_ARCH_QWEN: case LLM_ARCH_QWEN2: + case LLM_ARCH_QWEN2MOE: case LLM_ARCH_PHI2: case LLM_ARCH_GEMMA: case LLM_ARCH_STARCODER2: diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b50675952b51d..21adba42e3148 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1878,6 +1878,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op // unary ops for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) { test_cases.emplace_back(new test_unary((ggml_unary_op) op)); + test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 7, 13, 19, 23 })); } test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false)); From dbceec87c0221ec952e69448df6a71f1372a7487 Mon Sep 17 00:00:00 2001 From: Ashish <1856117+ashishdatta@users.noreply.github.com> Date: Tue, 16 Apr 2024 08:48:35 -0700 Subject: [PATCH 078/100] llama : add StableLM2 12B (#6635) * StableLM2 12B support for huggingface -> GGUF * StableLM12 tensormapping and constants * StableLM-2-12b model support * fix * Added 12B support * Removed autoformatting; resolved bug where model_arch was not selecting StableLM2 * Formatting * Do QK norm stacking in model conversion step * Converge StableLM and StableLM2 code to simplify graph construction * Fix accidental removal * Removed warnings * Revert formatter * Move QK norm stack to private function so it's easier to read * refactor stablelm graph builder to support 1.6, 3b and 12b more efficiently * Proper check for None type for new_name to avoid crash; formatting; revert change to base class `write_tensors()` * Format * Formatting * format Co-authored-by: compilade * Fix incorrect check for K norm * space after commas; Keep indentation multiple of 4 spaces * Flake8 format * Removed unnecessary conditional branches * Removed unused comment * Fixed incorrect tensor passing * Format --------- Co-authored-by: compilade --- convert-hf-to-gguf.py | 82 +++++++++++++++++++++++++++++++++++++++ gguf-py/gguf/constants.py | 2 + llama.cpp | 62 +++++++++++++++++++++++------ 3 files changed, 134 insertions(+), 12 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index a93b0666c6b86..f321d77de11f8 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1207,9 +1207,91 @@ def set_gguf_parameters(self): rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"]) self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"]))) self.gguf_writer.add_head_count(hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"]) self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True) self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_eps", "norm_eps"])) + def write_tensors(self): + block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + n_head = self.hparams.get("num_attention_heads") + n_kv_head = self.hparams.get("num_key_value_heads") + q_norms = dict() + k_norms = dict() + for name, data_torch in self.get_tensors(): + # we don't need these + if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")): + continue + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + n_dims = len(data.shape) + if name.find("q_layernorm.norms") != -1: + q_norms[name] = data + if len(q_norms) >= (block_count * n_head): + self._stack_qk_norm(block_count, name, tensor_map, n_head, q_norms, n_dims, layer_name="q_layernorm") + continue + if name.find("k_layernorm.norms") != -1: + k_norms[name] = data + if len(k_norms) >= (block_count * n_kv_head): + self._stack_qk_norm(block_count, name, tensor_map, n_kv_head, k_norms, n_dims, layer_name="k_layernorm") + continue + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and (n_dims == 1 or new_name.endswith("_norm.weight")): + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and not new_name.endswith("_norm.weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + def _stack_qk_norm(self, block_count, name, tensor_map, n_head, norms, n_dims, layer_name="q_layernorm"): + for bid in range(block_count): + datas = [] + for xid in range(n_head): + ename = f"model.layers.{bid}.self_attn.{layer_name}.norms.{xid}.weight" + datas.append(norms[ename]) + del norms[ename] + data = np.stack(datas, axis=0) + data_dtype = data.dtype + merged_name = f"model.layers.{bid}.self_attn.{layer_name}.weight" + new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + if self.ftype == 1 and data_dtype == np.float16 and (n_dims == 1 or new_name.endswith("_norm.weight")): + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and not new_name.endswith("_norm.weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + @Model.register("LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM") class LlamaModel(Model): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index df861164fe2f1..4b0b6c4c6d008 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -455,6 +455,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, ], MODEL_ARCH.QWEN: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/llama.cpp b/llama.cpp index 340e68fde3ac9..579986d1a5e2e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -716,6 +716,8 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, }, }, { @@ -1744,6 +1746,7 @@ enum e_model { MODEL_4B, MODEL_7B, MODEL_8B, + MODEL_12B, MODEL_13B, MODEL_14B, MODEL_15B, @@ -3607,6 +3610,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_3B: return "3B"; case MODEL_7B: return "7B"; case MODEL_8B: return "8B"; + case MODEL_12B: return "12B"; case MODEL_13B: return "13B"; case MODEL_14B: return "14B"; case MODEL_15B: return "15B"; @@ -3898,6 +3902,7 @@ static void llm_load_hparams( switch (hparams.n_layer) { case 24: model.type = e_model::MODEL_1B; break; case 32: model.type = e_model::MODEL_3B; break; + case 40: model.type = e_model::MODEL_12B; break; default: model.type = e_model::MODEL_UNKNOWN; } } break; @@ -5128,8 +5133,13 @@ static bool llm_load_tensors( layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, false); layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, false); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + // optional q and k layernorms, present in StableLM 2 12B + layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head}, false); + layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv}, false); + + // optional FFN norm, not present in StableLM 2 12B which uses parallel residual + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, false); + layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, false); layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); @@ -8197,7 +8207,7 @@ struct llm_build_context { struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * inpSA = inpL; + // norm cur = llm_build_norm(ctx0, inpL, hparams, @@ -8206,6 +8216,8 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "attn_norm", il); + struct ggml_tensor * inpSA = cur; + // self-attention { // compute Q and K and RoPE them @@ -8230,15 +8242,36 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + cb(Qcur, "Qcur", il); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + cb(Kcur, "Kcur", il); + + if (model.layers[il].attn_q_norm) { + Qcur = llm_build_norm(ctx0, Qcur, hparams, + model.layers[il].attn_q_norm, + NULL, + LLM_NORM, cb, il); + cb(Qcur, "Qcur", il); + } + if (model.layers[il].attn_k_norm) { + Kcur = llm_build_norm(ctx0, Kcur, hparams, + model.layers[il].attn_k_norm, + NULL, + LLM_NORM, cb, il); + cb(Kcur, "Kcur", il); + } + + Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -8253,20 +8286,25 @@ struct llm_build_context { // skip computing output for unused tokens struct ggml_tensor * inp_out_ids = build_inp_out_ids(); cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); // feed-forward network { - cur = llm_build_norm(ctx0, ffn_inp, hparams, - model.layers[il].ffn_norm, - model.layers[il].ffn_norm_b, - LLM_NORM, cb, il); - cb(cur, "ffn_norm", il); - + if (model.layers[il].ffn_norm) { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + } else { + // parallel residual + cur = inpSA; + } cur = llm_build_ffn(ctx0, cur, model.layers[il].ffn_up, NULL, model.layers[il].ffn_gate, NULL, From 8cc91dc63c0df397d644a581b2cbeea74eb51ae0 Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Tue, 16 Apr 2024 14:55:30 -0400 Subject: [PATCH 079/100] ggml : add llamafile sgemm (#6414) This change upstreams llamafile's cpu matrix multiplication kernels which improve image and prompt evaluation speed. For starters, Q4_0 and Q8_0 weights should go ~40% faster on CPU. The biggest benefits are with data types like f16 / f32, which process prompts 2x faster thus making them faster than quantized data types for prompt evals. This change also introduces bona fide AVX512 support since tinyBLAS is able to exploit the larger register file. For example, on my CPU llama.cpp llava-cli processes an image prompt at 305 tokens/second, using the Q4_K and Q4_0 types, which has always been faster than if we used f16 LLaVA weights, which at HEAD go 188 tokens/second. With this change, f16 LLaVA performance leap frogs to 464 tokens/second. On Intel Core i9-14900K this change improves F16 prompt perf by 5x. For example, using llama.cpp at HEAD with Mistral 7b f16 to process a 215 token prompt will go 13 tok/sec. This change has fixes making it go 52 tok/sec. It's mostly thanks to my vectorized outer product kernels but also because I added support for correctly counting the number of cores on Alderlake, so the default thread count discounts Intel's new efficiency cores. Only Linux right now can count cores. This work was sponsored by Mozilla who's given permission to change the license of this code from Apache 2.0 to MIT. To read more about what's improved, and how it works, see: https://justine.lol/matmul/ --- CMakeLists.txt | 2 + Makefile | 10 +- Package.swift | 1 + build.zig | 15 +- common/common.cpp | 73 ++ common/common.h | 3 +- examples/llama-bench/llama-bench.cpp | 2 +- ggml-impl.h | 2 +- ggml-quants.c | 2 +- ggml.c | 54 ++ sgemm.cpp | 1148 ++++++++++++++++++++++++++ sgemm.h | 12 + 12 files changed, 1312 insertions(+), 12 deletions(-) create mode 100644 sgemm.cpp create mode 100644 sgemm.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 19fdfa46ca4f1..158174c2094d2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1151,6 +1151,8 @@ add_library(ggml OBJECT ggml-backend.h ggml-quants.c ggml-quants.h + sgemm.cpp + sgemm.h ${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA} ${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL} ${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL} diff --git a/Makefile b/Makefile index 8f3e17da4c57b..928fb14cad628 100644 --- a/Makefile +++ b/Makefile @@ -219,6 +219,11 @@ ifdef LLAMA_DISABLE_LOGS MK_CPPFLAGS += -DLOG_DISABLE_LOGS endif # LLAMA_DISABLE_LOGS +# disable ggml.c's use of sgemm.cpp +ifdef LLAMA_NO_LLAMAFILE + MK_CPPFLAGS += -DGGML_USE_LLAMAFILE=0 +endif + # warnings WARN_FLAGS = -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function MK_CFLAGS += $(WARN_FLAGS) -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -Werror=implicit-int \ @@ -676,13 +681,16 @@ ggml-backend.o: ggml-backend.c ggml.h ggml-backend.h ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h ggml-common.h $(CC) $(CFLAGS) -c $< -o $@ +sgemm.o: sgemm.cpp sgemm.h ggml.h + $(CXX) $(CXXFLAGS) -c $< -o $@ + unicode.o: unicode.cpp unicode.h $(CXX) $(CXXFLAGS) -c $< -o $@ unicode-data.o: unicode-data.cpp unicode-data.h $(CXX) $(CXXFLAGS) -c $< -o $@ -OBJS += ggml-alloc.o ggml-backend.o ggml-quants.o unicode.o unicode-data.o +OBJS += ggml-alloc.o ggml-backend.o ggml-quants.o unicode.o unicode-data.o sgemm.o llama.o: llama.cpp unicode.h ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h llama.h $(CXX) $(CXXFLAGS) -c $< -o $@ diff --git a/Package.swift b/Package.swift index fbd0973be7f62..183e647575b42 100644 --- a/Package.swift +++ b/Package.swift @@ -4,6 +4,7 @@ import PackageDescription var sources = [ "ggml.c", + "sgemm.cpp", "llama.cpp", "unicode.cpp", "unicode-data.cpp", diff --git a/build.zig b/build.zig index e05ca2120ba4c..c35e801f88129 100644 --- a/build.zig +++ b/build.zig @@ -112,6 +112,7 @@ pub fn build(b: *std.build.Builder) !void { make.enable_lto = b.option(bool, "lto", "Enable LTO optimization, (default: false)") orelse false; const ggml = make.obj("ggml", "ggml.c"); + const sgemm = make.obj("sgemm", "sgemm.cpp"); const ggml_alloc = make.obj("ggml-alloc", "ggml-alloc.c"); const ggml_backend = make.obj("ggml-backend", "ggml-backend.c"); const ggml_quants = make.obj("ggml-quants", "ggml-quants.c"); @@ -128,14 +129,14 @@ pub fn build(b: *std.build.Builder) !void { const clip = make.obj("clip", "examples/llava/clip.cpp"); const llava = make.obj("llava", "examples/llava/llava.cpp"); - _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, console, grammar_parser }); - _ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo }); - _ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo }); - _ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo }); - _ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train }); - _ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train }); + _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, console, grammar_parser }); + _ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo }); + _ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo }); + _ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo }); + _ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train }); + _ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train }); - const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, grammar_parser, clip, llava }); + const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, grammar_parser, clip, llava }); if (server.target.isWindows()) { server.linkSystemLibrary("ws2_32"); } diff --git a/common/common.cpp b/common/common.cpp index 52576cba37bdd..cf69535e2d1f5 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -108,6 +108,79 @@ int32_t get_num_physical_cores() { return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; } +#if defined(__x86_64__) && defined(__linux__) +#include + +static void cpuid(unsigned leaf, unsigned subleaf, + unsigned *eax, unsigned *ebx, unsigned *ecx, unsigned *edx) { + __asm__("movq\t%%rbx,%%rsi\n\t" + "cpuid\n\t" + "xchgq\t%%rbx,%%rsi" + : "=a"(*eax), "=S"(*ebx), "=c"(*ecx), "=d"(*edx) + : "0"(leaf), "2"(subleaf)); +} + +static int pin_cpu(int cpu) { + cpu_set_t mask; + CPU_ZERO(&mask); + CPU_SET(cpu, &mask); + return pthread_setaffinity_np(pthread_self(), sizeof(mask), &mask); +} + +static bool is_hybrid_cpu(void) { + unsigned eax, ebx, ecx, edx; + cpuid(7, 0, &eax, &ebx, &ecx, &edx); + return !!(edx & (1u << 15)); +} + +static bool is_running_on_efficiency_core(void) { + unsigned eax, ebx, ecx, edx; + cpuid(0x1a, 0, &eax, &ebx, &ecx, &edx); + int intel_atom = 0x20; + int core_type = (eax & 0xff000000u) >> 24; + return core_type == intel_atom; +} + +static int count_math_cpus(int cpu_count) { + int result = 0; + for (int cpu = 0; cpu < cpu_count; ++cpu) { + if (pin_cpu(cpu)) { + return -1; + } + if (is_running_on_efficiency_core()) { + continue; // efficiency cores harm lockstep threading + } + ++cpu; // hyperthreading isn't useful for linear algebra + ++result; + } + return result; +} + +#endif // __x86_64__ && __linux__ + +/** + * Returns number of CPUs on system that are useful for math. + */ +int get_math_cpu_count() { +#if defined(__x86_64__) && defined(__linux__) + int cpu_count = sysconf(_SC_NPROCESSORS_ONLN); + if (cpu_count < 1) { + return get_num_physical_cores(); + } + if (is_hybrid_cpu()) { + cpu_set_t affinity; + if (!pthread_getaffinity_np(pthread_self(), sizeof(affinity), &affinity)) { + int result = count_math_cpus(cpu_count); + pthread_setaffinity_np(pthread_self(), sizeof(affinity), &affinity); + if (result > 0) { + return result; + } + } + } +#endif + return get_num_physical_cores(); +} + void process_escapes(std::string & input) { std::size_t input_len = input.length(); std::size_t output_idx = 0; diff --git a/common/common.h b/common/common.h index 65272b0baaa41..cca44268e6df5 100644 --- a/common/common.h +++ b/common/common.h @@ -39,6 +39,7 @@ extern char const *LLAMA_BUILD_TARGET; struct llama_control_vector_load_info; +int get_math_cpu_count(); int32_t get_num_physical_cores(); // @@ -48,7 +49,7 @@ int32_t get_num_physical_cores(); struct gpt_params { uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed - int32_t n_threads = get_num_physical_cores(); + int32_t n_threads = get_math_cpu_count(); int32_t n_threads_draft = -1; int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) int32_t n_threads_batch_draft = -1; diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 27e113203f1ea..8b532c8b6a98a 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -190,7 +190,7 @@ static const cmd_params cmd_params_defaults = { /* n_ubatch */ {512}, /* type_k */ {GGML_TYPE_F16}, /* type_v */ {GGML_TYPE_F16}, - /* n_threads */ {get_num_physical_cores()}, + /* n_threads */ {get_math_cpu_count()}, /* n_gpu_layers */ {99}, /* split_mode */ {LLAMA_SPLIT_MODE_LAYER}, /* main_gpu */ {0}, diff --git a/ggml-impl.h b/ggml-impl.h index e68b728775c41..0c997d3ed521f 100644 --- a/ggml-impl.h +++ b/ggml-impl.h @@ -88,7 +88,7 @@ typedef uint16_t ggml_fp16_internal_t; #if defined(_MSC_VER) || defined(__MINGW32__) #include #else -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__) #if !defined(__riscv) #include #endif diff --git a/ggml-quants.c b/ggml-quants.c index 32e84434a8c1b..32360a1f1be50 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -132,7 +132,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) { } static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { -#if defined(__AVXVNNI__) || defined(__AVX512VNNI__) +#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) const __m256i zero = _mm256_setzero_si256(); const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); return _mm256_cvtepi32_ps(summed_pairs); diff --git a/ggml.c b/ggml.c index 14288d29ded6b..119686be68d9c 100644 --- a/ggml.c +++ b/ggml.c @@ -4,6 +4,7 @@ #include "ggml-impl.h" #include "ggml-quants.h" #include "ggml.h" +#include "sgemm.h" #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW @@ -32,6 +33,14 @@ #include #endif +#ifndef GGML_USE_LLAMAFILE +#ifdef __ARM_FEATURE_MATMUL_INT8 +#define GGML_USE_LLAMAFILE 0 +#else +#define GGML_USE_LLAMAFILE 1 +#endif +#endif + #if defined(_MSC_VER) // disable "possible loss of data" to avoid hundreds of casts // we should just be careful :) @@ -10810,6 +10819,28 @@ static void ggml_compute_forward_mul_mat( } #endif +#if GGML_USE_LLAMAFILE + if (nb10 == ggml_type_size(src1->type)) { + for (int64_t i13 = 0; i13 < ne13; i13++) + for (int64_t i12 = 0; i12 < ne12; i12++) + if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), + (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, + nb01/ggml_type_size(src0->type), + (const char *)src1->data + i12*nb12 + i13*nb13, + nb11/ggml_type_size(src1->type), + (char *)dst->data + i12*nb2 + i13*nb3, + nb1/ggml_type_size(dst->type), + ith, nth, + params->type, + src0->type, + src1->type, + dst->type)) + goto UseGgmlGemm1; + return; + } +UseGgmlGemm1:; +#endif + if (params->type == GGML_TASK_TYPE_INIT) { if (ith != 0) { return; @@ -10841,6 +10872,29 @@ static void ggml_compute_forward_mul_mat( const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); +#if GGML_USE_LLAMAFILE + if (nb10 == ggml_type_size(src1->type) || src1->type != vec_dot_type) { + for (int64_t i13 = 0; i13 < ne13; i13++) + for (int64_t i12 = 0; i12 < ne12; i12++) + if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), + (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, + nb01/ggml_type_size(src0->type), + (const char *)wdata + (nb12/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i12 + + nb13/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i13), + row_size/ggml_type_size(vec_dot_type), + (char *)dst->data + i12*nb2 + i13*nb3, + nb1/ggml_type_size(dst->type), + ith, nth, + params->type, + src0->type, + vec_dot_type, + dst->type)) + goto UseGgmlGemm2; + return; + } +UseGgmlGemm2:; +#endif + const int64_t nr0 = ne01; // src0 rows const int64_t nr1 = ne1*ne12*ne13; // src1 rows diff --git a/sgemm.cpp b/sgemm.cpp new file mode 100644 index 0000000000000..6900f04cfb242 --- /dev/null +++ b/sgemm.cpp @@ -0,0 +1,1148 @@ +// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- +// vi: set et ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi +// +// Copyright 2024 Mozilla Foundation +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS +// BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// +// _ _ ___ _ _ ___ +// | |_(_)_ _ _ _| _ ) | /_\ / __| +// | _| | ' \ || | _ \ |__ / _ \\__ \. +// \__|_|_||_\_, |___/____/_/ \_\___/ +// |__/ +// +// BASIC LINEAR ALGEBRA SUBPROGRAMS +// +// +// This file implements multithreaded CPU matrix multiplication for the +// common contiguous use case C = Aᵀ * B. These kernels are designed to +// have excellent performance[1] for matrices that fit in the CPU cache +// without imposing any overhead such as cache filling or malloc calls. +// +// This implementation does not guarantee any upper bound with rounding +// errors, which grow along with k. Our goal's to maximally exploit the +// hardware for performance, and then use whatever resources remain for +// improving numerical accuracy. +// +// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online]. +// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024]. + +#pragma GCC diagnostic ignored "-Wpedantic" +#pragma GCC diagnostic ignored "-Wignored-attributes" + +#include "sgemm.h" +#include "ggml-impl.h" +#include "ggml-quants.h" + +#ifdef _MSC_VER +#define NOINLINE __declspec(noinline) +#else +#define NOINLINE __attribute__((__noinline__)) +#endif + +#if defined(__ARM_NEON) || defined(__AVX512F__) +#define VECTOR_REGISTERS 32 +#else +#define VECTOR_REGISTERS 16 +#endif + +// there will be blocks +#define BEGIN_KERNEL(RM, RN) \ + int ytiles = (m - m0) / RM; \ + int xtiles = (n - n0) / RN; \ + int tiles = ytiles * xtiles; \ + int duty = (tiles + nth - 1) / nth; \ + int start = duty * ith; \ + int end = start + duty; \ + if (end > tiles) \ + end = tiles; \ + for (int job = start; job < end; ++job) { \ + int i = m0 + job / xtiles * RM; \ + int j = n0 + job % xtiles * RN; + +#define END_KERNEL() } + +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) + +namespace { + +inline float unhalf(ggml_fp16_t d) { + return GGML_FP16_TO_FP32(d); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// VECTORIZED ARITHMETIC OPERATIONS + +#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); } +inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); } +inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); } +#endif // __SSE__ + +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); } +inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); } +inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); } +#endif // __AVX__ + +#if defined(__AVX512F__) +inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); } +inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); } +inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); } +#endif // __AVX512F__ + +#if defined(__ARM_NEON) +inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); } +inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); } +inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); } +#endif // __ARM_NEON + +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); } +inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); } +inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// VECTORIZED HORIZONTAL SUM + +#if defined(__ARM_NEON) +inline float hsum(float32x4_t x) { + return vaddvq_f32(x); +} +#endif // __ARM_NEON + +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) +inline float hsum(float16x8_t x) { + return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)), + vcvt_f32_f16(vget_high_f16(x)))); +} +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +inline float hsum(__m128 x) { +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) + x = _mm_add_ps(x, _mm_movehl_ps(x, x)); + x = _mm_add_ss(x, _mm_movehdup_ps(x)); +#else + __m128 t; + t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1)); + x = _mm_add_ps(x, t); + t = _mm_movehl_ps(t, x); + x = _mm_add_ss(x, t); +#endif + return _mm_cvtss_f32(x); +} +#endif + +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +inline float hsum(__m256 x) { + return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1), + _mm256_castps256_ps128(x))); +} +#endif // __AVX__ + +#if defined(__AVX512F__) +inline float hsum(__m512 x) { + return _mm512_reduce_add_ps(x); +} +#endif // __AVX512F__ + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// VECTORIZED MEMORY LOADING + +template T load(const U *); + +#if defined(__ARM_NEON) +template <> inline float32x4_t load(const float *p) { + return vld1q_f32(p); +} +#if !defined(_MSC_VER) +template <> inline float16x8_t load(const ggml_fp16_t *p) { + return vld1q_f16((const float16_t *)p); +} +template <> inline float32x4_t load(const ggml_fp16_t *p) { + return vcvt_f32_f16(vld1_f16((const float16_t *)p)); +} +#endif // _MSC_VER +#endif // __ARM_NEON + +#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +template <> inline __m128 load(const float *p) { + return _mm_loadu_ps(p); +} +#endif // __SSE__ + +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +template <> inline __m256 load(const float *p) { + return _mm256_loadu_ps(p); +} +#endif // __AVX__ + +#if defined(__F16C__) +template <> inline __m256 load(const ggml_fp16_t *p) { + return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p)); +} +#endif // __F16C__ + +#if defined(__AVX512F__) +template <> inline __m512 load(const float *p) { + return _mm512_loadu_ps(p); +} +template <> inline __m512 load(const ggml_fp16_t *p) { + return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p)); +} +#endif // __AVX512F__ + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// ABSTRACTIONS + +/** + * Computes a * b + c. + * + * This operation will become fused into a single arithmetic instruction + * if the hardware has support for this feature, e.g. Intel Haswell+ (c. + * 2013), AMD Bulldozer+ (c. 2011), etc. + */ +template +inline U madd(T a, T b, U c) { + return add(mul(a, b), c); +} + +/** + * Computes a * b + c with error correction. + * + * @see W. Kahan, "Further remarks on reducing truncation errors," + * Communications of the ACM, vol. 8, no. 1, p. 40, Jan. 1965, + * doi: 10.1145/363707.363723. + */ +template +inline U madder(T a, T b, U c, U *e) { + U y = sub(mul(a, b), *e); + U t = add(c, y); + *e = sub(sub(t, c), y); + return t; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// FLOATING POINT MATRIX MULTIPLICATION + +template +class tinyBLAS { + public: + tinyBLAS(int k, + const TA *A, int lda, + const TB *B, int ldb, + TC *C, int ldc, + int ith, int nth) + : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + } + + void matmul(int m, int n, int task) { + if (task == GGML_TASK_TYPE_COMPUTE) + mnpack(0, m, 0, n); + } + + private: + NOINLINE void mnpack(int m0, int m, int n0, int n) { + int mc, nc, mp, np; + if (m - m0 <= 0 || n - n0 <= 0) + return; + if (VECTOR_REGISTERS >= 32 && n - n0 >= 5 && m - m0 >= 5) { + mc = 5; + nc = 5; + gemm5x5(m0, m, n0, n); + } else if (n - n0 >= 4 && m - m0 >= 3) { + mc = 3; + nc = 4; + gemm3x4(m0, m, n0, n); + } else if (n - n0 >= 4) { + mc = 1; + nc = 4; + gemm1x4(m0, m, n0, n); + } else if (m - m0 >= 4) { + mc = 4; + nc = 1; + gemm4x1(m0, m, n0, n); + } else { + mc = 1; + nc = 1; + gemm1x1(m0, m, n0, n); + } + mp = m0 + (m - m0) / mc * mc; + np = n0 + (n - n0) / nc * nc; + mnpack(mp, m, n0, np); + mnpack(m0, mp, np, n); + mnpack(mp, m, np, n); + } + + NOINLINE void gemm5x5(int m0, int m, int n0, int n) { + BEGIN_KERNEL(5, 5) + D c00 = {0}; + D c01 = {0}; + D c02 = {0}; + D c03 = {0}; + D c04 = {0}; + D c10 = {0}; + D c11 = {0}; + D c12 = {0}; + D c13 = {0}; + D c14 = {0}; + D c20 = {0}; + D c21 = {0}; + D c22 = {0}; + D c23 = {0}; + D c24 = {0}; + D c30 = {0}; + D c31 = {0}; + D c32 = {0}; + D c33 = {0}; + D c34 = {0}; + D c40 = {0}; + D c41 = {0}; + D c42 = {0}; + D c43 = {0}; + D c44 = {0}; + for (int l = 0; l < k; l += KN) { + V k0 = load(B + ldb * (j + 0) + l); + V k1 = load(B + ldb * (j + 1) + l); + V k2 = load(B + ldb * (j + 2) + l); + V k3 = load(B + ldb * (j + 3) + l); + V k4 = load(B + ldb * (j + 4) + l); + V a0 = load(A + lda * (i + 0) + l); + c00 = madd(a0, k0, c00); + c01 = madd(a0, k1, c01); + c02 = madd(a0, k2, c02); + c03 = madd(a0, k3, c03); + c04 = madd(a0, k4, c04); + V a1 = load(A + lda * (i + 1) + l); + c10 = madd(a1, k0, c10); + c11 = madd(a1, k1, c11); + c12 = madd(a1, k2, c12); + c13 = madd(a1, k3, c13); + c14 = madd(a1, k4, c14); + V a2 = load(A + lda * (i + 2) + l); + c20 = madd(a2, k0, c20); + c21 = madd(a2, k1, c21); + c22 = madd(a2, k2, c22); + c23 = madd(a2, k3, c23); + c24 = madd(a2, k4, c24); + V a3 = load(A + lda * (i + 3) + l); + c30 = madd(a3, k0, c30); + c31 = madd(a3, k1, c31); + c32 = madd(a3, k2, c32); + c33 = madd(a3, k3, c33); + c34 = madd(a3, k4, c34); + V a4 = load(A + lda * (i + 4) + l); + c40 = madd(a4, k0, c40); + c41 = madd(a4, k1, c41); + c42 = madd(a4, k2, c42); + c43 = madd(a4, k3, c43); + c44 = madd(a4, k4, c44); + } + C[ldc * (j + 0) + (i + 0)] = hsum(c00); + C[ldc * (j + 0) + (i + 1)] = hsum(c10); + C[ldc * (j + 0) + (i + 2)] = hsum(c20); + C[ldc * (j + 0) + (i + 3)] = hsum(c30); + C[ldc * (j + 0) + (i + 4)] = hsum(c40); + C[ldc * (j + 1) + (i + 0)] = hsum(c01); + C[ldc * (j + 1) + (i + 1)] = hsum(c11); + C[ldc * (j + 1) + (i + 2)] = hsum(c21); + C[ldc * (j + 1) + (i + 3)] = hsum(c31); + C[ldc * (j + 1) + (i + 4)] = hsum(c41); + C[ldc * (j + 2) + (i + 0)] = hsum(c02); + C[ldc * (j + 2) + (i + 1)] = hsum(c12); + C[ldc * (j + 2) + (i + 2)] = hsum(c22); + C[ldc * (j + 2) + (i + 3)] = hsum(c32); + C[ldc * (j + 2) + (i + 4)] = hsum(c42); + C[ldc * (j + 3) + (i + 0)] = hsum(c03); + C[ldc * (j + 3) + (i + 1)] = hsum(c13); + C[ldc * (j + 3) + (i + 2)] = hsum(c23); + C[ldc * (j + 3) + (i + 3)] = hsum(c33); + C[ldc * (j + 3) + (i + 4)] = hsum(c43); + C[ldc * (j + 4) + (i + 0)] = hsum(c04); + C[ldc * (j + 4) + (i + 1)] = hsum(c14); + C[ldc * (j + 4) + (i + 2)] = hsum(c24); + C[ldc * (j + 4) + (i + 3)] = hsum(c34); + C[ldc * (j + 4) + (i + 4)] = hsum(c44); + END_KERNEL() + } + + NOINLINE void gemm3x4(int m0, int m, int n0, int n) { + BEGIN_KERNEL(3, 4) + D c00 = {0}; + D c01 = {0}; + D c02 = {0}; + D c03 = {0}; + D c10 = {0}; + D c11 = {0}; + D c12 = {0}; + D c13 = {0}; + D c20 = {0}; + D c21 = {0}; + D c22 = {0}; + D c23 = {0}; + for (int l = 0; l < k; l += KN) { + V k0 = load(B + ldb * (j + 0) + l); + V k1 = load(B + ldb * (j + 1) + l); + V k2 = load(B + ldb * (j + 2) + l); + V k3 = load(B + ldb * (j + 3) + l); + V a0 = load(A + lda * (i + 0) + l); + c00 = madd(a0, k0, c00); + c01 = madd(a0, k1, c01); + c02 = madd(a0, k2, c02); + c03 = madd(a0, k3, c03); + V a1 = load(A + lda * (i + 1) + l); + c10 = madd(a1, k0, c10); + c11 = madd(a1, k1, c11); + c12 = madd(a1, k2, c12); + c13 = madd(a1, k3, c13); + V a2 = load(A + lda * (i + 2) + l); + c20 = madd(a2, k0, c20); + c21 = madd(a2, k1, c21); + c22 = madd(a2, k2, c22); + c23 = madd(a2, k3, c23); + } + C[ldc * (j + 0) + (i + 0)] = hsum(c00); + C[ldc * (j + 0) + (i + 1)] = hsum(c10); + C[ldc * (j + 0) + (i + 2)] = hsum(c20); + C[ldc * (j + 1) + (i + 0)] = hsum(c01); + C[ldc * (j + 1) + (i + 1)] = hsum(c11); + C[ldc * (j + 1) + (i + 2)] = hsum(c21); + C[ldc * (j + 2) + (i + 0)] = hsum(c02); + C[ldc * (j + 2) + (i + 1)] = hsum(c12); + C[ldc * (j + 2) + (i + 2)] = hsum(c22); + C[ldc * (j + 3) + (i + 0)] = hsum(c03); + C[ldc * (j + 3) + (i + 1)] = hsum(c13); + C[ldc * (j + 3) + (i + 2)] = hsum(c23); + END_KERNEL() + } + + NOINLINE void gemm1x4(int m0, int m, int n0, int n) { + BEGIN_KERNEL(1, 4) + D c00 = {0}, e00 = {0}; + D c01 = {0}, e01 = {0}; + D c02 = {0}, e02 = {0}; + D c03 = {0}, e03 = {0}; + for (int l = 0; l < k; l += KN) { + V a = load(A + lda * (i + 0) + l); + c00 = madder(a, load(B + ldb * (j + 0) + l), c00, &e00); + c01 = madder(a, load(B + ldb * (j + 1) + l), c01, &e01); + c02 = madder(a, load(B + ldb * (j + 2) + l), c02, &e02); + c03 = madder(a, load(B + ldb * (j + 3) + l), c03, &e03); + } + C[ldc * (j + 0) + (i + 0)] = hsum(c00); + C[ldc * (j + 1) + (i + 0)] = hsum(c01); + C[ldc * (j + 2) + (i + 0)] = hsum(c02); + C[ldc * (j + 3) + (i + 0)] = hsum(c03); + END_KERNEL() + } + + NOINLINE void gemm4x1(int m0, int m, int n0, int n) { + BEGIN_KERNEL(4, 1) + D c00 = {0}, e00 = {0}; + D c10 = {0}, e10 = {0}; + D c20 = {0}, e20 = {0}; + D c30 = {0}, e30 = {0}; + for (int l = 0; l < k; l += KN) { + V b = load(B + ldb * (j + 0) + l); + c00 = madder(load(A + lda * (i + 0) + l), b, c00, &e00); + c10 = madder(load(A + lda * (i + 1) + l), b, c10, &e10); + c20 = madder(load(A + lda * (i + 2) + l), b, c20, &e20); + c30 = madder(load(A + lda * (i + 3) + l), b, c30, &e30); + } + C[ldc * (j + 0) + (i + 0)] = hsum(c00); + C[ldc * (j + 0) + (i + 1)] = hsum(c10); + C[ldc * (j + 0) + (i + 2)] = hsum(c20); + C[ldc * (j + 0) + (i + 3)] = hsum(c30); + END_KERNEL() + } + + NOINLINE void gemm1x1(int m0, int m, int n0, int n) { + BEGIN_KERNEL(1, 1) + D c = {0}, e = {0}; + for (int l = 0; l < k; l += KN) + c = madder(load(A + lda * i + l), + load(B + ldb * j + l), c, &e); + C[ldc * j + i] = hsum(c); + END_KERNEL() + } + + const TA *const A; + const TB *const B; + TC *const C; + const int k; + const int lda; + const int ldb; + const int ldc; + const int ith; + const int nth; +}; + +////////////////////////////////////////////////////////////////////////////////////////// +// QUANT ZERO MATRIX MULTIPLICATION + +#if defined(__ARM_FEATURE_DOTPROD) +template +class tinyBLAS_Q0_ARM { + public: + tinyBLAS_Q0_ARM(int k, + const TA *A, int lda, + const block_q8_0 *B, int ldb, + float *C, int ldc, + int ith, int nth) + : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + } + + void matmul(int m, int n, int task) { + if (task == GGML_TASK_TYPE_COMPUTE) + mnpack(0, m, 0, n); + } + + private: + NOINLINE void mnpack(int m0, int m, int n0, int n) { + int mc, nc, mp, np; + if (m - m0 <= 0 || n - n0 <= 0) + return; + if (m - m0 >= 3 && n - n0 >= 3) { + mc = 3; + nc = 3; + gemm3x3(m0, m, n0, n); + } else { + mc = 1; + nc = 1; + gemm1x1(m0, m, n0, n); + } + mp = m0 + (m - m0) / mc * mc; + np = n0 + (n - n0) / nc * nc; + mnpack(mp, m, n0, np); + mnpack(m0, mp, np, n); + mnpack(mp, m, np, n); + } + + NOINLINE void gemm3x3(int m0, int m, int n0, int n) { + BEGIN_KERNEL(3, 3) + int32x4_t zero = vdupq_n_s32(0); + float32x4_t c00 = vdupq_n_f32(0.f); + float32x4_t c01 = vdupq_n_f32(0.f); + float32x4_t c02 = vdupq_n_f32(0.f); + float32x4_t c10 = vdupq_n_f32(0.f); + float32x4_t c11 = vdupq_n_f32(0.f); + float32x4_t c12 = vdupq_n_f32(0.f); + float32x4_t c20 = vdupq_n_f32(0.f); + float32x4_t c21 = vdupq_n_f32(0.f); + float32x4_t c22 = vdupq_n_f32(0.f); + const TA *Ap0 = A + lda * (i + 0); + const TA *Ap1 = A + lda * (i + 1); + const TA *Ap2 = A + lda * (i + 2); + const block_q8_0 *Bp0 = B + ldb * (j + 0); + const block_q8_0 *Bp1 = B + ldb * (j + 1); + const block_q8_0 *Bp2 = B + ldb * (j + 2); + for (int l = 0; l < k; ++l) { + c00 = vmlaq_n_f32( + c00, + vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap0 + l), load_lo(Bp0 + l)), + load_hi(Ap0 + l), load_hi(Bp0 + l))), + unhalf(Ap0[l].d) * unhalf(Bp0[l].d)); + c01 = vmlaq_n_f32( + c01, + vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap0 + l), load_lo(Bp1 + l)), + load_hi(Ap0 + l), load_hi(Bp1 + l))), + unhalf(Ap0[l].d) * unhalf(Bp1[l].d)); + c02 = vmlaq_n_f32( + c02, + vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap0 + l), load_lo(Bp2 + l)), + load_hi(Ap0 + l), load_hi(Bp2 + l))), + unhalf(Ap0[l].d) * unhalf(Bp2[l].d)); + c10 = vmlaq_n_f32( + c10, + vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp0 + l)), + load_hi(Ap1 + l), load_hi(Bp0 + l))), + unhalf(Ap1[l].d) * unhalf(Bp0[l].d)); + c11 = vmlaq_n_f32( + c11, + vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp1 + l)), + load_hi(Ap1 + l), load_hi(Bp1 + l))), + unhalf(Ap1[l].d) * unhalf(Bp1[l].d)); + c12 = vmlaq_n_f32( + c12, + vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp2 + l)), + load_hi(Ap1 + l), load_hi(Bp2 + l))), + unhalf(Ap1[l].d) * unhalf(Bp2[l].d)); + c20 = vmlaq_n_f32( + c20, + vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp0 + l)), + load_hi(Ap2 + l), load_hi(Bp0 + l))), + unhalf(Ap2[l].d) * unhalf(Bp0[l].d)); + c21 = vmlaq_n_f32( + c21, + vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp1 + l)), + load_hi(Ap2 + l), load_hi(Bp1 + l))), + unhalf(Ap2[l].d) * unhalf(Bp1[l].d)); + c22 = vmlaq_n_f32( + c22, + vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp2 + l)), + load_hi(Ap2 + l), load_hi(Bp2 + l))), + unhalf(Ap2[l].d) * unhalf(Bp2[l].d)); + } + C[ldc * (j + 0) + (i + 0)] = hsum(c00); + C[ldc * (j + 0) + (i + 1)] = hsum(c10); + C[ldc * (j + 0) + (i + 2)] = hsum(c20); + C[ldc * (j + 1) + (i + 0)] = hsum(c01); + C[ldc * (j + 1) + (i + 1)] = hsum(c11); + C[ldc * (j + 1) + (i + 2)] = hsum(c21); + C[ldc * (j + 2) + (i + 0)] = hsum(c02); + C[ldc * (j + 2) + (i + 1)] = hsum(c12); + C[ldc * (j + 2) + (i + 2)] = hsum(c22); + END_KERNEL() + } + + NOINLINE void gemm1x1(int m0, int m, int n0, int n) { + BEGIN_KERNEL(1, 1) + float32x4_t acc = vdupq_n_f32(0.f); + const TA *Ap = A + lda * i; + const block_q8_0 *Bp = B + ldb * j; + for (int l = 0; l < k; ++l) { + acc = vmlaq_n_f32(acc, + vcvtq_f32_s32(vdotq_s32( + vdotq_s32(vdupq_n_s32(0), load_lo(Ap + l), load_lo(Bp + l)), + load_hi(Ap + l), load_hi(Bp + l))), + unhalf(Ap[l].d) * unhalf(Bp[l].d)); + } + C[ldc * j + i] = hsum(acc); + END_KERNEL() + } + + inline int8x16_t load_lo(const block_q8_0 *b) { + return vld1q_s8(b->qs); + } + inline int8x16_t load_hi(const block_q8_0 *b) { + return vld1q_s8(b->qs + 16); + } + + inline int8x16_t load_lo(const block_q4_0 *b) { + return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs), + vdupq_n_u8(0x0f))), + vdupq_n_s8(0x8)); + } + inline int8x16_t load_hi(const block_q4_0 *b) { + return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)), + vdupq_n_s8(0x8)); + } + + const TA *const A; + const block_q8_0 *const B; + float *const C; + const int k; + const int lda; + const int ldb; + const int ldc; + const int ith; + const int nth; +}; +#endif // __ARM_FEATURE_DOTPROD + +#if defined(__AVX2__) || defined(__AVX512F__) +template +class tinyBLAS_Q0_AVX2 { + public: + tinyBLAS_Q0_AVX2(int k, + const TA *A, int lda, + const TB *B, int ldb, + TC *C, int ldc, + int ith, int nth) + : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + } + + void matmul(int m, int n, int task) { + if (task == GGML_TASK_TYPE_COMPUTE) + mnpack(0, m, 0, n); + } + + private: + NOINLINE void mnpack(int m0, int m, int n0, int n) { + int mc, nc, mp, np; + if (m - m0 <= 0 || n - n0 <= 0) + return; + if (m - m0 >= 4 && n - n0 >= 3) { + mc = 4; + nc = 3; + gemm4x3(m0, m, n0, n); + } else if (m - m0 >= 4 && n - n0 >= 1) { + mc = 4; + nc = 1; + gemm4x1(m0, m, n0, n); + } else if (m - m0 >= 1 && n - n0 >= 4) { + mc = 1; + nc = 4; + gemm1x4(m0, m, n0, n); + } else { + mc = 1; + nc = 1; + gemm1x1(m0, m, n0, n); + } + mp = m0 + (m - m0) / mc * mc; + np = n0 + (n - n0) / nc * nc; + mnpack(mp, m, n0, np); + mnpack(m0, mp, np, n); + mnpack(mp, m, np, n); + } + + NOINLINE void gemm4x3(int m0, int m, int n0, int n) { + BEGIN_KERNEL(4, 3) + __m256 c00 = _mm256_setzero_ps(); + __m256 c10 = _mm256_setzero_ps(); + __m256 c20 = _mm256_setzero_ps(); + __m256 c30 = _mm256_setzero_ps(); + __m256 c01 = _mm256_setzero_ps(); + __m256 c11 = _mm256_setzero_ps(); + __m256 c21 = _mm256_setzero_ps(); + __m256 c31 = _mm256_setzero_ps(); + __m256 c02 = _mm256_setzero_ps(); + __m256 c12 = _mm256_setzero_ps(); + __m256 c22 = _mm256_setzero_ps(); + __m256 c32 = _mm256_setzero_ps(); + const TA *Ap0 = A + lda * (i + 0); + const TA *Ap1 = A + lda * (i + 1); + const TA *Ap2 = A + lda * (i + 2); + const TA *Ap3 = A + lda * (i + 3); + const TB *Bp0 = B + ldb * (j + 0); + const TB *Bp1 = B + ldb * (j + 1); + const TB *Bp2 = B + ldb * (j + 2); + for (int l = 0; l < k; ++l) { + float da0 = unhalf(Ap0[l].d); + float da1 = unhalf(Ap1[l].d); + float da2 = unhalf(Ap2[l].d); + float da3 = unhalf(Ap3[l].d); + __m256i e0 = load(Ap0 + l); + __m256i e1 = load(Ap1 + l); + __m256i e2 = load(Ap2 + l); + __m256i e3 = load(Ap3 + l); + float db0 = unhalf(Bp0[l].d); + __m256 d00 = _mm256_set1_ps(da0 * db0); + __m256 d10 = _mm256_set1_ps(da1 * db0); + __m256 d20 = _mm256_set1_ps(da2 * db0); + __m256 d30 = _mm256_set1_ps(da3 * db0); + __m256i f0 = load(Bp0 + l); + __m256i u0 = _mm256_sign_epi8(f0, f0); + __m256i s00 = _mm256_sign_epi8(e0, f0); + __m256i s10 = _mm256_sign_epi8(e1, f0); + __m256i s20 = _mm256_sign_epi8(e2, f0); + __m256i s30 = _mm256_sign_epi8(e3, f0); + c00 = madd(d00, updot(u0, s00), c00); + c10 = madd(d10, updot(u0, s10), c10); + c20 = madd(d20, updot(u0, s20), c20); + c30 = madd(d30, updot(u0, s30), c30); + float db1 = unhalf(Bp1[l].d); + __m256 d01 = _mm256_set1_ps(da0 * db1); + __m256 d11 = _mm256_set1_ps(da1 * db1); + __m256 d21 = _mm256_set1_ps(da2 * db1); + __m256 d31 = _mm256_set1_ps(da3 * db1); + __m256i f1 = load(Bp1 + l); + __m256i u1 = _mm256_sign_epi8(f1, f1); + __m256i s01 = _mm256_sign_epi8(e0, f1); + __m256i s11 = _mm256_sign_epi8(e1, f1); + __m256i s21 = _mm256_sign_epi8(e2, f1); + __m256i s31 = _mm256_sign_epi8(e3, f1); + c01 = madd(d01, updot(u1, s01), c01); + c11 = madd(d11, updot(u1, s11), c11); + c21 = madd(d21, updot(u1, s21), c21); + c31 = madd(d31, updot(u1, s31), c31); + float db2 = unhalf(Bp2[l].d); + __m256 d02 = _mm256_set1_ps(da0 * db2); + __m256 d12 = _mm256_set1_ps(da1 * db2); + __m256 d22 = _mm256_set1_ps(da2 * db2); + __m256 d32 = _mm256_set1_ps(da3 * db2); + __m256i f2 = load(Bp2 + l); + __m256i u2 = _mm256_sign_epi8(f2, f2); + __m256i s02 = _mm256_sign_epi8(e0, f2); + __m256i s12 = _mm256_sign_epi8(e1, f2); + __m256i s22 = _mm256_sign_epi8(e2, f2); + __m256i s32 = _mm256_sign_epi8(e3, f2); + c02 = madd(d02, updot(u2, s02), c02); + c12 = madd(d12, updot(u2, s12), c12); + c22 = madd(d22, updot(u2, s22), c22); + c32 = madd(d32, updot(u2, s32), c32); + } + C[ldc * (j + 0) + (i + 0)] = hsum(c00); + C[ldc * (j + 0) + (i + 1)] = hsum(c10); + C[ldc * (j + 0) + (i + 2)] = hsum(c20); + C[ldc * (j + 0) + (i + 3)] = hsum(c30); + C[ldc * (j + 1) + (i + 0)] = hsum(c01); + C[ldc * (j + 1) + (i + 1)] = hsum(c11); + C[ldc * (j + 1) + (i + 2)] = hsum(c21); + C[ldc * (j + 1) + (i + 3)] = hsum(c31); + C[ldc * (j + 2) + (i + 0)] = hsum(c02); + C[ldc * (j + 2) + (i + 1)] = hsum(c12); + C[ldc * (j + 2) + (i + 2)] = hsum(c22); + C[ldc * (j + 2) + (i + 3)] = hsum(c32); + END_KERNEL() + } + + NOINLINE void gemm4x1(int m0, int m, int n0, int n) { + BEGIN_KERNEL(4, 1) + __m256 c0 = _mm256_setzero_ps(); + __m256 c1 = _mm256_setzero_ps(); + __m256 c2 = _mm256_setzero_ps(); + __m256 c3 = _mm256_setzero_ps(); + const TA *Ap0 = A + lda * (i + 0); + const TA *Ap1 = A + lda * (i + 1); + const TA *Ap2 = A + lda * (i + 2); + const TA *Ap3 = A + lda * (i + 3); + const TB *Bp = B + ldb * j; + for (int l = 0; l < k; ++l) { + float db0 = unhalf(Bp[l].d); + __m256i f = load(Bp + l); + __m256i u = _mm256_sign_epi8(f, f); + __m256 d0 = _mm256_set1_ps(unhalf(Ap0[l].d) * db0); + __m256 d1 = _mm256_set1_ps(unhalf(Ap1[l].d) * db0); + __m256 d2 = _mm256_set1_ps(unhalf(Ap2[l].d) * db0); + __m256 d3 = _mm256_set1_ps(unhalf(Ap3[l].d) * db0); + __m256i e0 = load(Ap0 + l); + __m256i e1 = load(Ap1 + l); + __m256i e2 = load(Ap2 + l); + __m256i e3 = load(Ap3 + l); + __m256i s0 = _mm256_sign_epi8(e0, f); + __m256i s1 = _mm256_sign_epi8(e1, f); + __m256i s2 = _mm256_sign_epi8(e2, f); + __m256i s3 = _mm256_sign_epi8(e3, f); + __m256 g0 = updot(u, s0); + __m256 g1 = updot(u, s1); + __m256 g2 = updot(u, s2); + __m256 g3 = updot(u, s3); + c0 = madd(d0, g0, c0); + c1 = madd(d1, g1, c1); + c2 = madd(d2, g2, c2); + c3 = madd(d3, g3, c3); + } + C[ldc * j + (i + 0)] = hsum(c0); + C[ldc * j + (i + 1)] = hsum(c1); + C[ldc * j + (i + 2)] = hsum(c2); + C[ldc * j + (i + 3)] = hsum(c3); + END_KERNEL() + } + + NOINLINE void gemm1x4(int m0, int m, int n0, int n) { + BEGIN_KERNEL(1, 4) + __m256 c0 = _mm256_setzero_ps(); + __m256 c1 = _mm256_setzero_ps(); + __m256 c2 = _mm256_setzero_ps(); + __m256 c3 = _mm256_setzero_ps(); + const TB *Bp0 = B + ldb * (j + 0); + const TB *Bp1 = B + ldb * (j + 1); + const TB *Bp2 = B + ldb * (j + 2); + const TB *Bp3 = B + ldb * (j + 3); + const TA *Ap = A + lda * i; + for (int l = 0; l < k; ++l) { + float da0 = unhalf(Ap[l].d); + __m256i f = load(Ap + l); + __m256i u = _mm256_sign_epi8(f, f); + __m256 d0 = _mm256_set1_ps(unhalf(Bp0[l].d) * da0); + __m256 d1 = _mm256_set1_ps(unhalf(Bp1[l].d) * da0); + __m256 d2 = _mm256_set1_ps(unhalf(Bp2[l].d) * da0); + __m256 d3 = _mm256_set1_ps(unhalf(Bp3[l].d) * da0); + __m256 g0 = updot(u, _mm256_sign_epi8(load(Bp0 + l), f)); + __m256 g1 = updot(u, _mm256_sign_epi8(load(Bp1 + l), f)); + __m256 g2 = updot(u, _mm256_sign_epi8(load(Bp2 + l), f)); + __m256 g3 = updot(u, _mm256_sign_epi8(load(Bp3 + l), f)); + c0 = madd(d0, g0, c0); + c1 = madd(d1, g1, c1); + c2 = madd(d2, g2, c2); + c3 = madd(d3, g3, c3); + } + C[ldc * (j + 0) + i] = hsum(c0); + C[ldc * (j + 1) + i] = hsum(c1); + C[ldc * (j + 2) + i] = hsum(c2); + C[ldc * (j + 3) + i] = hsum(c3); + END_KERNEL() + } + + NOINLINE void gemm1x1(int m0, int m, int n0, int n) { + BEGIN_KERNEL(1, 1) + __m256 c = _mm256_setzero_ps(); + const TA *Ap = A + lda * i; + const TB *Bp = B + ldb * j; + for (int l = 0; l < k; ++l) { + __m256 d = _mm256_set1_ps(unhalf(Ap[l].d) * unhalf(Bp[l].d)); + __m256i e = load(Ap + l); + __m256i f = load(Bp + l); + __m256 g = updot(_mm256_sign_epi8(e, e), _mm256_sign_epi8(f, e)); + c = madd(d, g, c); + } + C[ldc * j + i] = hsum(c); + END_KERNEL() + } + + inline __m256i load(const block_q8_0 *b) { + return _mm256_loadu_si256((const __m256i *)b->qs); + } + + inline __m256i load(const block_q4_0 *b) { + return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); + } + + inline __m256 updot(__m256i u, __m256i s) { + __m256i res; +#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) + res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s); +#else + res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s)); +#endif + return _mm256_cvtepi32_ps(res); + } + + static inline __m256i denibble(const uint8_t *p) { + const __m128i tmp = _mm_loadu_si128((const __m128i *)p); + const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); + const __m256i lowMask = _mm256_set1_epi8(15); + return _mm256_and_si256(lowMask, bytes); + } + + const TA *const A; + const TB *const B; + TC *const C; + const int k; + const int lda; + const int ldb; + const int ldc; + const int ith; + const int nth; +}; +#endif // __AVX2__ + +} // namespace + +/** + * Performs optimized matrix multiplication on CPU. + * + * This subroutine may compute C = Aᵀ * B with column major ordering. + * Despite its name, this isn't a generalized implementation. Work is + * only performed when a handwritten kernel is written and available. + * Otherwise the caller should fall back to a general matmul routine. + * + * For example, for single-threaded single-precision GEMM you can say + * + * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, + * 0, 1, GGML_TASK_TYPE_COMPUTE, + * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32); + * + * @param m is rows in `A` and `C` + * @param n is cols in `B` and `C` + * @param k is cols in `A` and rows in `B` + * @param A is first input matrix (always transposed) + * @param lda is row stride of `A` + * @param B is second input matrix (never transposed) + * @param ldb is row stride of `B` + * @param C is input/output array of output matrices + * @param ldc is row stride of `C` + * @param ith is thread id (must be less than `nth`) + * @param nth is number of threads (must be greater than zero) + * @param task is GGML task type + * @param Atype is GGML data type of `A` + * @param Btype is GGML data type of `B` + * @param Ctype is GGML data type of `C` + * @return true if this function was able to service the matmul request + */ +bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, int ldb, void *C, + int ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) { + + assert(m >= 0); + assert(n >= 0); + assert(k >= 0); + assert(lda >= k); + assert(ldb >= k); + assert(ldc >= m); + assert(nth > 0); + assert(ith < nth); + assert(1ll * lda * m <= 0x7fffffff); + assert(1ll * ldb * n <= 0x7fffffff); + assert(1ll * ldc * n <= 0x7fffffff); + + if (Ctype != GGML_TYPE_F32) + return false; + + switch (Atype) { + + case GGML_TYPE_F32: { + if (Btype != GGML_TYPE_F32) + return false; +#if defined(__AVX512F__) + if (k % 16) + return false; + tinyBLAS<16, __m512, __m512, float, float, float> tb{ + k, (const float *)A, lda, + (const float *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n, task); + return true; +#elif defined(__AVX__) || defined(__AVX2__) + if (k % 8) + return false; + tinyBLAS<8, __m256, __m256, float, float, float> tb{ + k, (const float *)A, lda, + (const float *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n, task); + return true; +#elif defined(__ARM_NEON) + if (n < 4) + return false; + if (k % 4) + return false; + tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ + k, (const float *)A, lda, + (const float *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n, task); + return true; +#else + return false; +#endif + } + + case GGML_TYPE_F16: { +#if defined(__AVX512F__) + if (k % 16) + return false; + if (Btype != GGML_TYPE_F32) + return false; + tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{ + k, (const ggml_fp16_t *)A, lda, + (const float *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n, task); + return true; +#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) + if (k % 8) + return false; + if (Btype != GGML_TYPE_F32) + return false; + tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{ + k, (const ggml_fp16_t *)A, lda, + (const float *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n, task); + return true; +#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) + if (n < 8) + return false; + if (k % 8) + return false; + if (Btype != GGML_TYPE_F16) + return false; + tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ + k, (const ggml_fp16_t *)A, lda, + (const ggml_fp16_t *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n, task); + return true; +#elif defined(__ARM_NEON) && !defined(_MSC_VER) + if (k % 4) + return false; + if (Btype != GGML_TYPE_F32) + return false; + tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ + k, (const ggml_fp16_t *)A, lda, + (const float *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n, task); + return true; +#else + return false; +#endif + } + + case GGML_TYPE_Q8_0: { + if (Btype != GGML_TYPE_Q8_0) + return false; +#if defined(__AVX2__) || defined(__AVX512F__) + tinyBLAS_Q0_AVX2 tb{ + k, (const block_q8_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n, task); + return true; +#elif defined(__ARM_FEATURE_DOTPROD) + tinyBLAS_Q0_ARM tb{ + k, (const block_q8_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n, task); + return true; +#else + return false; +#endif + } + + case GGML_TYPE_Q4_0: { + if (Btype != GGML_TYPE_Q8_0) + return false; +#if defined(__AVX2__) || defined(__AVX512F__) + tinyBLAS_Q0_AVX2 tb{ + k, (const block_q4_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n, task); + return true; +#elif defined(__ARM_FEATURE_DOTPROD) + tinyBLAS_Q0_ARM tb{ + k, (const block_q4_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n, task); + return true; +#else + return false; +#endif + } + + default: + return false; + } + + (void)m; + (void)n; + (void)k; + (void)A; + (void)lda; + (void)B; + (void)ldb; + (void)C; + (void)ldc; + (void)ith; + (void)nth; + (void)task; + (void)Atype; + (void)Btype; + (void)Ctype; +} diff --git a/sgemm.h b/sgemm.h new file mode 100644 index 0000000000000..da23b209c4dd5 --- /dev/null +++ b/sgemm.h @@ -0,0 +1,12 @@ +#pragma once +#include +#ifdef __cplusplus +extern "C" { +#endif + +bool llamafile_sgemm(int, int, int, const void *, int, const void *, int, + void *, int, int, int, int, int, int, int); + +#ifdef __cplusplus +} +#endif From 666867b799ddd9da7dfdc905ece291ecf286effa Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 16 Apr 2024 23:50:22 +0300 Subject: [PATCH 080/100] ggml : fix llamafile sgemm wdata offsets (#6710) ggml-ci --- CMakeLists.txt | 6 ++++++ Makefile | 2 ++ ggml.c | 11 ++++------- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 158174c2094d2..2cc0df3fb2e67 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -88,6 +88,7 @@ endif() # 3rd party libs option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) option(LLAMA_BLAS "llama: use BLAS" OFF) +option(LLAMA_LLAMAFILE "llama: use llamafile SGEMM" ON) set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") option(LLAMA_CUDA "llama: use CUDA" OFF) option(LLAMA_CUBLAS "llama: use CUDA (deprecated, use LLAMA_CUDA)" OFF) @@ -286,6 +287,7 @@ if (LLAMA_METAL) ${METALKIT_FRAMEWORK} ) endif() + if (LLAMA_BLAS) if (LLAMA_STATIC) set(BLA_STATIC ON) @@ -368,6 +370,10 @@ if (LLAMA_BLAS) endif() endif() +if (LLAMA_LLAMAFILE) + add_compile_definitions(GGML_USE_LLAMAFILE) +endif() + if (LLAMA_QKK_64) add_compile_definitions(GGML_QKK_64) endif() diff --git a/Makefile b/Makefile index 928fb14cad628..9a711743ba94b 100644 --- a/Makefile +++ b/Makefile @@ -222,6 +222,8 @@ endif # LLAMA_DISABLE_LOGS # disable ggml.c's use of sgemm.cpp ifdef LLAMA_NO_LLAMAFILE MK_CPPFLAGS += -DGGML_USE_LLAMAFILE=0 +else + MK_CPPFLAGS += -DGGML_USE_LLAMAFILE=1 endif # warnings diff --git a/ggml.c b/ggml.c index 119686be68d9c..593c603f493be 100644 --- a/ggml.c +++ b/ggml.c @@ -33,12 +33,8 @@ #include #endif -#ifndef GGML_USE_LLAMAFILE #ifdef __ARM_FEATURE_MATMUL_INT8 -#define GGML_USE_LLAMAFILE 0 -#else -#define GGML_USE_LLAMAFILE 1 -#endif +#undef GGML_USE_LLAMAFILE #endif #if defined(_MSC_VER) @@ -10879,8 +10875,9 @@ UseGgmlGemm1:; if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), - (const char *)wdata + (nb12/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i12 + - nb13/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i13), + (const char *)wdata + ggml_row_size(vec_dot_type, + nb12/ggml_type_size(src1->type)*i12 + + nb13/ggml_type_size(src1->type)*i13), row_size/ggml_type_size(vec_dot_type), (char *)dst->data + i12*nb2 + i13*nb3, nb1/ggml_type_size(dst->type), From 532c1737a14bb4b99747e6f460874947df37e450 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 16 Apr 2024 23:50:38 +0300 Subject: [PATCH 081/100] llama : make general.name optional (#6709) --- llama.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 579986d1a5e2e..f4f4063cf6062 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4136,9 +4136,11 @@ static void llm_load_vocab( // CodeGemma (LLM_ARCH_GEMMA). This can potentially be removed once // new versions of these models have been published. std::string gen_name; - ml.get_key(LLM_KV_GENERAL_NAME, gen_name); + ml.get_key(LLM_KV_GENERAL_NAME, gen_name, false); + std::transform(gen_name.begin(), gen_name.end(), gen_name.begin(), [](unsigned char c){ return std::tolower(c); }); + if (gen_name.find("code") != std::string::npos) { if (model.arch == LLM_ARCH_LLAMA) { vocab.special_prefix_id = 32007; From facb8b56f8fd3bb10a693bf0943ae9d69d0828ef Mon Sep 17 00:00:00 2001 From: "Zheng.Deng" <32841220+dengzheng-cloud@users.noreply.github.com> Date: Wed, 17 Apr 2024 04:51:07 +0800 Subject: [PATCH 082/100] convert : fix autoawq gemma (#6704) * fix autoawq quantized gemma model convert error using autoawq to quantize gemma model will include a lm_head.weight tensor in model-00001-of-00002.safetensors. it result in this situation that convert-hf-to-gguf.py can't map lm_head.weight. skip loading this tensor could prevent this error. * change code to full string match and print necessary message change code to full string match and print a short message to inform users that lm_head.weight has been skipped. --------- Co-authored-by: Zheng.Deng <32841220+CUGfred@users.noreply.github.com> --- convert-hf-to-gguf.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index f321d77de11f8..c14186abbc2a6 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2458,6 +2458,12 @@ def write_tensors(self): tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) for name, data_torch in self.get_tensors(): + # lm_head is not used in llama.cpp, while autoawq will include this tensor in model + # To prevent errors, skip loading lm_head.weight. + if name == "lm_head.weight": + print(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.") + continue + old_dtype = data_torch.dtype # convert any unsupported data types to float32 From 599ce84a71512b72bf4fd6a248e7725f646eb1a8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 17 Apr 2024 12:00:35 +0300 Subject: [PATCH 083/100] llama : flash_attn cparam + fix defrag --- common/common.cpp | 6 + common/common.h | 1 + llama.cpp | 345 +++++++++++++++++++++++++--------------------- llama.h | 1 + 4 files changed, 194 insertions(+), 159 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index cf69535e2d1f5..fbff8cf13effc 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -900,6 +900,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.cont_batching = true; return true; } + if (arg == "-fa" || arg == "--flash-attn") { + params.flash_attn = true; + return true; + } if (arg == "--color") { params.use_color = true; return true; @@ -1836,6 +1840,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; + cparams.flash_attn = params.flash_attn; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); @@ -2673,6 +2678,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed); fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); + fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); diff --git a/common/common.h b/common/common.h index cca44268e6df5..78a1a34021b60 100644 --- a/common/common.h +++ b/common/common.h @@ -148,6 +148,7 @@ struct gpt_params { bool multiline_input = false; // reverse the usage of `\` bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly + bool flash_attn = false; // flash attention bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens diff --git a/llama.cpp b/llama.cpp index 1081899fbaa3f..0b5ee9ef569c6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -107,8 +107,6 @@ #define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_EXPERTS 60 -#define LLAMA_FLASH_ATTN - // // logging // @@ -1899,6 +1897,7 @@ struct llama_cparams { bool embeddings; bool causal_attn; bool offload_kqv; + bool flash_attn; enum llama_pooling_type pooling_type; @@ -5938,15 +5937,17 @@ static struct ggml_tensor * llm_build_inp_embd( static void llm_build_kv_store( struct ggml_context * ctx, const llama_hparams & hparams, + const llama_cparams & cparams, const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * k_cur, struct ggml_tensor * v_cur, - int64_t n_ctx, int32_t n_tokens, int32_t kv_head, const llm_build_cb & cb, int64_t il) { + const int64_t n_ctx = cparams.n_ctx; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); @@ -5959,28 +5960,26 @@ static void llm_build_kv_store( // important: storing RoPE-ed version of K in the KV cache! ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); -#if defined(LLAMA_FLASH_ATTN) - // NOTE: the V cache is not transposed when using FLASH attention !! - struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, - (ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head); - cb(v_cache_view, "v_cache_view", il); - - ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); + if (cparams.flash_attn) { + // NOTE: the V cache is not transposed when using FLASH attention !! + struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, + (ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head); + cb(v_cache_view, "v_cache_view", il); - GGML_UNUSED(n_ctx); -#else - // compute the transposed [n_tokens, n_embd] V matrix - //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); - assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); - struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed - cb(v_cur_t, "v_cur_t", il); + ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); + } else { + // compute the transposed [n_tokens, n_embd] V matrix + //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); + assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); + struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); + cb(v_cur_t, "v_cur_t", il); - struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, - ( n_ctx)*ggml_element_size(kv.v_l[il]), - (kv_head)*ggml_element_size(kv.v_l[il])); + struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, + ( n_ctx)*ggml_element_size(kv.v_l[il]), + (kv_head)*ggml_element_size(kv.v_l[il])); - ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); -#endif + ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); + } } static struct ggml_tensor * llm_build_norm( @@ -6111,6 +6110,7 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_context * ctx, const llama_model & model, const llama_hparams & hparams, + const llama_cparams & cparams, const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * wo, @@ -6118,12 +6118,12 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * q_cur, struct ggml_tensor * kq_mask, struct ggml_tensor * kq_pos, - int64_t n_ctx, int32_t n_tokens, int32_t n_kv, float kq_scale, const llm_build_cb & cb, int il) { + const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head_k = hparams.n_embd_head_k; @@ -6143,97 +6143,100 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * cur; -#if defined(LLAMA_FLASH_ATTN) - GGML_UNUSED(model); - GGML_UNUSED(n_ctx); + if (cparams.flash_attn) { + GGML_UNUSED(model); + GGML_UNUSED(n_ctx); - GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention"); + GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention"); - // split cached v into n_head heads (not transposed) - struct ggml_tensor * v = - ggml_view_3d(ctx, kv.v_l[il], - n_embd_head_v, n_kv, n_head_kv, - ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv.v_l[il]->type, n_embd_head_k), - 0); - cb(v, "v", il); + // split cached v into n_head heads (not transposed) + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv.v_l[il]->type, n_embd_head_k), + 0); + cb(v, "v", il); - cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale); - ggml_flash_attn_ext_set_prec(cur, GGML_PREC_DEFAULT); - //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); - //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); - //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); - //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); - //printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]); + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale); - cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens); -#else - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); + if (model.arch == LLM_ARCH_PHI2) { + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + } + //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); + //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); + //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); + //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); + //printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]); - if (model.arch == LLM_ARCH_PHI2) { - // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs - // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - } + cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens); + } else { + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); + + if (model.arch == LLM_ARCH_PHI2) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } - if (model.arch == LLM_ARCH_GROK) { - // need to do the following: - // multiply by attn_output_multiplyer of 0.08838834764831845 - // and then : - // kq = 30 * tanh(kq / 30) - // before the softmax below + if (model.arch == LLM_ARCH_GROK) { + // need to do the following: + // multiply by attn_output_multiplyer of 0.08838834764831845 + // and then : + // kq = 30 * tanh(kq / 30) + // before the softmax below - //try from phi2 - //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + //try from phi2 + //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); - kq = ggml_scale(ctx, kq, 30); - } + kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); + kq = ggml_scale(ctx, kq, 30); + } #if defined(GGML_USE_KOMPUTE) #pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Kompute") #pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024") #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5488") - if (hparams.f_max_alibi_bias > 0.0f) { - kq = ggml_scale(ctx, kq, kq_scale); - cb(kq, "kq_scaled", il); + if (hparams.f_max_alibi_bias > 0.0f) { + kq = ggml_scale(ctx, kq, kq_scale); + cb(kq, "kq_scaled", il); - kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias); - cb(kq, "kq_scaled_alibi", il); + kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias); + cb(kq, "kq_scaled_alibi", il); - kq = ggml_add(ctx, kq, kq_mask); - cb(kq, "kq_masked", il); + kq = ggml_add(ctx, kq, kq_mask); + cb(kq, "kq_masked", il); - kq = ggml_soft_max(ctx, kq); - cb(kq, "kq_soft_max", il); - } else + kq = ggml_soft_max(ctx, kq); + cb(kq, "kq_soft_max", il); + } else #endif - { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias); - cb(kq, "kq_soft_max_ext", il); - } + { + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + } - GGML_ASSERT(kv.size == n_ctx); + GGML_ASSERT(kv.size == n_ctx); - // split cached v into n_head heads - struct ggml_tensor * v = - ggml_view_3d(ctx, kv.v_l[il], - n_kv, n_embd_head_v, n_head_kv, - ggml_element_size(kv.v_l[il])*n_ctx, - ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, - 0); - cb(v, "v", il); + // split cached v into n_head heads + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(kv.v_l[il])*n_ctx, + ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, + 0); + cb(v, "v", il); - struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); + struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv", il); - struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); - cb(kqv_merged, "kqv_merged", il); + struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); - cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); - cb(cur, "kqv_merged_cont", il); -#endif + cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); + cb(cur, "kqv_merged_cont", il); + } ggml_build_forward_expand(graph, cur); @@ -6253,6 +6256,7 @@ static struct ggml_tensor * llm_build_kv( struct ggml_context * ctx, const llama_model & model, const llama_hparams & hparams, + const llama_cparams & cparams, const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * wo, @@ -6262,7 +6266,6 @@ static struct ggml_tensor * llm_build_kv( struct ggml_tensor * q_cur, struct ggml_tensor * kq_mask, struct ggml_tensor * kq_pos, - int64_t n_ctx, int32_t n_tokens, int32_t kv_head, int32_t n_kv, @@ -6276,12 +6279,12 @@ static struct ggml_tensor * llm_build_kv( ggml_build_forward_expand(graph, k_cur); ggml_build_forward_expand(graph, v_cur); - llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il); + llm_build_kv_store(ctx, hparams, cparams, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il); struct ggml_tensor * cur; - cur = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b, - q_cur, kq_mask, kq_pos, n_ctx, n_tokens, n_kv, kq_scale, cb, il); + cur = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b, + q_cur, kq_mask, kq_pos, n_tokens, n_kv, kq_scale, cb, il); cb(cur, "kqv_out", il); return cur; @@ -6323,6 +6326,8 @@ struct llm_build_context { const int32_t kv_head; // index of where we store new KV data in the cache const int32_t n_orig_ctx; + const bool flash_attn; + const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; @@ -6369,6 +6374,7 @@ struct llm_build_context { n_outputs (worst_case ? n_tokens : lctx.n_outputs), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), + flash_attn (cparams.flash_attn), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), @@ -6483,15 +6489,31 @@ struct llm_build_context { ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id)); - ggml_tensor * view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(kv_self.v_l[il]->type, kv_self.size), - ggml_row_size(kv_self.v_l[il]->type, i)); + ggml_tensor * view_v_src; + ggml_tensor * view_v_dst; + + if (flash_attn) { + // NOTE: the V cache is not transposed when using flash attention + view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], + n_embd_v_gqa, nm, + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i)); - ggml_tensor * view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(kv_self.v_l[il]->type, kv_self.size), - ggml_row_size(kv_self.v_l[il]->type, id)); + view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], + n_embd_v_gqa, nm, + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id)); + } else { + view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], + nm, n_embd_v_gqa, + ggml_row_size(kv_self.v_l[il]->type, kv_self.size), + ggml_row_size(kv_self.v_l[il]->type, i)); + + view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], + nm, n_embd_v_gqa, + ggml_row_size(kv_self.v_l[il]->type, kv_self.size), + ggml_row_size(kv_self.v_l[il]->type, id)); + } ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst)); @@ -6640,9 +6662,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -6846,9 +6868,9 @@ struct llm_build_context { cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -6953,9 +6975,9 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7073,9 +7095,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7198,9 +7220,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -7343,9 +7365,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, - model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7447,9 +7469,9 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7651,9 +7673,9 @@ struct llm_build_context { ); cb(Vcur, "Vcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Q, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Q, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7747,9 +7769,9 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); cb(Qcur, "Qcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8040,9 +8062,9 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8171,14 +8193,15 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } else { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } } @@ -8320,9 +8343,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8438,9 +8461,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8557,9 +8580,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8671,9 +8694,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8817,9 +8840,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -8919,9 +8942,9 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } struct ggml_tensor * sa_out = cur; @@ -9022,9 +9045,9 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9129,9 +9152,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9245,9 +9268,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9362,9 +9385,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9492,9 +9515,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9613,9 +9636,9 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -9732,9 +9755,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -10022,9 +10045,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -11016,7 +11039,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { // each move requires 6*n_layer tensors (see build_defrag) // - source view, destination view, copy operation // - x2 for keys and values - const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer); + //const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer); + // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 + const uint32_t max_moves = (LLAMA_MAX_NODES - 2*n_layer)/(6*n_layer); // determine which KV cells to move where // @@ -14626,6 +14651,7 @@ struct llama_context_params llama_context_default_params() { /*.logits_all =*/ false, /*.embeddings =*/ false, /*.offload_kqv =*/ true, + /*.flash_attn =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -14795,6 +14821,7 @@ struct llama_context * llama_new_context_with_model( cparams.defrag_thold = params.defrag_thold; cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; + cparams.flash_attn = params.flash_attn; cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; diff --git a/llama.h b/llama.h index b5da686f7b7e5..77c288eb25543 100644 --- a/llama.h +++ b/llama.h @@ -270,6 +270,7 @@ extern "C" { bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU + bool flash_attn; // whether to use flash attention // Abort callback // if it returns true, execution of llama_decode() will be aborted From 405385726ef7432b65b9e63dd7a63c18765eb376 Mon Sep 17 00:00:00 2001 From: Pierrick HYMBERT Date: Wed, 17 Apr 2024 14:05:02 +0200 Subject: [PATCH 084/100] server: support flash_attn param --- examples/server/server.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 634e653ada284..f1754b60b7fe8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2722,6 +2722,8 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, params.embedding = true; } else if (arg == "-cb" || arg == "--cont-batching") { params.cont_batching = true; + } else if (arg == "-fa" || arg == "--flash-attn") { + params.flash_attn = true; } else if (arg == "-np" || arg == "--parallel") { if (++i >= argc) { invalid_param = true; From 8dd1ec8b3ffbfa2d26e82e672cea89f5eeb2f141 Mon Sep 17 00:00:00 2001 From: Yaroslav Date: Wed, 17 Apr 2024 14:47:50 +0200 Subject: [PATCH 085/100] readme : add UI (#6724) * Update README.md * Update README.md --------- Co-authored-by: Georgi Gerganov --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index fa2a4db834df2..e3eae60bb015b 100644 --- a/README.md +++ b/README.md @@ -189,6 +189,8 @@ Unless otherwise noted these projects are open-source with permissive licensing: - [MindMac](https://mindmac.app) (proprietary) - [KodiBot](https://github.com/firatkiral/kodibot) (GPL) - [eva](https://github.com/ylsdamxssjxxdd/eva) (MIT) +- [AI Sublime Text plugin](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (MIT) + *(to have a project listed here, it should clearly state that it depends on `llama.cpp`)* --- From 3b8f1ec4b18770531d0b1d792f3edf08254e4f0c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 17 Apr 2024 23:58:26 +0300 Subject: [PATCH 086/100] llamafile : tmp disable + build sgemm.o when needed (#6716) * build : sgemm.o only when needed ggml-ci * llamafile : tmp disable due to MoE bug ggml-ci --- CMakeLists.txt | 38 ++++++++++++++++++++++++++------------ Makefile | 28 +++++++++++++++------------- 2 files changed, 41 insertions(+), 25 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2cc0df3fb2e67..f134a153bb4ff 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,6 +43,18 @@ else() set(LLAMA_METAL_DEFAULT OFF) endif() +# TODO: fix this for Android CI +# https://github.com/ggerganov/llama.cpp/pull/6716#issuecomment-2061509191 +#if (CMAKE_SYSTEM_NAME MATCHES "ANDROID") +# set(LLAMA_LLAMAFILE_DEFAULT OFF) +#else() +# set(LLAMA_LLAMAFILE_DEFAULT ON) +#endif() + +# TODO: temporary disable until MoE is fixed +# https://github.com/ggerganov/llama.cpp/pull/6716 +set(LLAMA_LLAMAFILE_DEFAULT OFF) + # general option(BUILD_SHARED_LIBS "build shared libraries" OFF) option(LLAMA_STATIC "llama: static link libraries" OFF) @@ -88,7 +100,7 @@ endif() # 3rd party libs option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) option(LLAMA_BLAS "llama: use BLAS" OFF) -option(LLAMA_LLAMAFILE "llama: use llamafile SGEMM" ON) +option(LLAMA_LLAMAFILE "llama: use llamafile SGEMM" ${LLAMA_LLAMAFILE_DEFAULT}) set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") option(LLAMA_CUDA "llama: use CUDA" OFF) option(LLAMA_CUBLAS "llama: use CUDA (deprecated, use LLAMA_CUDA)" OFF) @@ -372,6 +384,9 @@ endif() if (LLAMA_LLAMAFILE) add_compile_definitions(GGML_USE_LLAMAFILE) + + set(GGML_HEADERS_LLAMAFILE sgemm.h) + set(GGML_SOURCES_LLAMAFILE sgemm.cpp) endif() if (LLAMA_QKK_64) @@ -1157,17 +1172,16 @@ add_library(ggml OBJECT ggml-backend.h ggml-quants.c ggml-quants.h - sgemm.cpp - sgemm.h - ${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA} - ${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL} - ${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL} - ${GGML_SOURCES_MPI} ${GGML_HEADERS_MPI} - ${GGML_SOURCES_EXTRA} ${GGML_HEADERS_EXTRA} - ${GGML_SOURCES_SYCL} ${GGML_HEADERS_SYCL} - ${GGML_SOURCES_KOMPUTE} ${GGML_HEADERS_KOMPUTE} - ${GGML_SOURCES_VULKAN} ${GGML_HEADERS_VULKAN} - ${GGML_SOURCES_ROCM} ${GGML_HEADERS_ROCM} + ${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA} + ${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL} + ${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL} + ${GGML_SOURCES_MPI} ${GGML_HEADERS_MPI} + ${GGML_SOURCES_EXTRA} ${GGML_HEADERS_EXTRA} + ${GGML_SOURCES_SYCL} ${GGML_HEADERS_SYCL} + ${GGML_SOURCES_KOMPUTE} ${GGML_HEADERS_KOMPUTE} + ${GGML_SOURCES_VULKAN} ${GGML_HEADERS_VULKAN} + ${GGML_SOURCES_ROCM} ${GGML_HEADERS_ROCM} + ${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE} ) target_include_directories(ggml PUBLIC . ${LLAMA_EXTRA_INCLUDES}) diff --git a/Makefile b/Makefile index 9a711743ba94b..760015f294729 100644 --- a/Makefile +++ b/Makefile @@ -219,13 +219,6 @@ ifdef LLAMA_DISABLE_LOGS MK_CPPFLAGS += -DLOG_DISABLE_LOGS endif # LLAMA_DISABLE_LOGS -# disable ggml.c's use of sgemm.cpp -ifdef LLAMA_NO_LLAMAFILE - MK_CPPFLAGS += -DGGML_USE_LLAMAFILE=0 -else - MK_CPPFLAGS += -DGGML_USE_LLAMAFILE=1 -endif - # warnings WARN_FLAGS = -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function MK_CFLAGS += $(WARN_FLAGS) -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -Werror=implicit-int \ @@ -391,6 +384,15 @@ ifdef LLAMA_OPENBLAS MK_LDFLAGS += $(shell pkg-config --libs openblas) endif # LLAMA_OPENBLAS +# TODO: temporary disable until MoE is fixed +# https://github.com/ggerganov/llama.cpp/pull/6716 +LLAMA_NO_LLAMAFILE := 1 + +ifndef LLAMA_NO_LLAMAFILE + MK_CPPFLAGS += -DGGML_USE_LLAMAFILE + OBJS += sgemm.o +endif + ifdef LLAMA_BLIS MK_CPPFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/blis -I/usr/include/blis MK_LDFLAGS += -lblis -L/usr/local/lib @@ -487,11 +489,9 @@ ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/com ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh) $(NVCC_COMPILE) - endif # LLAMA_CUDA ifdef LLAMA_CLBLAST - MK_CPPFLAGS += -DGGML_USE_CLBLAST $(shell pkg-config --cflags-only-I clblast OpenCL) MK_CFLAGS += $(shell pkg-config --cflags-only-other clblast OpenCL) MK_CXXFLAGS += $(shell pkg-config --cflags-only-other clblast OpenCL) @@ -610,6 +610,11 @@ ggml-mpi.o: ggml-mpi.c ggml-mpi.h $(CC) $(CFLAGS) -c $< -o $@ endif # LLAMA_MPI +ifndef LLAMA_NO_LLAMAFILE +sgemm.o: sgemm.cpp sgemm.h ggml.h + $(CXX) $(CXXFLAGS) -c $< -o $@ +endif + GF_CC := $(CC) include scripts/get-flags.mk @@ -683,16 +688,13 @@ ggml-backend.o: ggml-backend.c ggml.h ggml-backend.h ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h ggml-common.h $(CC) $(CFLAGS) -c $< -o $@ -sgemm.o: sgemm.cpp sgemm.h ggml.h - $(CXX) $(CXXFLAGS) -c $< -o $@ - unicode.o: unicode.cpp unicode.h $(CXX) $(CXXFLAGS) -c $< -o $@ unicode-data.o: unicode-data.cpp unicode-data.h $(CXX) $(CXXFLAGS) -c $< -o $@ -OBJS += ggml-alloc.o ggml-backend.o ggml-quants.o unicode.o unicode-data.o sgemm.o +OBJS += ggml-alloc.o ggml-backend.o ggml-quants.o unicode.o unicode-data.o llama.o: llama.cpp unicode.h ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h llama.h $(CXX) $(CXXFLAGS) -c $< -o $@ From 5668c79ea092b7bff95e1fce96e3de717c31349d Mon Sep 17 00:00:00 2001 From: Pierrick HYMBERT Date: Wed, 17 Apr 2024 23:26:29 +0200 Subject: [PATCH 087/100] server: bench: enable flash_attn param --- examples/server/bench/bench.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/server/bench/bench.py b/examples/server/bench/bench.py index 6ca637bddc3a1..86c5de101445c 100644 --- a/examples/server/bench/bench.py +++ b/examples/server/bench/bench.py @@ -268,6 +268,7 @@ def start_server_background(args): server_args.extend(['--defrag-thold', "0.1"]) server_args.append('--cont-batching') server_args.append('--metrics') + server_args.append('--flash-attn') server_args.extend(['--log-format', "text"]) args = [str(arg) for arg in [server_path, *server_args]] print(f"bench: starting server with: {' '.join(args)}") From c71bfd736ee99a56e697697b39240f2ee06ed26d Mon Sep 17 00:00:00 2001 From: slaren Date: Thu, 18 Apr 2024 09:04:47 +0200 Subject: [PATCH 088/100] llama : fix compatibility with old 2 expert models (#6735) --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index f4f4063cf6062..8c1446296fe35 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4592,7 +4592,7 @@ static bool llm_load_tensors( size_t ctx_size = ggml_tensor_overhead()*(ml.n_tensors + 1); // +1 for models where tok_embd is duplicated as output // for moe merged tensors - ctx_size += ggml_tensor_overhead()*hparams.n_expert*n_layer; + ctx_size += ggml_tensor_overhead()*n_layer*3; std::map ctx_map; for (auto & it : buft_layer_count) { From 34f93bbb39965dd40fe8ad717902d6d109b64afb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 9 Apr 2024 11:39:16 +0200 Subject: [PATCH 089/100] CUDA: refactor host code, dyn. par. blocks --- ggml-cuda.cu | 1 + ggml-cuda/common.cuh | 6 + ggml-cuda/fattn.cu | 542 +++++++++++++++++++------------------------ 3 files changed, 248 insertions(+), 301 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 11adbabd655e1..2cf6c8d98bd89 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -141,6 +141,7 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].cc = 100*prop.major + 10*prop.minor; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) info.devices[id].smpb = prop.sharedMemPerBlock; + info.devices[id].nsm = prop.multiProcessorCount; } for (int id = 0; id < info.device_count; ++id) { diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index b0149b7be22b3..989780dbce88c 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -390,6 +390,11 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { } #endif // defined(GGML_USE_HIPBLAS) +#define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ + defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL +#define FP16_MMA_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ + defined(RDNA3) : __CUDA_ARCH__ >= CC_VOLTA + // TODO: move to ggml-common.h static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; @@ -403,6 +408,7 @@ struct ggml_cuda_device_info { struct cuda_device_info { int cc; // compute capability + int nsm; // number of streaming multiprocessors size_t smpb; // max. shared memory per block bool vmm; // virtual memory support size_t vmm_granularity; // granularity of virtual memory diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 91ef5551e025a..5f1345a7fe94f 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -36,18 +36,17 @@ static __global__ void flash_attn_vec_ext_f16( const int ne1, const int ne2, const int ne3) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#if FP16_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic = blockIdx.x / parallel_blocks; // Index of the Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y); + const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic); const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask; - - if (parallel_blocks == 1) { - Q_f2 += blockIdx.x*nb01/sizeof(float2); - maskh += blockIdx.x*ne11; - } + const half * maskh = (const half *) mask + ne11*ic; const int stride_KV = nb11 / sizeof(half); const int stride_KV2 = nb11 / sizeof(half2); @@ -85,7 +84,7 @@ static __global__ void flash_attn_vec_ext_f16( half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value. - const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*D; + const int k_start = parallel_blocks == 1 ? 0 : ip*D; for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { // Calculate KQ tile and keep track of new maximum KQ values: half kqmax_new = kqmax; @@ -168,18 +167,19 @@ static __global__ void flash_attn_vec_ext_f16( return; } + half dst_val = (__low2half(VKQ) + __high2half(VKQ)); if (parallel_blocks == 1) { - dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum; - } else { - dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)); + dst_val /= kqsum; + } + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = dst_val; - if (tid == 0) { - dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(kqmax, kqsum); - } + if (parallel_blocks == 1 || tid != 0) { + return; } + dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_half2(kqmax, kqsum); #else NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#endif // FP16_AVAILABLE } template // D == head size, VKQ_stride == num VKQ rows calculated in parallel @@ -212,8 +212,12 @@ static __global__ void flash_attn_ext_f16( const int ne1, const int ne2, const int ne3) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA +#if FP16_MMA_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); constexpr int frag_m = ncols == 8 ? 32 : 16; @@ -233,15 +237,10 @@ static __global__ void flash_attn_ext_f16( constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb02* blockIdx.y); + const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half2 * mask2 = (half2 *) mask; - - if (parallel_blocks == 1) { - Q_f += blockIdx.x * ncols*nb01/sizeof(float); - mask2 += blockIdx.x * ncols*ne11/2; - } + const half2 * mask2 = (const half2 *) mask + ne11*(ic0/2); const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); @@ -283,11 +282,7 @@ static __global__ void flash_attn_ext_f16( if (i0 + WARP_SIZE > D && i >= D) { break; } - if (parallel_blocks == 1) { - KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; - } else { - KQ[j*D_padded + i] = j == 0 ? Q_f[j*stride_Q + i] * scale : 0.0f; - } + KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; } } @@ -305,8 +300,7 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); // Iterate over ne11 == previous tokens: - const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*FATTN_KQ_STRIDE; - for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { + for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { // Calculate tile of KQ: #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { @@ -439,41 +433,39 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); } - if (parallel_blocks == 1) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - if (ncols*blockIdx.x + j >= ne01) { - return; - } - const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]); #pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { - break; - } - dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j; - } + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j_VKQ = j0 + threadIdx.y; + if (ic0 + j_VKQ >= ne01) { + return; } - } else { + const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; + + const half KQ_rowsum_j = __low2half(KQ_rowsum[j0/nwarps]) + __high2half(KQ_rowsum[j0/nwarps]); #pragma unroll - for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) { - const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x; - if (i0 + nwarps*WARP_SIZE > D && i >= D) { - return; + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + half dst_val = VKQ[j_VKQ*D_padded + i]; + if (parallel_blocks == 1) { + dst_val /= KQ_rowsum_j; } - dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + i] = VKQ[i]; + dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val; } - if (threadIdx.y == 0 && threadIdx.x == 0) { - dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2( - __low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0])); + if (parallel_blocks == 1 || threadIdx.x != 0) { + continue; } + + half2 dst_meta_val = KQ_max[j0/nwarps]; + reinterpret_cast(dst_meta_val.y) = KQ_rowsum_j; + dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; } #else NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA +#endif // FP16_MMA_AVAILABLE } template // D == head size @@ -482,7 +474,10 @@ static __global__ void flash_attn_combine_results( const float * __restrict__ VKQ_parts, const half2 * __restrict__ VKQ_meta, float * __restrict__ dst) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#if FP16_AVAILABLE + VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; + VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; + dst += D * gridDim.y*blockIdx.x; const int tid = threadIdx.x; __builtin_assume(tid < D); @@ -513,7 +508,7 @@ static __global__ void flash_attn_combine_results( dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; #else NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#endif // FP16_AVAILABLE } constexpr int get_max_power_of_2(int x) { @@ -540,26 +535,124 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); -#define FATTN_SWITCH_CASE(D, ncols, nwarps) \ - case ncols: { \ - constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \ - flash_attn_ext_f16 \ - <<>> ( \ - (const char *) Q->data, \ - (const char *) K->data, \ - (const char *) V->data, \ - mask ? ((const char *) mask->data) : nullptr, \ - (float *) KQV->data, nullptr, \ - scale, \ - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \ - K->ne[0], K->ne[1], K->ne[2], K->ne[3], \ - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, \ - Q->nb[1], Q->nb[2], Q->nb[3], \ - K->nb[1], K->nb[2], K->nb[3], \ - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] \ - ); \ - } \ - break; \ +template void launch_fattn_vec_f16( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + ggml_cuda_pool & pool, cudaStream_t main_stream +) { + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + + constexpr int nwarps = ((D) + WARP_SIZE - 1) / WARP_SIZE; + constexpr dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); + const int shmem = 0; + + float scale; + memcpy(&scale, KQV->op_params, sizeof(float)); + + flash_attn_vec_ext_f16 + <<>> ( + (const char *) Q->data, + (const char *) K->data, + (const char *) V->data, + mask ? ((const char *) mask->data) : nullptr, + (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if ((parallel_blocks) == 1) { + return; + } + + constexpr dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; + + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + CUDA_CHECK(cudaGetLastError()); +} + +template void launch_fattn_f16_impl( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + ggml_cuda_pool & pool, cudaStream_t main_stream +) { + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + + constexpr int frag_m = (cols_per_block) == 8 && (D) % 32 == 0 ? 32 : 16; + constexpr dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); + const int shmem = 0; + + float scale; + memcpy(&scale, KQV->op_params, sizeof(float)); + + flash_attn_ext_f16 + <<>> ( + (const char *) Q->data, + (const char *) K->data, + (const char *) V->data, + mask ? ((const char *) mask->data) : nullptr, + (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if ((parallel_blocks) == 1) { + return; + } + + constexpr dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; + + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + CUDA_CHECK(cudaGetLastError()); +} + +template void launch_fattn_f16( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream +) { + const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; + + if (4*blocks_num_pb1 < 2*nsm) { + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + return; + } + if (2*blocks_num_pb1 < 2*nsm) { + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + return; + } + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); +} void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; @@ -583,259 +676,106 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_set_device(ctx.device); - const cudaStream_t main_stream = ctx.stream(); - - float scale; - memcpy(&scale, KQV->op_params, sizeof(float)); - - if (Q->ne[1] == 1) { + if (Q->ne[1] == 1 && Q->ne[0] % WARP_SIZE == 0) { constexpr int parallel_blocks = 4; - - ggml_cuda_pool_alloc dst_tmp(ctx.pool()); - ggml_cuda_pool_alloc dst_tmp_meta(ctx.pool()); - - const int nwarps = (Q->ne[0] + WARP_SIZE - 1) / WARP_SIZE; - const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); - const dim3 block_dim(WARP_SIZE, nwarps, 1); - const int shmem = 0; - - // Performance of the vector kernel is very bad for head sizes 80 and 112, use the tensor core kernel instead: - constexpr int nwarps_tc = 4; - constexpr dim3 block_dim_tc(WARP_SIZE, nwarps_tc, 1); - - const dim3 blocks_num_combine(1, blocks_num.y, blocks_num.z); - const dim3 block_dim_combine(Q->ne[0], 1, 1); - const int shmem_combine = 0; - - if (parallel_blocks > 1) { - dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); - dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); - } - switch (Q->ne[0]) { case 64: - flash_attn_vec_ext_f16<64, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<64, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); - break; - case 80: - flash_attn_ext_f16<80, 16, nwarps_tc, get_VKQ_stride(80, nwarps_tc, 16), parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<80, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; case 96: - flash_attn_vec_ext_f16<96, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<96, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); - break; - case 112: - flash_attn_vec_ext_f16<112, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<112, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + launch_fattn_vec_f16< 96, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; case 128: - flash_attn_vec_ext_f16<128, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<128, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; case 256: - flash_attn_vec_ext_f16<256, parallel_blocks> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - if (parallel_blocks == 1) { - break; - } - CUDA_CHECK(cudaGetLastError()); - flash_attn_combine_results<256, parallel_blocks> - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + launch_fattn_vec_f16<256, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; default: GGML_ASSERT(false); break; } - CUDA_CHECK(cudaGetLastError()); return; } - int cols_per_block; - if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { - cols_per_block = 32; - } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { - cols_per_block = 16; - } else { - cols_per_block = 8; - } - constexpr int nwarps = 4; - const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); - const dim3 block_dim(WARP_SIZE, nwarps, 1); - const size_t shmem = 0; + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - switch (Q->ne[0]) { - case 64: switch (cols_per_block) { - FATTN_SWITCH_CASE(64, 8, nwarps); - FATTN_SWITCH_CASE(64, 16, nwarps); - FATTN_SWITCH_CASE(64, 32, nwarps); - default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); - GGML_ASSERT(false); + if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { + constexpr int cols_per_block = 8; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; - } break; - case 80: switch (cols_per_block) { - // FATTN_SWITCH_CASE(80, 8, nwarps); - FATTN_SWITCH_CASE(80, 16, nwarps); - FATTN_SWITCH_CASE(80, 32, nwarps); - default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); - GGML_ASSERT(false); + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; - } break; - case 96: switch (cols_per_block) { - FATTN_SWITCH_CASE(96, 8, nwarps); - FATTN_SWITCH_CASE(96, 16, nwarps); - FATTN_SWITCH_CASE(96, 32, nwarps); - default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); - GGML_ASSERT(false); + case 128: + launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; - } break; - case 112: switch (cols_per_block) { - // FATTN_SWITCH_CASE(112, 8, nwarps); - FATTN_SWITCH_CASE(112, 16, nwarps); - FATTN_SWITCH_CASE(112, 32, nwarps); - default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); - GGML_ASSERT(false); + case 256: + launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; - } break; - case 128: switch (cols_per_block) { - FATTN_SWITCH_CASE(128, 8, nwarps); - FATTN_SWITCH_CASE(128, 16, nwarps); - FATTN_SWITCH_CASE(128, 32, nwarps); default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); break; - } break; - case 256: switch (cols_per_block) { - FATTN_SWITCH_CASE(256, 8, nwarps); - FATTN_SWITCH_CASE(256, 16, nwarps); - FATTN_SWITCH_CASE(256, 32, nwarps); + } + return; + } + + if (Q->ne[1] <= 32) { + constexpr int cols_per_block = 16; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; default: - fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); break; - } break; + } + return; + } + + constexpr int cols_per_block = 32; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; default: GGML_ASSERT(false); break; } - CUDA_CHECK(cudaGetLastError()); + return; } From 6a3b84236de279f0fe012cfca0c168472526b696 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 13 Apr 2024 22:05:43 +0200 Subject: [PATCH 090/100] fix flash_attn_vec_f16 race condition --- ggml-cuda/fattn.cu | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 5f1345a7fe94f..36479b2170979 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -149,6 +149,8 @@ static __global__ void flash_attn_vec_ext_f16( VKQ += V_k*KQ2[k0/2]; } } + + __syncthreads(); } if (tid >= D) { @@ -547,7 +549,7 @@ template void launch_fattn_vec_f16( dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); } - constexpr int nwarps = ((D) + WARP_SIZE - 1) / WARP_SIZE; + constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; constexpr dim3 block_dim(WARP_SIZE, nwarps, 1); const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); const int shmem = 0; @@ -561,7 +563,7 @@ template void launch_fattn_vec_f16( (const char *) K->data, (const char *) V->data, mask ? ((const char *) mask->data) : nullptr, - (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, scale, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], @@ -572,7 +574,7 @@ template void launch_fattn_vec_f16( ); CUDA_CHECK(cudaGetLastError()); - if ((parallel_blocks) == 1) { + if (parallel_blocks == 1) { return; } From ef9e1593f33df5dbc8f89b927a8a7bd9dfc9e6dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 15 Apr 2024 16:05:07 +0200 Subject: [PATCH 091/100] flush softmax exp below threshold to 0 --- ggml-cuda/fattn.cu | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 36479b2170979..f6289822e0ea0 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -3,8 +3,9 @@ #include -#define FATTN_KQ_STRIDE 256 -#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. +#define FATTN_KQ_STRIDE 256 +#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. +#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. template // D == head size __launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1) @@ -338,10 +339,16 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - KQ_max_new = __hmax2(KQ_max_new, KQ2[j*(kqs_padded/2) + k]); + half2 val = KQ2[j*(kqs_padded/2) + k]; + val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = __hmax2(KQ_max_new, val); + KQ2[j*(kqs_padded/2) + k] = val; } KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); - KQ_max_scale[j0/nwarps] = h2exp(KQ_max[j0/nwarps] - KQ_max_new); + const half2 diff = KQ_max[j0/nwarps] - KQ_max_new; + KQ_max_scale[j0/nwarps] = h2exp(diff); + const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint *) &KQ_max_scale[j0/nwarps]) &= ftz_mask; KQ_max[j0/nwarps] = KQ_max_new; half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); @@ -350,8 +357,10 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + threadIdx.x; half2 val = KQ2[j*(kqs_padded/2) + k]; - val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); - val = h2exp(val - KQ_max[j0/nwarps]); + const half2 diff = val - KQ_max[j0/nwarps]; + val = h2exp(diff); + const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint *) &val) &= ftz_mask; KQ_rowsum_add += val; KQ2[j*(kqs_padded/2) + k] = val; } @@ -501,7 +510,10 @@ static __global__ void flash_attn_combine_results( float VKQ_denominator = 0.0f; #pragma unroll for (int l = 0; l < parallel_blocks; ++l) { - float KQ_max_scale = hexp(__low2half(meta[l]) - kqmax); + const half diff = __low2half(meta[l]) - kqmax; + float KQ_max_scale = hexp(diff); + const uint ftz_mask = 0xFFFFFFFF * (diff > __float2half(SOFTMAX_FTZ_THRESHOLD)); + *((uint *) &KQ_max_scale) &= ftz_mask; VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; VKQ_denominator += KQ_max_scale * __high2float(meta[l]); From a5b0e2dea018cfac5ee478aac0d780eef391b30b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 16 Apr 2024 15:58:21 +0200 Subject: [PATCH 092/100] store temp KQ in registers --- ggml-cuda/fattn.cu | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index f6289822e0ea0..b889cdb3b9b01 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -335,14 +335,21 @@ static __global__ void flash_attn_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; + half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + } + half2 KQ_max_new = KQ_max[j0/nwarps]; #pragma unroll for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - half2 val = KQ2[j*(kqs_padded/2) + k]; - val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); - KQ_max_new = __hmax2(KQ_max_new, val); - KQ2[j*(kqs_padded/2) + k] = val; + + KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); } KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); const half2 diff = KQ_max[j0/nwarps] - KQ_max_new; @@ -356,13 +363,12 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - half2 val = KQ2[j*(kqs_padded/2) + k]; - const half2 diff = val - KQ_max[j0/nwarps]; - val = h2exp(diff); + const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max[j0/nwarps]; + KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint *) &val) &= ftz_mask; - KQ_rowsum_add += val; - KQ2[j*(kqs_padded/2) + k] = val; + *((uint *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; + KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; + KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; } KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); From 0bc67dd1c81c15f04096985f9a85d81b431767b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 16 Apr 2024 16:22:29 +0200 Subject: [PATCH 093/100] Calculate KQ as FP32 if KQV has GGML_PREC_F32 --- ggml-cuda/fattn.cu | 286 +++++++++++++++++++++++++++++++++------------ 1 file changed, 213 insertions(+), 73 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index b889cdb3b9b01..dda344531335c 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -1,6 +1,7 @@ #include "common.cuh" #include "fattn.cuh" +#include #include #define FATTN_KQ_STRIDE 256 @@ -185,7 +186,8 @@ static __global__ void flash_attn_vec_ext_f16( #endif // FP16_AVAILABLE } -template // D == head size, VKQ_stride == num VKQ rows calculated in parallel +// D == head size, VKQ_stride == num VKQ rows calculated in parallel: +template __launch_bounds__(nwarps*WARP_SIZE, 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, @@ -229,7 +231,8 @@ static __global__ void flash_attn_ext_f16( typedef nvcuda::wmma::fragment frag_a_K; typedef nvcuda::wmma::fragment frag_a_V; typedef nvcuda::wmma::fragment frag_b; - typedef nvcuda::wmma::fragment frag_c; + typedef nvcuda::wmma::fragment frag_c_KQ; + typedef nvcuda::wmma::fragment frag_c_VKQ; constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. @@ -238,12 +241,14 @@ static __global__ void flash_attn_ext_f16( // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: constexpr int D_padded = D + 8; constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; + constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half2 * mask2 = (const half2 *) mask + ne11*(ic0/2); + const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; + const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); @@ -251,14 +256,29 @@ static __global__ void flash_attn_ext_f16( frag_b Q_b[D/16][ncols/frag_n]; // A single buffer for temporarily holding tiles of KQ and VKQ parts: - constexpr int mem_KQ = ncols*kqs_padded; + constexpr int mem_KQ = ncols*kqs_padded*kqar; constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; + float * KQ_f = (float *) KQ; half2 * KQ2 = (half2 *) KQ; - half2 KQ_rowsum[ncols/nwarps] = {{ 0.0f, 0.0f}}; - half2 KQ_max[ncols/nwarps] = {{-HALF_MAX_HALF, -HALF_MAX_HALF}}; - half2 KQ_max_scale[ncols/nwarps] = {{ 0.0f, 0.0f}}; + float KQ_rowsum_f[ncols/nwarps] = {0.0f}; + float KQ_max_f[ncols/nwarps]; + float KQ_max_scale_f[ncols/nwarps] = {0.0f}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_f[j] = -FLT_MAX/2.0f; + } + + half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + half2 KQ_max_h2[ncols/nwarps]; + half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF); + } __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; @@ -307,7 +327,7 @@ static __global__ void flash_attn_ext_f16( // Calculate tile of KQ: #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { - frag_c KQ_c[ncols/frag_n]; + frag_c_KQ KQ_c[ncols/frag_n]; #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); @@ -323,7 +343,7 @@ static __global__ void flash_attn_ext_f16( } #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::store_matrix_sync(KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); + nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); } } @@ -335,45 +355,90 @@ static __global__ void flash_attn_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; - half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; + if (std::is_same::value) { + float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; - KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; - } + KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; + } - half2 KQ_max_new = KQ_max[j0/nwarps]; + float KQ_max_new = KQ_max_f[j0/nwarps]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; - KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); - KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); - } - KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); - const half2 diff = KQ_max[j0/nwarps] - KQ_max_new; - KQ_max_scale[j0/nwarps] = h2exp(diff); - const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint *) &KQ_max_scale[j0/nwarps]) &= ftz_mask; - KQ_max[j0/nwarps] = KQ_max_new; + KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); + } + KQ_max_new = warp_reduce_max(KQ_max_new); + + const float diff = KQ_max_f[j0/nwarps] - KQ_max_new; + KQ_max_scale_f[j0/nwarps] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_max_scale_f[j0/nwarps] = 0.0f; + } + KQ_max_f[j0/nwarps] = KQ_max_new; + + float KQ_rowsum_add = 0.0f; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; - half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); + const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps]; + KQ_f_tmp[k0/WARP_SIZE] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_f_tmp[k0/WARP_SIZE] = 0.0f; + } + KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE]; + KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE]; + } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add; + } else { + half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + } + + half2 KQ_max_new = KQ_max_h2[j0/nwarps]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; - const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max[j0/nwarps]; - KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); + KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); + } + KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); + const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; + KQ_max_scale_h2[j0/nwarps] = h2exp(diff); const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; - KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; - KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; - } - KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + *((uint *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; + KQ_max_h2[j0/nwarps] = KQ_max_new; - // Scale previous KQ_rowsum to account for a potential increase in KQ_max: - KQ_rowsum[j0/nwarps] = KQ_max_scale[j0/nwarps]*KQ_rowsum[j0/nwarps] + KQ_rowsum_add; + half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps]; + KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); + const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; + KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; + KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; + } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add; + } } __syncthreads(); @@ -386,12 +451,12 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; nvcuda::wmma::load_matrix_sync( KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], - KQ + j0*kqs_padded + k, - kqs_padded); + KQ + j0*(kqar*kqs_padded) + k, + kqar*kqs_padded); } } - frag_c VKQ_c[D/VKQ_stride][ncols/frag_n]; + frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n]; #pragma unroll for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { #pragma unroll @@ -431,6 +496,14 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; + + half2 VKQ_scale; + if (std::is_same::value) { + VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]); + } else { + VKQ_scale = KQ_max_scale_h2[j0/nwarps]; + } + #pragma unroll for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; @@ -443,7 +516,7 @@ static __global__ void flash_attn_ext_f16( for (int l = 0; l < VKQ_ratio; ++l) { VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; } - VKQ2[j*(D_padded/2) + i] = KQ_max_scale[j0/nwarps]*VKQ2[j*(D_padded/2) + i] + VKQ_add; + VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add; } } @@ -458,14 +531,20 @@ static __global__ void flash_attn_ext_f16( } const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - const half KQ_rowsum_j = __low2half(KQ_rowsum[j0/nwarps]) + __high2half(KQ_rowsum[j0/nwarps]); + float KQ_rowsum_j; + if (std::is_same::value) { + KQ_rowsum_j = KQ_rowsum_f[j0/nwarps]; + } else { + KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); + } + #pragma unroll for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; if (i0 + WARP_SIZE > D && i >= D) { break; } - half dst_val = VKQ[j_VKQ*D_padded + i]; + float dst_val = VKQ[j_VKQ*D_padded + i]; if (parallel_blocks == 1) { dst_val /= KQ_rowsum_j; } @@ -476,7 +555,12 @@ static __global__ void flash_attn_ext_f16( continue; } - half2 dst_meta_val = KQ_max[j0/nwarps]; + half2 dst_meta_val; + if (std::is_same::value) { + reinterpret_cast(dst_meta_val.x) = KQ_max_f[j0/nwarps]; + } else { + dst_meta_val = KQ_max_h2[j0/nwarps]; + } reinterpret_cast(dst_meta_val.y) = KQ_rowsum_j; dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; } @@ -606,7 +690,7 @@ template void launch_fattn_vec_f16( CUDA_CHECK(cudaGetLastError()); } -template void launch_fattn_f16_impl( +template void launch_fattn_f16_impl( const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, ggml_cuda_pool & pool, cudaStream_t main_stream ) { @@ -626,7 +710,7 @@ template void launc float scale; memcpy(&scale, KQV->op_params, sizeof(float)); - flash_attn_ext_f16 + flash_attn_ext_f16 <<>> ( (const char *) Q->data, (const char *) K->data, @@ -657,21 +741,21 @@ template void launc CUDA_CHECK(cudaGetLastError()); } -template void launch_fattn_f16( +template void launch_fattn_f16( const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream ) { const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; if (4*blocks_num_pb1 < 2*nsm) { - launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); return; } if (2*blocks_num_pb1 < 2*nsm) { - launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); return; } - launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); } void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -696,15 +780,73 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_set_device(ctx.device); - if (Q->ne[1] == 1 && Q->ne[0] % WARP_SIZE == 0) { + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + + const int32_t precision = KQV->op_params[1]; + + if (precision != GGML_PREC_DEFAULT) { + if (Q->ne[1] <= 32 || Q->ne[0] > 128) { + constexpr int cols_per_block = 16; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + } else { + constexpr int cols_per_block = 32; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + // case 256: + // launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + // break; + default: + GGML_ASSERT(false); + break; + } + } + return; + } + + if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) { constexpr int parallel_blocks = 4; switch (Q->ne[0]) { case 64: launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; - case 96: - launch_fattn_vec_f16< 96, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; case 128: launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; @@ -718,23 +860,21 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } - const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { constexpr int cols_per_block = 8; constexpr int nwarps = 4; switch (Q->ne[0]) { case 64: - launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 96: - launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 128: - launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 256: - launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; default: GGML_ASSERT(false); @@ -748,22 +888,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst constexpr int nwarps = 4; switch (Q->ne[0]) { case 64: - launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 80: - launch_fattn_f16< 80, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 96: - launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 112: - launch_fattn_f16<112, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 128: - launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 256: - launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; default: GGML_ASSERT(false); @@ -776,22 +916,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst constexpr int nwarps = 4; switch (Q->ne[0]) { case 64: - launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 80: - launch_fattn_f16< 80, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 96: - launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 112: - launch_fattn_f16<112, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 128: - launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; case 256: - launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); break; default: GGML_ASSERT(false); From 2f538b9547ec2c2c67be0d41ed96d33c141354fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 17 Apr 2024 16:29:28 +0200 Subject: [PATCH 094/100] Add __hgt2_mask implementation for CUDA 11 --- ggml-cuda/common.cuh | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 989780dbce88c..ac6de643d668e 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -306,6 +306,13 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX } +#if CUDART_VERSION < 12000 +static __device__ __forceinline__ uint __hgt2_mask(const half2 a, const half2 b) { + const uint mask_low = 0x0000FFFF * ( __low2half(a) > __low2half(b)); + const uint mask_high = 0xFFFF0000 * (__high2half(a) > __high2half(b)); + return mask_low | mask_high; +} +#endif // CUDART_VERSION < 12000 #if defined(GGML_USE_HIPBLAS) #define __CUDA_ARCH__ 1300 From 87968de9a99d9820d20aeb5a15211f4ab5efde83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 17 Apr 2024 17:31:03 +0200 Subject: [PATCH 095/100] fix KQ FP32 precision fpr parallel_blocks > 1 --- ggml-cuda/fattn.cu | 48 +++++++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index dda344531335c..4cf2907e8d10c 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -15,8 +15,8 @@ static __global__ void flash_attn_vec_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, - float * __restrict__ dst, - half2 * __restrict__ dst_meta, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, const float scale, const int ne00, const int ne01, @@ -180,7 +180,7 @@ static __global__ void flash_attn_vec_ext_f16( if (parallel_blocks == 1 || tid != 0) { return; } - dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_half2(kqmax, kqsum); + dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax, kqsum); #else NO_DEVICE_CODE; #endif // FP16_AVAILABLE @@ -194,8 +194,8 @@ static __global__ void flash_attn_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, - float * __restrict__ dst, - half2 * __restrict__ dst_meta, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, const float scale, const int ne00, const int ne01, @@ -555,13 +555,13 @@ static __global__ void flash_attn_ext_f16( continue; } - half2 dst_meta_val; + float2 dst_meta_val; if (std::is_same::value) { - reinterpret_cast(dst_meta_val.x) = KQ_max_f[j0/nwarps]; + dst_meta_val.x = KQ_max_f[j0/nwarps]; } else { - dst_meta_val = KQ_max_h2[j0/nwarps]; + dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); } - reinterpret_cast(dst_meta_val.y) = KQ_rowsum_j; + dst_meta_val.y = KQ_rowsum_j; dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; } #else @@ -572,8 +572,8 @@ static __global__ void flash_attn_ext_f16( template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_combine_results( - const float * __restrict__ VKQ_parts, - const half2 * __restrict__ VKQ_meta, + const float * __restrict__ VKQ_parts, + const float2 * __restrict__ VKQ_meta, float * __restrict__ dst) { #if FP16_AVAILABLE VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; @@ -583,30 +583,30 @@ static __global__ void flash_attn_combine_results( const int tid = threadIdx.x; __builtin_assume(tid < D); - __shared__ half2 meta[parallel_blocks]; - if (tid < parallel_blocks) { - meta[threadIdx.x] = VKQ_meta[blockIdx.y*parallel_blocks + tid]; + __shared__ float2 meta[parallel_blocks]; + if (tid < 2*parallel_blocks) { + ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid]; } __syncthreads(); - half kqmax = __low2half(meta[0]); + float kqmax = meta[0].x; #pragma unroll for (int l = 1; l < parallel_blocks; ++l) { - kqmax = __hmax(kqmax, __low2half(meta[l])); + kqmax = max(kqmax, meta[l].x); } float VKQ_numerator = 0.0f; float VKQ_denominator = 0.0f; #pragma unroll for (int l = 0; l < parallel_blocks; ++l) { - const half diff = __low2half(meta[l]) - kqmax; - float KQ_max_scale = hexp(diff); - const uint ftz_mask = 0xFFFFFFFF * (diff > __float2half(SOFTMAX_FTZ_THRESHOLD)); + const float diff = meta[l].x - kqmax; + const float KQ_max_scale = expf(diff); + const uint ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); *((uint *) &KQ_max_scale) &= ftz_mask; VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; - VKQ_denominator += KQ_max_scale * __high2float(meta[l]); + VKQ_denominator += KQ_max_scale * meta[l].y; } dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; @@ -643,8 +643,8 @@ template void launch_fattn_vec_f16( const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, ggml_cuda_pool & pool, cudaStream_t main_stream ) { - ggml_cuda_pool_alloc dst_tmp(pool); - ggml_cuda_pool_alloc dst_tmp_meta(pool); + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); @@ -694,8 +694,8 @@ template dst_tmp(pool); - ggml_cuda_pool_alloc dst_tmp_meta(pool); + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); From 260cdb2d082d1658d1c6b693c4cbf77754873886 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Apr 2024 14:28:19 +0300 Subject: [PATCH 096/100] llama-bench : add -fa,--flash-attn arg --- examples/llama-bench/llama-bench.cpp | 30 +++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 8b532c8b6a98a..95c3095dd04da 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -174,6 +174,7 @@ struct cmd_params { std::vector split_mode; std::vector main_gpu; std::vector no_kv_offload; + std::vector flash_attn; std::vector> tensor_split; std::vector use_mmap; std::vector embeddings; @@ -195,6 +196,7 @@ static const cmd_params cmd_params_defaults = { /* split_mode */ {LLAMA_SPLIT_MODE_LAYER}, /* main_gpu */ {0}, /* no_kv_offload */ {false}, + /* flash_attn */ {false}, /* tensor_split */ {std::vector(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, @@ -220,6 +222,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -sm, --split-mode (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str()); printf(" -mg, --main-gpu (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); + printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); printf(" -ts, --tensor-split (default: 0)\n"); @@ -393,6 +396,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = split(argv[i], split_delim); params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end()); + } else if (arg == "-fa" || arg == "--flash-attn") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = split(argv[i], split_delim); + params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end()); } else if (arg == "-mmp" || arg == "--mmap") { if (++i >= argc) { invalid_param = true; @@ -477,6 +487,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.split_mode.empty()) { params.split_mode = cmd_params_defaults.split_mode; } if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; } if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; } + if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; } if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; } if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; } @@ -498,6 +509,7 @@ struct cmd_params_instance { llama_split_mode split_mode; int main_gpu; bool no_kv_offload; + bool flash_attn; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -532,6 +544,7 @@ struct cmd_params_instance { cparams.type_k = type_k; cparams.type_v = type_v; cparams.offload_kqv = !no_kv_offload; + cparams.flash_attn = flash_attn; cparams.embeddings = embeddings; return cparams; @@ -554,6 +567,7 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & tk : params.type_k) for (const auto & tv : params.type_v) for (const auto & nkvo : params.no_kv_offload) + for (const auto & fa : params.flash_attn) for (const auto & nt : params.n_threads) { for (const auto & n_prompt : params.n_prompt) { if (n_prompt == 0) { @@ -572,6 +586,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .split_mode = */ sm, /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, + /* .flash_attn = */ fa, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -596,6 +611,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .split_mode = */ sm, /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, + /* .flash_attn = */ fa, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -633,6 +649,7 @@ struct test { llama_split_mode split_mode; int main_gpu; bool no_kv_offload; + bool flash_attn; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -657,6 +674,7 @@ struct test { split_mode = inst.split_mode; main_gpu = inst.main_gpu; no_kv_offload = inst.no_kv_offload; + flash_attn = inst.flash_attn; tensor_split = inst.tensor_split; use_mmap = inst.use_mmap; embeddings = inst.embeddings; @@ -731,7 +749,7 @@ struct test { "n_batch", "n_ubatch", "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", - "main_gpu", "no_kv_offload", + "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "use_mmap", "embeddings", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", @@ -753,7 +771,7 @@ struct test { } if (field == "cuda" || field == "opencl" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || - field == "use_mmap" || field == "embeddings") { + field == "flash_attn" || field == "use_mmap" || field == "embeddings") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -787,7 +805,7 @@ struct test { std::to_string(n_batch), std::to_string(n_ubatch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), - std::to_string(main_gpu), std::to_string(no_kv_offload), + std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), @@ -955,6 +973,9 @@ struct markdown_printer : public printer { if (field == "no_kv_offload") { return "nkvo"; } + if (field == "flash_attn") { + return "fa"; + } if (field == "use_mmap") { return "mmap"; } @@ -1001,6 +1022,9 @@ struct markdown_printer : public printer { if (params.no_kv_offload.size() > 1 || params.no_kv_offload != cmd_params_defaults.no_kv_offload) { fields.emplace_back("no_kv_offload"); } + if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) { + fields.emplace_back("flash_attn"); + } if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) { fields.emplace_back("tensor_split"); } From 105332cc17b8bd8f3989606b59489ed85eefe04f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Apr 2024 14:33:07 +0300 Subject: [PATCH 097/100] metal : add BS=1 kernel for flash attention (#6508) * metal : add BS=1 kernel for flash attention (wip) * metal : support more than 1 warps * metal : opts * metal : opt * metal : switch to parallel reduce * metal : reduce registers * metal : simplify * metal : initial FA vec kernel --- ggml-metal.m | 119 ++++++++++++++------ ggml-metal.metal | 274 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 361 insertions(+), 32 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index e8613dbee04d2..407f94eb224cf 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -183,6 +183,8 @@ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -621,12 +623,14 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); @@ -2563,19 +2567,32 @@ static enum ggml_status ggml_metal_graph_compute( id pipeline = nil; - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; - default: - { - GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - GGML_ASSERT(false && "add template specialization for this size"); - } + if (ne01 > 1 || (ne00%128 != 0)) { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ASSERT(false && "add template specialization for this size"); + } + } + } else { + switch (ne00) { + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ASSERT(false && "add template specialization for this size"); + } + } } // TODO: extend if necessary @@ -2609,24 +2626,62 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + // half8x8 kernel + if (ne01 > 1 || (ne00%128 != 0)) { + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - GGML_ASSERT(nqptg <= 32); - GGML_ASSERT(nqptg % 8 == 0); - GGML_ASSERT(ncpsg % 32 == 0); + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); - // simdgroups per threadgroup (a.k.a. warps) - // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; + // simdgroups per threadgroup (a.k.a. warps) + // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; - const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); + const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); - //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); - GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:smem atIndex:0]; + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:smem atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else { + // half1x4 kernel + const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 1 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + // simdgroups per threadgroup (a.k.a. warps) + // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) + const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + + int64_t nsg = 1; + while (nsg <= nsgt) { + nsg *= 2; + } + nsg /= 2; + + // require power of 2 + //{ + // int64_t nsgm = 1; + // while (nsgm < nsg) { + // nsgm *= 2; + // } + // GGML_ASSERT(nsg == nsgm); + //} + + const size_t smem = (nqptg*(ne00 + nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index be47db86ebbd3..de6072e93470a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2494,6 +2494,280 @@ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>; +#define HALF_MAX_HALF half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. + +template // head size, queries per threadgroup, cache items per threadgroup +kernel void kernel_flash_attn_ext_vec_f16( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant float & scale, + threadgroup half * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]; + + const short D4 = D/4; + const short D8 = D/8; + const short NW = N_SIMDWIDTH; + const short SH = (C + 1); // shared memory per simdgroup in (half) + + const short T = D + nsg*SH; // shared memory size per query in (half) + const short T4 = T/4; // shared memory size per query in (half4) + + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*SH + 1*D); // same as above but in half4 + threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + half4 lo[D4/NW]; + + // load heads from Q to shared memory + device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + + for (short i = tiisg; i < D4; i += NW) { + if (iq1 < ne01) { + sq4[i] = (half4) q4[i]; + } else { + sq4[i] = 0.0h; + } + } + + // zero out lo + for (short i = tiisg; i < D4; i += NW) { + lo[i/NW] = 0.0h; + } + + // zero out shared memory SH + for (short i = tiisg; i < SH/4; i += NW) { + ss4[i] = 0.0h; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + half S = { 0.0h }; + half M = { -HALF_MAX_HALF }; + + // assume K and V are same shape + const short ne22 = ne12; + const short ne23 = ne13; + + const uint nb21 = nb11; + const uint nb22 = nb12; + const uint nb23 = nb13; + + // broadcast + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; + + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; + + // k indices + const short ik2 = iq2 / rk2; + const short ik3 = iq3 / rk3; + + // v indices + const short iv2 = iq2 / rv2; + const short iv3 = iq3 / rv3; + + // load the queries from shared memory into local memory + half4 mq[D4]; + + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + mq[i] = sq4[i]; + } + + // pointer to the mask + device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + + // Q*K^T + { +#pragma unroll + for (short cc = 0; cc < C/4; ++cc) { + half4 mqk = { 0.0h }; + + device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); + +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + + half4x4 mk; + mk[0] = pk4[i + 0*(nb11/8)]; + mk[1] = pk4[i + 1*(nb11/8)]; + mk[2] = pk4[i + 2*(nb11/8)]; + mk[3] = pk4[i + 3*(nb11/8)]; + + mqk += mq[i] * mk; + } + + // reduce the results from the threads in the simdgroup + mqk += simd_shuffle_down(mqk, 16); + mqk += simd_shuffle_down(mqk, 8); + mqk += simd_shuffle_down(mqk, 4); + mqk += simd_shuffle_down(mqk, 2); + mqk += simd_shuffle_down(mqk, 1); + + // mqk = mqk*scale + mask + if (tiisg == 0) { + half4 mm = mp4[ic/4 + cc]; + mqk = mqk*scale + mm; + + ss4[cc] = mqk; + } + } + } + + // online softmax + { + const short p = tiisg; + + const half m = M; + const half s = ss[p]; + + M = simd_max(max(M, s)); + + const half ms = exp(m - M); + const half vs = exp(s - M); + + S = S*ms + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[p] = vs; + + // O = diag(ms)*O +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + lo[i/NW] *= ms; + } + } + + // O = O + (Q*K^T)*V + { +#pragma unroll + for (short cc = 0; cc < C/4; ++cc) { + device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23)); + +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; + lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; + lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; + lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; + } + } + } + + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + } + + // store results to shared memory + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = lo[ii/NW]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // parallel reduce + for (short r = nsg/2; r > 0; r >>= 1) { + if (sgitg < r) { + const half S0 = ss[ 0]; + const half S1 = ss[r*SH + 0]; + + const half M0 = ss[ 1]; + const half M1 = ss[r*SH + 1]; + + const half M = max(M0, M1); + + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); + + const half S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + const half S = ss[0]; + + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S; + } + } +} + +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 32>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 32>; + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, From e11b2e6e1e18522ca7cf129600875a0f6fb9307d Mon Sep 17 00:00:00 2001 From: Ren Xuancheng Date: Thu, 18 Apr 2024 19:38:04 +0800 Subject: [PATCH 098/100] Qwen2 : assume tied weights if lm_head/output weights is missing (#6738) --- llama.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 8c1446296fe35..05094fec197d6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5184,7 +5184,13 @@ static bool llm_load_tensors( // output { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + ml.n_created--; // artificial tensor + ml.size_data += ggml_nbytes(model.output); + } } for (int i = 0; i < n_layer; ++i) { From c16a7c26882669a0d2ed7ef592cde3b0227248f2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Apr 2024 20:08:52 +0300 Subject: [PATCH 099/100] metal : use F32 attention accumulators --- ggml-metal.m | 15 +---- ggml-metal.metal | 156 +++++++++++++++++++++++------------------------ ggml.c | 3 +- 3 files changed, 81 insertions(+), 93 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 407f94eb224cf..f4a831b52a9f9 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2636,10 +2636,9 @@ static enum ggml_status ggml_metal_graph_compute( GGML_ASSERT(ncpsg % 32 == 0); // simdgroups per threadgroup (a.k.a. warps) - // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; - const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); + const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); @@ -2656,7 +2655,6 @@ static enum ggml_status ggml_metal_graph_compute( GGML_ASSERT(ncpsg % 32 == 0); // simdgroups per threadgroup (a.k.a. warps) - // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); int64_t nsg = 1; @@ -2665,16 +2663,7 @@ static enum ggml_status ggml_metal_graph_compute( } nsg /= 2; - // require power of 2 - //{ - // int64_t nsgm = 1; - // while (nsgm < nsg) { - // nsgm *= 2; - // } - // GGML_ASSERT(nsg == nsgm); - //} - - const size_t smem = (nqptg*(ne00 + nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); + const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); diff --git a/ggml-metal.metal b/ggml-metal.metal index de6072e93470a..36b87b2f0750a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2169,12 +2169,13 @@ kernel void kernel_flash_attn_ext_f16( const short NW = N_SIMDWIDTH; const short SH = (C + Q); // shared memory per simdgroup in (half) - const short T = D + nsg*SH; // shared memory size per query in (half) + const short T = D + 2*nsg*SH; // shared memory size per query in (half) + const short TF = T/2; // shared memory size per query in (float) const short T4 = T/4; // shared memory size per query in (half4) - threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 - threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) simdgroup_half8x8 lo[Q8][D8]; @@ -2202,15 +2203,15 @@ kernel void kernel_flash_attn_ext_f16( // zero out shared memory SH for (short j = 0; j < Q; ++j) { for (short i = tiisg; i < SH; i += NW) { - ss[j*T + i] = 0.0h; + ss[j*TF + i] = 0.0f; } } threadgroup_barrier(mem_flags::mem_threadgroup); { - half S[Q] = { [0 ... Q-1] = 0.0h }; - half M[Q] = { [0 ... Q-1] = -INFINITY }; + float S[Q] = { [0 ... Q-1] = 0.0h }; + float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 }; // assume K and V are same shape const short ne22 = ne12; @@ -2248,7 +2249,7 @@ kernel void kernel_flash_attn_ext_f16( device const half * mp = (device const half *) (mask + iq1*nb31); // prepare diagonal scale matrix - simdgroup_half8x8 mscale(scale); + simdgroup_float8x8 mscale(scale); // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns @@ -2261,9 +2262,9 @@ kernel void kernel_flash_attn_ext_f16( // Q*K^T { for (short cc = 0; cc < C/8; ++cc) { - simdgroup_half8x8 mqk[Q8]; + simdgroup_float8x8 mqk[Q8]; for (short j = 0; j < Q8; ++j) { - mqk[j] = make_filled_simdgroup_matrix(0.h); + mqk[j] = make_filled_simdgroup_matrix(0.h); } device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); @@ -2283,48 +2284,48 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); - simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false); + simdgroup_store(mqk[j], ss + 8*j*TF + 8*cc, TF, 0, false); } } } // used to detect blocks full of -INF - half smax = -INFINITY; + float smax = -INFINITY; // online softmax if (C == 32) { - half ms[Q]; + float ms[Q]; for (short j = 0; j < Q; ++j) { const short p = tiisg; - const half m = M[j]; - const half s = ss[j*T + p]; + const float m = M[j]; + const float s = ss[j*TF + p]; smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); - ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + ms[j] = exp(m - M[j]); + const float vs = exp(s - M[j]); S[j] = S[j]*ms[j] + simd_sum(vs); // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + ss[j*TF + p] = vs; } // create a QxQ diagonal matrix for rescaling the output if (tiisg < Q) { - ss[tiisg*T + C + tiisg] = ms[tiisg]; + ss[tiisg*TF + C + tiisg] = ms[tiisg]; } } else { - half ms[Q]; + float ms[Q]; for (short j = 0; j < Q; ++j) { - const half m = M[j]; + const float m = M[j]; for (short p = tiisg; p < C; p += NW) { - const half s = ss[j*T + p]; + const float s = ss[j*TF + p]; smax = max(smax, s); M[j] = max(M[j], s); @@ -2333,20 +2334,20 @@ kernel void kernel_flash_attn_ext_f16( smax = simd_max(smax); M[j] = simd_max(M[j]); - ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); + ms[j] = exp(m - M[j]); // local sum - half ls = 0.0h; + float ls = 0.0h; for (short p = tiisg; p < C; p += NW) { - const half s = ss[j*T + p]; + const float s = ss[j*TF + p]; - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + const float vs = exp(s - M[j]); ls += vs; // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + ss[j*TF + p] = vs; } S[j] = S[j]*ms[j] + simd_sum(ls); @@ -2354,7 +2355,7 @@ kernel void kernel_flash_attn_ext_f16( // create a QxQ diagonal matrix for rescaling the output if (tiisg < Q) { - ss[tiisg*T + C + tiisg] = ms[tiisg]; + ss[tiisg*TF + C + tiisg] = ms[tiisg]; } } @@ -2365,8 +2366,8 @@ kernel void kernel_flash_attn_ext_f16( // O = diag(ms)*O for (short j = 0; j < Q8; ++j) { - simdgroup_half8x8 mm; - simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false); + simdgroup_float8x8 mm; + simdgroup_load(mm, ss + 8*j*TF + C + 8*j, TF, 0, false); for (short i = 0; i < D8; ++i) { simdgroup_multiply(lo[j][i], mm, lo[j][i]); @@ -2383,8 +2384,8 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); for (short j = 0; j < Q8; ++j) { - simdgroup_half8x8 mv; - simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false); + simdgroup_float8x8 mv; + simdgroup_load(mv, ss + 8*j*TF + 8*cc, TF, 0, false); simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]); } @@ -2396,16 +2397,16 @@ kernel void kernel_flash_attn_ext_f16( // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (short j = 0; j < Q; ++j) { if (tiisg == 0) { - ss[j*T + 0] = S[j]; - ss[j*T + 1] = M[j]; + ss[j*TF + 0] = S[j]; + ss[j*TF + 1] = M[j]; } } } // reduce the warps sequentially for (short sg = 1; sg < nsg; ++sg) { - half S = { 0.0h }; - half M = { -INFINITY }; + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; threadgroup_barrier(mem_flags::mem_threadgroup); @@ -2423,36 +2424,36 @@ kernel void kernel_flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (sgitg == 0) { for (short j = 0; j < Q; ++j) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + sg*SH + 0]; + const float S0 = ss[j*TF + 0]; + const float S1 = ss[j*TF + sg*SH + 0]; - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + sg*SH + 1]; + const float M0 = ss[j*TF + 1]; + const float M1 = ss[j*TF + sg*SH + 1]; M = max(M0, M1); - const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M); - const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M); + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); S = S0*ms0 + S1*ms1; if (tiisg == 0) { - ss[j*T + 0] = S; - ss[j*T + 1] = M; + ss[j*TF + 0] = S; + ss[j*TF + 1] = M; - ss[j*T + C + j ] = ms0; - ss[j*T + C + j + sg*SH] = ms1; + ss[j*TF + C + j ] = ms0; + ss[j*TF + C + j + sg*SH] = ms1; } } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 for (short j = 0; j < Q8; ++j) { simdgroup_half8x8 t; - simdgroup_half8x8 ms0; - simdgroup_half8x8 ms1; + simdgroup_float8x8 ms0; + simdgroup_float8x8 ms1; - simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false); - simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false); + simdgroup_load(ms0, ss + 8*j*TF + C + 8*j, TF, 0, false); + simdgroup_load(ms1, ss + 8*j*TF + C + 8*j + sg*SH, TF, 0, false); for (short i = 0; i < D8; ++i) { simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false); @@ -2478,7 +2479,7 @@ kernel void kernel_flash_attn_ext_f16( // final rescale with 1/S and store to global memory if (sgitg == 0) { for (short j = 0; j < Q && iq1 + j < ne01; ++j) { - const half S = ss[j*T + 0]; + const float S = ss[j*TF + 0]; for (short i = tiisg; i < D4; i += NW) { dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; @@ -2494,8 +2495,6 @@ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>; -#define HALF_MAX_HALF half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. - template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_vec_f16( device const char * q, @@ -2539,18 +2538,16 @@ kernel void kernel_flash_attn_ext_vec_f16( const short iq1 = tgpig[0]; const short D4 = D/4; - const short D8 = D/8; const short NW = N_SIMDWIDTH; const short SH = (C + 1); // shared memory per simdgroup in (half) - const short T = D + nsg*SH; // shared memory size per query in (half) - const short T4 = T/4; // shared memory size per query in (half4) + const short T = D + 2*nsg*SH; // shared memory size per query in (half) - threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 - threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*SH + 1*D); // same as above but in half4 - threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results + //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4 + threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) half4 lo[D4/NW]; @@ -2579,8 +2576,8 @@ kernel void kernel_flash_attn_ext_vec_f16( threadgroup_barrier(mem_flags::mem_threadgroup); { - half S = { 0.0h }; - half M = { -HALF_MAX_HALF }; + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; // assume K and V are same shape const short ne22 = ne12; @@ -2628,7 +2625,7 @@ kernel void kernel_flash_attn_ext_vec_f16( { #pragma unroll for (short cc = 0; cc < C/4; ++cc) { - half4 mqk = { 0.0h }; + float4 mqk = { 0.0h }; device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); @@ -2642,7 +2639,7 @@ kernel void kernel_flash_attn_ext_vec_f16( mk[2] = pk4[i + 2*(nb11/8)]; mk[3] = pk4[i + 3*(nb11/8)]; - mqk += mq[i] * mk; + mqk += (float4) (mq[i] * mk); } // reduce the results from the threads in the simdgroup @@ -2654,7 +2651,7 @@ kernel void kernel_flash_attn_ext_vec_f16( // mqk = mqk*scale + mask if (tiisg == 0) { - half4 mm = mp4[ic/4 + cc]; + float4 mm = (float4) mp4[ic/4 + cc]; mqk = mqk*scale + mm; ss4[cc] = mqk; @@ -2666,13 +2663,13 @@ kernel void kernel_flash_attn_ext_vec_f16( { const short p = tiisg; - const half m = M; - const half s = ss[p]; + const float m = M; + const float s = ss[p]; M = simd_max(max(M, s)); - const half ms = exp(m - M); - const half vs = exp(s - M); + const float ms = exp(m - M); + const float vs = exp(s - M); S = S*ms + simd_sum(vs); @@ -2696,6 +2693,7 @@ kernel void kernel_flash_attn_ext_vec_f16( #pragma unroll for (short ii = 0; ii < D4; ii += NW) { const short i = ii + tiisg; + lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; @@ -2724,18 +2722,18 @@ kernel void kernel_flash_attn_ext_vec_f16( // parallel reduce for (short r = nsg/2; r > 0; r >>= 1) { if (sgitg < r) { - const half S0 = ss[ 0]; - const half S1 = ss[r*SH + 0]; + const float S0 = ss[ 0]; + const float S1 = ss[r*SH + 0]; - const half M0 = ss[ 1]; - const half M1 = ss[r*SH + 1]; + const float M0 = ss[ 1]; + const float M1 = ss[r*SH + 1]; - const half M = max(M0, M1); + const float M = max(M0, M1); - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); - const half S = S0*ms0 + S1*ms1; + const float S = S0*ms0 + S1*ms1; if (tiisg == 0) { ss[0] = S; @@ -2756,7 +2754,7 @@ kernel void kernel_flash_attn_ext_vec_f16( // final rescale with 1/S and store to global memory if (sgitg == 0) { - const half S = ss[0]; + const float S = ss[0]; for (short ii = 0; ii < D4; ii += NW) { short i = ii + tiisg; diff --git a/ggml.c b/ggml.c index f50cb948daeab..f2bbfa6f2273b 100644 --- a/ggml.c +++ b/ggml.c @@ -14882,12 +14882,13 @@ static void ggml_compute_forward_flash_attn_ext( struct ggml_tensor * dst) { switch (dst->op_params[1]) { case GGML_PREC_DEFAULT: + case GGML_PREC_F32: { + // uses F32 accumulators ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); } break; default: { - // TODO: implement F32 precision GGML_ASSERT(false); } break; } From 9ca869876eddfb0c51d3b80440a11abdd6a3fe18 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Apr 2024 21:41:32 +0300 Subject: [PATCH 100/100] batched-bench : add fattn arg --- examples/batched-bench/batched-bench.cpp | 28 ++++++++++++++---------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 1e34de620a41b..2924d8116f44f 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -32,7 +32,7 @@ int main(int argc, char ** argv) { gpt_params params; if (argc == 1 || argv[1][0] == '-') { - printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [IS_PP_SHARED] [NGL] \n" , argv[0]); + printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [FATTN] [IS_PP_SHARED] [NGL] \n" , argv[0]); printf(" , and PL are comma-separated lists of numbers without spaces\n\n"); printf(" example: %s ggml-model-f16.gguf 2048 2048 512 0 999 128,256,512 128,256 1,2,4,8,16,32\n\n", argv[0]); return 1 ; @@ -41,6 +41,7 @@ int main(int argc, char ** argv) { int n_kv_max = 2048; int n_batch = 2048; int n_ubatch = 512; + bool flash_attn = false; int is_pp_shared = 0; int n_gpu_layers = 0; @@ -66,23 +67,27 @@ int main(int argc, char ** argv) { } if (argc >= 6) { - is_pp_shared = std::atoi(argv[5]); + flash_attn = std::atoi(argv[5]); } if (argc >= 7) { - n_gpu_layers = std::atoi(argv[6]); + is_pp_shared = std::atoi(argv[6]); } if (argc >= 8) { - n_pp = parse_list(argv[7]); + n_gpu_layers = std::atoi(argv[7]); } if (argc >= 9) { - n_tg = parse_list(argv[8]); + n_pp = parse_list(argv[8]); } if (argc >= 10) { - n_pl = parse_list(argv[9]); + n_tg = parse_list(argv[9]); + } + + if (argc >= 11) { + n_pl = parse_list(argv[10]); } // init LLM @@ -108,10 +113,11 @@ int main(int argc, char ** argv) { llama_context_params ctx_params = llama_context_default_params(); - ctx_params.seed = 1234; - ctx_params.n_ctx = n_kv_max; - ctx_params.n_batch = n_batch; - ctx_params.n_ubatch = n_ubatch; + ctx_params.seed = 1234; + ctx_params.n_ctx = n_kv_max; + ctx_params.n_batch = n_batch; + ctx_params.n_ubatch = n_ubatch; + ctx_params.flash_attn = flash_attn; ctx_params.n_threads = params.n_threads; ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; @@ -169,7 +175,7 @@ int main(int argc, char ** argv) { } LOG_TEE("\n"); - LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch); + LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, flash_attn, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch); LOG_TEE("\n"); LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");