From 9b4294da9c9ff620d5c5ee38128d4f66c1315d9d Mon Sep 17 00:00:00 2001 From: pculliton Date: Tue, 30 Jul 2024 21:19:09 -0400 Subject: [PATCH 1/2] Adding Gemma 2 2B configs Updates to Q scaling and Gemma 2 model sizes to match v2 2B model. --- src/llama.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/llama.cpp b/src/llama.cpp index a207451f58507..fba6b17572d71 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4969,6 +4969,7 @@ static void llm_load_hparams( hparams.attn_soft_cap = true; switch (hparams.n_layer) { + case 26: model.type = e_model::MODEL_2B; break; case 42: model.type = e_model::MODEL_9B; break; case 46: model.type = e_model::MODEL_27B; break; default: model.type = e_model::MODEL_UNKNOWN; @@ -11736,6 +11737,7 @@ struct llm_build_context { // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e switch (model.type) { + case e_model::MODEL_2B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); break; case e_model::MODEL_9B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); break; case e_model::MODEL_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break; default: GGML_ABORT("fatal error"); From 36e6685d932f43d65dc1bdf2a5efb3145d0bd000 Mon Sep 17 00:00:00 2001 From: pculliton Date: Wed, 31 Jul 2024 11:00:16 -0400 Subject: [PATCH 2/2] Update src/llama.cpp Co-authored-by: slaren --- src/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index fba6b17572d71..e6f303d31b3bf 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -11737,7 +11737,7 @@ struct llm_build_context { // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e switch (model.type) { - case e_model::MODEL_2B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); break; + case e_model::MODEL_2B: case e_model::MODEL_9B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); break; case e_model::MODEL_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break; default: GGML_ABORT("fatal error");