From f15e933fb1ad24e26a28ba3106a6990182a10bd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Thu, 16 May 2024 13:42:24 +0200 Subject: [PATCH 1/2] llama : use n_embd_head_v instead of n_embd_head_k when reshaping kqv --- llama.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index 0ef756a52860c..50109e95fe95f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6655,7 +6655,7 @@ static struct ggml_tensor * llm_build_kqv( ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); } - cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens); + cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens); } else { struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); cb(kq, "kq", il); @@ -6700,7 +6700,7 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); cb(kqv_merged, "kqv_merged", il); - cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); + cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens); cb(cur, "kqv_merged_cont", il); } From 886f89da697515ea05c5e93f3a071f5165f0b4ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Fri, 17 May 2024 09:25:36 +0200 Subject: [PATCH 2/2] llama : use n_embd_v_gqa and n_embd_head_v instead of n_embd_k_gqa and n_embd_head_k when making a view of cached value vectors. --- llama.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index 50109e95fe95f..aff3816590d25 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6622,6 +6622,7 @@ static struct ggml_tensor * llm_build_kqv( const int64_t n_embd_head_k = hparams.n_embd_head_k; const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_head_v = hparams.n_embd_head_v; + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3); cb(q, "q", il); @@ -6644,8 +6645,8 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], n_embd_head_v, n_kv, n_head_kv, - ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv.v_l[il]->type, n_embd_head_k), + ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv.v_l[il]->type, n_embd_head_v), 0); cb(v, "v", il);