diff --git a/common/arg.cpp b/common/arg.cpp index 8531f0871d44a..84dc6841e3866 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -811,6 +811,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.flash_attn = true; } ).set_env("LLAMA_ARG_FLASH_ATTN")); + add_opt(common_arg( + {"-mla", "--mla-attn"}, + string_format("enable Multi-head Latent Attention (default: %s)", params.mla_attn ? "enabled" : "disabled"), + [](common_params & params) { + params.mla_attn = true; + } + ).set_env("LLAMA_ARG_MLA_ATTN")); add_opt(common_arg( {"-p", "--prompt"}, "PROMPT", "prompt to start generation with; for system message, use -sys", diff --git a/common/common.cpp b/common/common.cpp index 6448b7b03d6d2..b3438f2646bdf 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1132,6 +1132,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; cparams.flash_attn = params.flash_attn; + cparams.mla_attn = params.mla_attn; cparams.no_perf = params.no_perf; if (params.reranking) { diff --git a/common/common.h b/common/common.h index 1c0f199774976..207732a9957a8 100644 --- a/common/common.h +++ b/common/common.h @@ -325,6 +325,7 @@ struct common_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention + bool mla_attn = false; // MLA attention for deepseek2 bool no_perf = false; // disable performance metrics bool ctx_shift = true; // context shift on inifinite text generation diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 6358a94e9b55f..d13196696ce33 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4141,6 +4141,78 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter else: return [] + n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) + n_head_kv = self.hparams["num_key_value_heads"] + qk_nope_head_dim = self.hparams["qk_nope_head_dim"] + qk_rope_head_dim = self.hparams["qk_rope_head_dim"] + v_head_dim = self.hparams["v_head_dim"] + kv_lora_rank = self.hparams["kv_lora_rank"] + + # (v2-lite) split q_proj into: q_proj and q_mqa_proj + if name.endswith("q_proj.weight"): + assert data_torch.shape[0] == n_head_kv * (qk_nope_head_dim + qk_rope_head_dim) + assert data_torch.shape[1] == n_embed + + q_proj_with_mqa = data_torch.view(n_head_kv, qk_nope_head_dim + qk_rope_head_dim, n_embed) + q_proj, q_mqa_proj = torch.split(q_proj_with_mqa, [qk_nope_head_dim, qk_rope_head_dim], dim = 1) + + q_proj = q_proj.reshape(n_head_kv * qk_nope_head_dim, n_embed) + q_mqa_proj = q_mqa_proj.reshape(n_head_kv * qk_rope_head_dim, n_embed) + + return [ + (self.map_tensor_name(name), q_proj), + (self.map_tensor_name(name.replace("q_proj", "q_mqa_proj")), q_mqa_proj) + ] + + # (v2/v3/r1) split q_b_proj into: q_b_proj and q_b_mqa_proj + if name.endswith("q_b_proj.weight"): + q_lora_rank = self.hparams["q_lora_rank"] + + assert data_torch.shape[0] == n_head_kv * (qk_nope_head_dim + qk_rope_head_dim) + assert data_torch.shape[1] == q_lora_rank + + q_b_proj_with_mqa = data_torch.view(n_head_kv, qk_nope_head_dim + qk_rope_head_dim, q_lora_rank) + q_b_proj, q_b_mqa_proj = torch.split(q_b_proj_with_mqa, [qk_nope_head_dim, qk_rope_head_dim], dim = 1) + + q_b_proj = q_b_proj.reshape(n_head_kv * qk_nope_head_dim, q_lora_rank) + q_b_mqa_proj = q_b_mqa_proj.reshape(n_head_kv * qk_rope_head_dim, q_lora_rank) + + return [ + (self.map_tensor_name(name), q_b_proj), + (self.map_tensor_name(name.replace("q_b_proj", "q_b_mqa_proj")), q_b_mqa_proj) + ] + + # split kv_a_proj_with_mqa into: kv_a_proj and k_mqa_proj + if name.endswith("kv_a_proj_with_mqa.weight"): + assert data_torch.shape[0] == kv_lora_rank + qk_rope_head_dim + assert data_torch.shape[1] == n_embed + + kv_a_proj_with_mqa = data_torch.view(kv_lora_rank + qk_rope_head_dim, n_embed) + kv_a_proj, k_mqa_proj = torch.split(kv_a_proj_with_mqa, [kv_lora_rank, qk_rope_head_dim], dim = 0) + + return [ + (self.map_tensor_name(name.replace("kv_a_proj_with_mqa", "kv_a_proj")), kv_a_proj), + (self.map_tensor_name(name.replace("kv_a_proj_with_mqa", "k_mqa_proj")), k_mqa_proj) + ] + + # split kv_b_proj into: k_b_proj, v_b_proj, and k_b_trans_proj (for deepseek-mla) + if name.endswith("kv_b_proj.weight"): + assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim) + assert data_torch.shape[1] == kv_lora_rank + + kv_b_proj = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, kv_lora_rank) + k_b_proj, v_b_proj = torch.split(kv_b_proj, [qk_nope_head_dim, v_head_dim], dim = 1) + + k_b_trans_proj = k_b_proj.transpose(1, 2).reshape(n_head_kv * kv_lora_rank, qk_nope_head_dim) + k_b_proj = k_b_proj.reshape(n_head_kv * qk_nope_head_dim, kv_lora_rank) + v_b_proj = v_b_proj.reshape(n_head_kv * v_head_dim, kv_lora_rank) + + return [ + (self.map_tensor_name(name.replace("kv_b_proj", "k_b_trans_proj")), k_b_trans_proj), + (self.map_tensor_name(name.replace("kv_b_proj", "k_b_proj")), k_b_proj), + (self.map_tensor_name(name.replace("kv_b_proj", "v_b_proj")), v_b_proj) + ] + return [(self.map_tensor_name(name), data_torch)] def prepare_tensors(self): diff --git a/examples/server/README.md b/examples/server/README.md index a2a0903261e31..043c725d8d548 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -46,6 +46,7 @@ The project is under active development, and we are [looking for feedback and co | `-ub, --ubatch-size N` | physical maximum batch size (default: 512)
(env: LLAMA_ARG_UBATCH) | | `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) | | `-fa, --flash-attn` | enable Flash Attention (default: disabled)
(env: LLAMA_ARG_FLASH_ATTN) | +| `-mla, --mla-attn` | enable Multi-head Latent Attention (default: disabled)
(env: LLAMA_ARG_MLA_ATTN) | | `--no-perf` | disable internal libllama performance timings (default: false)
(env: LLAMA_ARG_NO_PERF) | | `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) | | `--no-escape` | do not process escape sequences | diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index ecac5b4bb7f59..758efa2f3ef16 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -356,6 +356,13 @@ class MODEL_TENSOR(IntEnum): ATTN_Q_B = auto() ATTN_KV_A_MQA = auto() ATTN_KV_B = auto() + ATTN_Q_MQA = auto() + ATTN_Q_B_MQA = auto() + ATTN_KV_A = auto() + ATTN_K_MQA = auto() + ATTN_K_B_TRANS = auto() + ATTN_K_B = auto() + ATTN_V_B = auto() ATTN_Q_A_NORM = auto() ATTN_KV_A_NORM = auto() FFN_SUB_NORM = auto() @@ -543,6 +550,13 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa", MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", + MODEL_TENSOR.ATTN_Q_MQA: "blk.{bid}.attn_q_mqa", + MODEL_TENSOR.ATTN_Q_B_MQA: "blk.{bid}.attn_q_b_mqa", + MODEL_TENSOR.ATTN_KV_A: "blk.{bid}.attn_kv_a", + MODEL_TENSOR.ATTN_K_MQA: "blk.{bid}.attn_k_mqa", + MODEL_TENSOR.ATTN_K_B_TRANS: "blk.{bid}.attn_k_b_trans", + MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b", + MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b", MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", @@ -1041,6 +1055,13 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_Q_B, MODEL_TENSOR.ATTN_KV_A_MQA, MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_Q_MQA, + MODEL_TENSOR.ATTN_Q_B_MQA, + MODEL_TENSOR.ATTN_KV_A, + MODEL_TENSOR.ATTN_K_MQA, + MODEL_TENSOR.ATTN_K_B_TRANS, + MODEL_TENSOR.ATTN_K_B, + MODEL_TENSOR.ATTN_V_B, MODEL_TENSOR.ATTN_Q_A_NORM, MODEL_TENSOR.ATTN_KV_A_NORM, MODEL_TENSOR.ATTN_OUT, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 617791e240b60..ae17da73af674 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -586,6 +586,34 @@ class TensorNameMap: "model.layers.{bid}.self_attn.kv_b_proj", # deepseek2 ), + MODEL_TENSOR.ATTN_Q_MQA: ( + "model.layers.{bid}.self_attn.q_mqa_proj", # deepseek2 (v2-lite) + ), + + MODEL_TENSOR.ATTN_Q_B_MQA: ( + "model.layers.{bid}.self_attn.q_b_mqa_proj", # deepseek2 (v2/v3/r1) + ), + + MODEL_TENSOR.ATTN_KV_A: ( + "model.layers.{bid}.self_attn.kv_a_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_K_MQA: ( + "model.layers.{bid}.self_attn.k_mqa_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_K_B_TRANS: ( + "model.layers.{bid}.self_attn.k_b_trans_proj", # deepseek2 (mla only) + ), + + MODEL_TENSOR.ATTN_K_B: ( + "model.layers.{bid}.self_attn.k_b_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_V_B: ( + "model.layers.{bid}.self_attn.v_b_proj", # deepseek2 + ), + MODEL_TENSOR.ATTN_Q_A_NORM: ( "model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2 ), diff --git a/include/llama.h b/include/llama.h index d62792c0a6760..be6dfc5d87154 100644 --- a/include/llama.h +++ b/include/llama.h @@ -343,6 +343,7 @@ extern "C" { bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool flash_attn; // whether to use flash attention [EXPERIMENTAL] + bool mla_attn; // MLA attention for deepseek2 bool no_perf; // whether to measure performance timings // Abort callback diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 97a1e7e5e01ef..cca3cad2c6cb8 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -997,6 +997,13 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, + { LLM_TENSOR_ATTN_Q_MQA, "blk.%d.attn_q_mqa" }, + { LLM_TENSOR_ATTN_Q_B_MQA, "blk.%d.attn_q_b_mqa" }, + { LLM_TENSOR_ATTN_KV_A, "blk.%d.attn_kv_a" }, + { LLM_TENSOR_ATTN_K_MQA, "blk.%d.attn_k_mqa" }, + { LLM_TENSOR_ATTN_K_B_TRANS, "blk.%d.attn_k_b_trans" }, + { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, + { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, @@ -1333,23 +1340,13 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_Q_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_Q_B_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_KV_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_K_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_K_B_TRANS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 122fdcebe0af6..cae591373c2de 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -277,6 +277,13 @@ enum llm_tensor { LLM_TENSOR_ATTN_Q_B, LLM_TENSOR_ATTN_KV_A_MQA, LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_Q_MQA, + LLM_TENSOR_ATTN_Q_B_MQA, + LLM_TENSOR_ATTN_KV_A, + LLM_TENSOR_ATTN_K_MQA, + LLM_TENSOR_ATTN_K_B_TRANS, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_KV_A_NORM, LLM_TENSOR_ATTN_SUB_NORM, diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 252012f3d9405..6ebab857e236a 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -28,6 +28,7 @@ struct llama_cparams { bool causal_attn; bool offload_kqv; bool flash_attn; + bool mla_attn; bool no_perf; enum llama_pooling_type pooling_type; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index feffdf0de52cf..37ab4adfdbd94 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -32,7 +32,7 @@ bool llama_kv_cache_init( cache.recurrent = llama_model_is_recurrent(&model); cache.v_trans = !cache.recurrent && !cparams.flash_attn; - cache.can_shift = !cache.recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA + cache.can_shift = !cache.recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported yet LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n", __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, cache.can_shift); @@ -91,8 +91,21 @@ bool llama_kv_cache_init( return false; } - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + int64_t n_embd_k; + int64_t n_embd_v; + + // note: deepseek-mla stores the compressed versions + if (cparams.mla_attn && model.arch == LLM_ARCH_DEEPSEEK2) { + n_embd_k = hparams.n_lora_kv + hparams.n_rot; + n_embd_v = hparams.n_lora_kv; + } else { + n_embd_k = hparams.n_embd_k_gqa(i); + n_embd_v = hparams.n_embd_v_gqa(i); + } + + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v*kv_size); + ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); cache.k_l.push_back(k); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 1da4eae7e63e2..dc83718b968c6 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2890,14 +2890,20 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!is_lite) { layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); - layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_qk_nope}, 0); + layer.wq_b_mqa = create_tensor(tn(LLM_TENSOR_ATTN_Q_B_MQA, "weight", i), {q_lora_rank, n_head * n_embd_head_qk_rope}, 0); } else { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_qk_nope}, 0); + layer.wq_mqa = create_tensor(tn(LLM_TENSOR_ATTN_Q_MQA, "weight", i), {n_embd, n_head * n_embd_head_qk_rope}, 0); } - layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); - layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + layer.wkv_a = create_tensor(tn(LLM_TENSOR_ATTN_KV_A, "weight", i), {n_embd, kv_lora_rank}, 0); + layer.wk_mqa = create_tensor(tn(LLM_TENSOR_ATTN_K_MQA, "weight", i), {n_embd, n_embd_head_qk_rope}, 0); + layer.wk_b_trans = create_tensor(tn(LLM_TENSOR_ATTN_K_B_TRANS, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 0); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_qk_nope}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); diff --git a/src/llama-model.h b/src/llama-model.h index a7c30444786fd..1b9852402d7b5 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -152,23 +152,30 @@ struct llama_layer { struct ggml_tensor * attn_norm_enc = nullptr; // attention - struct ggml_tensor * wq = nullptr; - struct ggml_tensor * wk = nullptr; - struct ggml_tensor * wv = nullptr; - struct ggml_tensor * wo = nullptr; - struct ggml_tensor * wqkv = nullptr; - struct ggml_tensor * wq_a = nullptr; - struct ggml_tensor * wq_b = nullptr; - struct ggml_tensor * wkv_a_mqa = nullptr; - struct ggml_tensor * wkv_b = nullptr; - struct ggml_tensor * wq_cross = nullptr; - struct ggml_tensor * wk_cross = nullptr; - struct ggml_tensor * wv_cross = nullptr; - struct ggml_tensor * wo_cross = nullptr; - struct ggml_tensor * wq_enc = nullptr; - struct ggml_tensor * wk_enc = nullptr; - struct ggml_tensor * wv_enc = nullptr; - struct ggml_tensor * wo_enc = nullptr; + struct ggml_tensor * wq = nullptr; + struct ggml_tensor * wk = nullptr; + struct ggml_tensor * wv = nullptr; + struct ggml_tensor * wo = nullptr; + struct ggml_tensor * wqkv = nullptr; + struct ggml_tensor * wq_a = nullptr; + struct ggml_tensor * wq_b = nullptr; + struct ggml_tensor * wkv_a_mqa = nullptr; + struct ggml_tensor * wkv_b = nullptr; + struct ggml_tensor * wq_mqa = nullptr; + struct ggml_tensor * wq_b_mqa = nullptr; + struct ggml_tensor * wkv_a = nullptr; + struct ggml_tensor * wk_mqa = nullptr; + struct ggml_tensor * wk_b_trans = nullptr; + struct ggml_tensor * wk_b = nullptr; + struct ggml_tensor * wv_b = nullptr; + struct ggml_tensor * wq_cross = nullptr; + struct ggml_tensor * wk_cross = nullptr; + struct ggml_tensor * wv_cross = nullptr; + struct ggml_tensor * wo_cross = nullptr; + struct ggml_tensor * wq_enc = nullptr; + struct ggml_tensor * wk_enc = nullptr; + struct ggml_tensor * wv_enc = nullptr; + struct ggml_tensor * wo_enc = nullptr; // attention bias struct ggml_tensor * bq = nullptr; diff --git a/src/llama.cpp b/src/llama.cpp index 607f278615969..72e999e0ac0b0 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -156,8 +156,7 @@ static struct ggml_tensor * llm_build_inp_embd( static void llm_build_kv_store( struct ggml_context * ctx, - const llama_hparams & hparams, - const llama_cparams & cparams, + struct llama_context & lctx, const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * k_cur, @@ -166,28 +165,41 @@ static void llm_build_kv_store( int32_t kv_head, const llm_build_cb & cb, int64_t il) { - const int64_t n_ctx = cparams.n_ctx; + const llama_model & model = lctx.model; + const llama_hparams & hparams = lctx.model.hparams; + const llama_cparams & cparams = lctx.cparams; - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + const int64_t n_ctx = cparams.n_ctx; GGML_ASSERT(kv.size == n_ctx); - struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa)*kv_head); + int64_t n_embd_k; + int64_t n_embd_v; + + // note: deepseek-mla converts MLA to MQA so n_embd_k/n_embd_v change too + if (cparams.mla_attn && model.arch == LLM_ARCH_DEEPSEEK2) { + n_embd_k = hparams.n_lora_kv + hparams.n_rot; + n_embd_v = hparams.n_lora_kv; + } else { + n_embd_k = hparams.n_embd_k_gqa(il); + n_embd_v = hparams.n_embd_v_gqa(il); + } + + struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k, ggml_row_size(kv.k_l[il]->type, n_embd_k)*kv_head); cb(k_cache_view, "k_cache_view", il); // note: storing RoPE-ed version of K in the KV cache ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); - assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); + assert(v_cur->ne[0] == n_embd_v && v_cur->ne[1] == n_tokens); struct ggml_tensor * v_cache_view = nullptr; if (cparams.flash_attn) { - v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa)*kv_head); + v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v, ggml_row_size(kv.v_l[il]->type, n_embd_v)*kv_head); } else { // note: the V cache is transposed when not using flash attention - v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, + v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v, ( n_ctx)*ggml_element_size(kv.v_l[il]), (kv_head)*ggml_element_size(kv.v_l[il])); @@ -542,8 +554,9 @@ static struct ggml_tensor * llm_build_kqv( struct llama_context & lctx, const llama_kv_cache & kv, struct ggml_cgraph * graph, + struct ggml_tensor * wv_b, struct ggml_tensor * wo, - struct ggml_tensor * wo_b, + struct ggml_tensor * bo, struct ggml_tensor * q_cur, struct ggml_tensor * kq_mask, int32_t n_tokens, @@ -558,28 +571,28 @@ static struct ggml_tensor * llm_build_kqv( const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head(il); const int64_t n_head_kv = hparams.n_head_kv(il); - const int64_t n_embd_head_k = hparams.n_embd_head_k; const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const int64_t n_embd_head_v = hparams.n_embd_head_v; + const int64_t n_embd_head_k = hparams.n_embd_head_k; const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + const int64_t n_embd_head_v = hparams.n_embd_head_v; + + struct ggml_tensor * cur; struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3); cb(q, "q", il); - struct ggml_tensor * k = - ggml_view_3d(ctx, kv.k_l[il], - n_embd_head_k, n_kv, n_head_kv, - ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv.k_l[il]->type, n_embd_head_k), - 0); - cb(k, "k", il); - - struct ggml_tensor * cur; - if (cparams.flash_attn) { GGML_UNUSED(model); GGML_UNUSED(n_ctx); + struct ggml_tensor * k = + ggml_view_3d(ctx, kv.k_l[il], + n_embd_head_k, n_kv, n_head_kv, + ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv.k_l[il]->type, n_embd_head_k), + 0); + cb(k, "k", il); + // split cached v into n_head heads (not transposed) struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], @@ -596,51 +609,137 @@ static struct ggml_tensor * llm_build_kqv( cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens); } else { - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); - // note: this op tends to require high floating point range - // while for some models F16 is enough, for others it is not, so we default to F32 here - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + // MLA converetd to MQA optimised to use non-batched matrix multiplies + if (cparams.mla_attn && model.arch == LLM_ARCH_DEEPSEEK2) { + const int64_t n_embd_head_k_mqa = hparams.n_lora_kv + hparams.n_rot; + const int64_t n_embd_head_v_mqa = hparams.n_lora_kv; - if (model.arch == LLM_ARCH_GROK) { - // need to do the following: - // multiply by attn_output_multiplyer of 0.08838834764831845 - // and then : - // kq = 30 * tanh(kq / 30) - // before the softmax below + // must cont for the 2D view or else kq with have n_tokens <-> n_head swapped... + q = ggml_cont(ctx, q); + cb(q, "q_cont", il); - kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); - kq = ggml_scale(ctx, kq, 30); - } + q = ggml_view_2d(ctx, q, + n_embd_head_k_mqa, n_head * n_tokens, + ggml_row_size(q->type, n_embd_head_k_mqa), + 0); + cb(q, "q_view", il); - if (hparams.attn_soft_cap) { - kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping); - kq = ggml_tanh(ctx, kq); - kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping); - } + struct ggml_tensor * k = + ggml_view_2d(ctx, kv.k_l[il], + n_embd_head_k_mqa, n_kv, + ggml_row_size(kv.k_l[il]->type, n_embd_head_k_mqa), + 0); + cb(k, "k", il); - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); - cb(kq, "kq_soft_max_ext", il); + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); - GGML_ASSERT(kv.size == n_ctx); + // note: this doesn't seem necessary + //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - // split cached v into n_head heads - struct ggml_tensor * v = - ggml_view_3d(ctx, kv.v_l[il], - n_kv, n_embd_head_v, n_head_kv, - ggml_element_size(kv.v_l[il])*n_ctx, - ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, + kq = ggml_view_3d(ctx, kq, + n_kv, n_tokens, n_head, + ggml_row_size(kq->type, n_kv), + ggml_row_size(kq->type, n_kv * n_tokens), 0); - cb(v, "v", il); + cb(kq, "kq_view", il); + + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + kq = ggml_view_2d(ctx, kq, + n_kv, n_tokens * n_head, + ggml_row_size(kq->type, n_kv), + 0); + cb(kq, "kq_soft_max_view", il); + + GGML_ASSERT(kv.size == n_ctx); + + struct ggml_tensor * v = + ggml_view_2d(ctx, kv.v_l[il], + n_kv, n_embd_head_v_mqa, + ggml_element_size(kv.v_l[il])*n_ctx, + 0); + cb(v, "v", il); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv_compressed", il); + + kqv = ggml_view_3d(ctx, kqv, + n_embd_head_v_mqa, n_tokens, n_head, + ggml_row_size(kqv->type, n_embd_head_v_mqa), + ggml_row_size(kqv->type, n_embd_head_v_mqa * n_tokens), + 0); + cb(kqv, "kqv_view", il); + + struct ggml_tensor * wv_b_view = + ggml_view_3d(ctx, wv_b, n_embd_head_v_mqa, n_embd_head_v, n_head, + ggml_row_size(wv_b->type, n_embd_head_v_mqa), + ggml_row_size(wv_b->type, n_embd_head_v * n_embd_head_v_mqa), + 0); + cb(wv_b_view, "wv_b_view", il); + + // dsecompress the MQA to MHA + cur = ggml_mul_mat(ctx, wv_b_view, kqv); + cb(cur, "kqv", il); + + // standard MHA/GQA non-flash-attension case + } else { + struct ggml_tensor * k = + ggml_view_3d(ctx, kv.k_l[il], + n_embd_head_k, n_kv, n_head_kv, + ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv.k_l[il]->type, n_embd_head_k), + 0); + cb(k, "k", il); + + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); + + // note: this op tends to require high floating point range + // while for some models F16 is enough, for others it is not, so we default to F32 here + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + + if (model.arch == LLM_ARCH_GROK) { + // need to do the following: + // multiply by attn_output_multiplyer of 0.08838834764831845 + // and then : + // kq = 30 * tanh(kq / 30) + // before the softmax below + + kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); + kq = ggml_scale(ctx, kq, 30); + } + + if (hparams.attn_soft_cap) { + kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping); + kq = ggml_tanh(ctx, kq); + kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping); + } - struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + GGML_ASSERT(kv.size == n_ctx); + + // split cached v into n_head heads + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(kv.v_l[il])*n_ctx, + ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, + 0); + cb(v, "v", il); - struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); - cb(kqv_merged, "kqv_merged", il); + cur = ggml_mul_mat(ctx, v, kq); + cb(cur, "kqv", il); + } + + cur = ggml_permute(ctx, cur, 0, 2, 1, 3); + cb(cur, "kqv_merged", il); - cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens); + cur = ggml_cont_2d(ctx, cur, n_embd_head_v*n_head, n_tokens); cb(cur, "kqv_merged_cont", il); } @@ -650,12 +749,12 @@ static struct ggml_tensor * llm_build_kqv( cur = llm_build_lora_mm(lctx, ctx, wo, cur); } - if (wo_b) { + if (bo) { cb(cur, "kqv_wo", il); } - if (wo_b) { - cur = ggml_add(ctx, cur, wo_b); + if (bo) { + cur = ggml_add(ctx, cur, bo); } return cur; @@ -666,8 +765,9 @@ static struct ggml_tensor * llm_build_kv( struct llama_context & lctx, const llama_kv_cache & kv, struct ggml_cgraph * graph, + struct ggml_tensor * wv_b, struct ggml_tensor * wo, - struct ggml_tensor * wo_b, + struct ggml_tensor * bo, struct ggml_tensor * k_cur, struct ggml_tensor * v_cur, struct ggml_tensor * q_cur, @@ -678,20 +778,17 @@ static struct ggml_tensor * llm_build_kv( float kq_scale, const llm_build_cb & cb, int il) { - const llama_hparams & hparams = lctx.model.hparams; - const llama_cparams & cparams = lctx.cparams; - // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced ggml_build_forward_expand(graph, q_cur); ggml_build_forward_expand(graph, k_cur); ggml_build_forward_expand(graph, v_cur); - llm_build_kv_store(ctx, hparams, cparams, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il); + llm_build_kv_store(ctx, lctx, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il); struct ggml_tensor * cur; - cur = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b, q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il); + cur = llm_build_kqv(ctx, lctx, kv, graph, wv_b, wo, bo, q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il); cb(cur, "kqv_out", il); return cur; @@ -1546,7 +1643,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + nullptr, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); } @@ -1723,7 +1820,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + nullptr, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); } @@ -1861,7 +1958,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, + nullptr, model.layers[il].wo, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -1966,7 +2063,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, + nullptr, model.layers[il].wo, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -2087,7 +2184,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, + nullptr, model.layers[il].wo, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -2211,7 +2308,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + nullptr, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } @@ -2363,7 +2460,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, + nullptr, model.layers[il].wo, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -2475,7 +2572,7 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + nullptr, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -2569,7 +2666,7 @@ struct llm_build_context { cb(Qcur, "Qcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, + nullptr, model.layers[il].wo, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -2864,7 +2961,7 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + nullptr, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -2996,13 +3093,13 @@ struct llm_build_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + nullptr, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } else { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + nullptr, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } } @@ -3147,7 +3244,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, + nullptr, model.layers[il].wo, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -3266,7 +3363,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, + nullptr, model.layers[il].wo, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -3380,7 +3477,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + nullptr, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -3498,7 +3595,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + nullptr, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -3613,7 +3710,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + nullptr, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -3772,7 +3869,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + nullptr, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } @@ -3897,7 +3994,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + nullptr, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } @@ -4022,7 +4119,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, + nullptr, model.layers[il].wo, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } struct ggml_tensor * sa_out = cur; @@ -4124,7 +4221,7 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + nullptr, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -4235,7 +4332,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + nullptr, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -4355,7 +4452,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, + nullptr, model.layers[il].wo, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -4473,7 +4570,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + nullptr, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -4667,7 +4764,7 @@ struct llm_build_context { cb(k_states, "k_states", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, + nullptr, model.layers[il].wo, nullptr, k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); } @@ -4789,7 +4886,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, + nullptr, model.layers[il].wo, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } @@ -4908,7 +5005,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, + nullptr, model.layers[il].wo, nullptr, Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il); } @@ -5045,7 +5142,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + nullptr, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -5244,7 +5341,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + nullptr, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -5381,8 +5478,9 @@ struct llm_build_context { cb(Kcur, "Kcur", il); } - cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, - KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il); + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + nullptr, model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -5507,7 +5605,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, nullptr, + nullptr, model.layers[il].wo, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -5626,7 +5724,7 @@ struct llm_build_context { cb(Kcur, "Kcur_rope", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, + nullptr, model.layers[il].wo, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -5758,7 +5856,7 @@ struct llm_build_context { cb(Kcur, "Kcur_rope", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, + nullptr, model.layers[il].wo, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -5888,7 +5986,7 @@ struct llm_build_context { cb(Qcur, "Vcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, + nullptr, model.layers[il].wo, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -5997,7 +6095,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + nullptr, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -6140,7 +6238,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, + nullptr, model.layers[il].wo, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -6289,7 +6387,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + nullptr, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); } @@ -6388,8 +6486,8 @@ struct llm_build_context { const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k)); const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); - const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t kv_lora_rank = hparams.n_lora_kv; struct ggml_tensor * cur; @@ -6407,7 +6505,6 @@ struct llm_build_context { for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; - // norm cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); @@ -6415,115 +6512,160 @@ struct llm_build_context { // self_attention { - struct ggml_tensor * q = NULL; + struct ggml_tensor * q_nope = nullptr; + struct ggml_tensor * q_mqa = nullptr; if (!is_lite) { // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens} - q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); - cb(q, "q", il); + struct ggml_tensor * q_compressed = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); + cb(q_compressed, "q_compressed", il); - q = llm_build_norm(ctx0, q, hparams, - model.layers[il].attn_q_a_norm, NULL, + q_compressed = llm_build_norm(ctx0, q_compressed, hparams, + model.layers[il].attn_q_a_norm, nullptr, LLM_NORM_RMS, cb, il); - cb(q, "q", il); + cb(q_compressed, "q_compressed_norm", il); - // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens} - q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q); - cb(q, "q", il); + // {q_lora_rank, n_head * n_embd_head_qk_nope} * {q_lora_rank, n_tokens} -> {n_head * n_embd_head_qk_nope, n_tokens} + q_nope = ggml_mul_mat(ctx0, model.layers[il].wq_b, q_compressed); + cb(q_nope, "q_nope", il); + + // {q_lora_rank, n_head * n_embd_head_qk_rope} * {q_lora_rank, n_tokens} -> {n_head * n_embd_head_qk_rope, n_tokens} + q_mqa = ggml_mul_mat(ctx0, model.layers[il].wq_b_mqa, q_compressed); + cb(q_mqa, "q_mqa", il); } else { - q = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - cb(q, "q", il); + // {n_embd, n_head * n_embd_head_qk_nope} * {n_embd, n_tokens} -> {n_head * n_embd_head_qk_nope, n_tokens} + q_nope = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(q_nope, "q_nope", il); + + // {n_embd, n_head * n_embd_head_qk_rope} * {n_embd, n_tokens} -> {n_head * n_embd_head_qk_rope, n_tokens} + q_mqa = ggml_mul_mat(ctx0, model.layers[il].wq_mqa, cur); + cb(q_mqa, "q_mqa", il); } - // split into {n_head * n_embd_head_qk_nope, n_tokens} - struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + // {n_embd_head_qk_nope, n_head, n_tokens} + struct ggml_tensor * q_nope_view = ggml_view_3d(ctx0, q_nope, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(q_nope->type, n_embd_head_qk_nope), + ggml_row_size(q_nope->type, n_head * n_embd_head_qk_nope), 0); - cb(q_nope, "q_nope", il); + cb(q_nope_view, "q_nope_view", il); - // and {n_head * n_embd_head_qk_rope, n_tokens} - struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), - ggml_row_size(q->type, n_embd_head_qk_nope)); - cb(q_pe, "q_pe", il); + // {n_embd_head_qk_rope, n_head, n_tokens} + struct ggml_tensor * q_mqa_view = ggml_view_3d(ctx0, q_mqa, n_embd_head_qk_rope, n_head, n_tokens, + ggml_row_size(q_mqa->type, n_embd_head_qk_rope), + ggml_row_size(q_mqa->type, n_head * n_embd_head_qk_rope), + 0); + cb(q_mqa_view, "q_mqa_view", il); - // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens} - struct ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); - cb(kv_pe_compresseed, "kv_pe_compresseed", il); + q_mqa_view = ggml_rope_ext(ctx0, q_mqa_view, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(q_mqa_view, "q_mqa_view_rope", il); - // split into {kv_lora_rank, n_tokens} - struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens, - kv_pe_compresseed->nb[1], - 0); + // {n_embd, kv_lora_rank} * {n_embd, n_tokens} -> {kv_lora_rank, n_tokens} + struct ggml_tensor * kv_compressed = ggml_mul_mat(ctx0, model.layers[il].wkv_a, cur); cb(kv_compressed, "kv_compressed", il); - // and {n_embd_head_qk_rope, n_tokens} - struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, - kv_pe_compresseed->nb[1], - kv_pe_compresseed->nb[1], - ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); - cb(k_pe, "k_pe", il); - - // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont - kv_compressed = ggml_cont(ctx0, kv_compressed); kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams, - model.layers[il].attn_kv_a_norm, NULL, + model.layers[il].attn_kv_a_norm, nullptr, LLM_NORM_RMS, cb, il); - cb(kv_compressed, "kv_compressed", il); + cb(kv_compressed, "kv_compressed_norm", il); - // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} - struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); - cb(kv, "kv", il); + // {n_embd, n_embd_head_qk_rope} * {n_embd, n_tokens} -> {n_embd_head_qk_rope, n_tokens} + struct ggml_tensor * k_mqa = ggml_mul_mat(ctx0, model.layers[il].wk_mqa, cur); + cb(k_mqa, "k_mqa", il); - // split into {n_head * n_embd_head_qk_nope, n_tokens} - struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + // {n_embd_head_qk_rope, 1, n_tokens} + struct ggml_tensor * k_mqa_view = ggml_view_3d(ctx0, k_mqa, n_embd_head_qk_rope, 1, n_tokens, + ggml_row_size(k_mqa->type, n_embd_head_qk_rope), + ggml_row_size(k_mqa->type, n_embd_head_qk_rope), 0); - cb(k_nope, "k_nope", il); + cb(k_mqa_view, "k_mqa_view", il); - // and {n_head * n_embd_head_v, n_tokens} - struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), - ggml_row_size(kv->type, (n_embd_head_qk_nope))); - cb(v_states, "v_states", il); - - v_states = ggml_cont(ctx0, v_states); - cb(v_states, "v_states", il); - - v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, - ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), - 0); - cb(v_states, "v_states", il); - - q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this - q_pe = ggml_rope_ext( - ctx0, q_pe, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow + k_mqa_view = ggml_rope_ext(ctx0, k_mqa_view, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow ); - cb(q_pe, "q_pe", il); + cb(k_mqa_view, "k_mqa_view_rope", il); + + if (!cparams.mla_attn) { + // {kv_lora_rank, n_head * n_embd_head_qk_nope} * {kv_lora_rank, n_tokens} -> {n_head * n_embd_head_qk_nope, n_tokens} + struct ggml_tensor * k_nope = ggml_mul_mat(ctx0, model.layers[il].wk_b, kv_compressed); + cb(k_nope, "k_nope", il); + + // {n_embd_head_qk_nope, n_head, n_tokens} + struct ggml_tensor * k_nope_view = ggml_view_3d(ctx0, k_nope, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(k_nope->type, n_embd_head_qk_nope), + ggml_row_size(k_nope->type, n_head * n_embd_head_qk_nope), + 0); + cb(k_nope_view, "k_nope_view", il); + + // TODO: build_k_shift() and build_defrag(); the RoPEed part is the first n_rot as they expect + // {n_embd_head_qk_rope + n_embd_head_qk_nope, n_head, n_tokens} + struct ggml_tensor * q_states = ggml_concat(ctx0, q_mqa_view, q_nope_view, 0); + cb(q_states, "q_states", il); + + // TODO: build_k_shift() and build_defrag(); the RoPEed part is the first n_rot as they expect + // {n_embd_head_qk_rope + n_embd_head_qk_nope, n_head, n_tokens} + struct ggml_tensor * k_states = ggml_concat(ctx0, ggml_repeat(ctx0, k_mqa_view, q_mqa_view), k_nope_view, 0); + cb(k_states, "k_states", il); + + // {kv_lora_rank, n_head * n_embd_head_v} * {kv_lora_rank, n_tokens} -> {n_head * n_embd_head_v, n_tokens} + struct ggml_tensor * v_states = ggml_mul_mat(ctx0, model.layers[il].wv_b, kv_compressed); + cb(v_states, "v_states", il); + + // note: this has essentially converted MLA into MHA (with very large KV-cache overhead) + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + nullptr, model.layers[il].wo, nullptr, + k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + } else { + // {n_embd_head_qk_nope, kv_lora_rank, n_head} + struct ggml_tensor * wk_b_trans_view = ggml_view_3d(ctx0, model.layers[il].wk_b_trans, + n_embd_head_qk_nope, kv_lora_rank, n_head, + ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), + ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), + 0); + cb(wk_b_trans_view, "wk_b_trans_view", il); + + // {n_embd_head_qk_nope, n_tokens, n_head} + q_nope_view = ggml_permute(ctx0, q_nope_view, 0, 2, 1, 3); + cb(q_nope_view, "q_nope_view_perm", il); + + // {n_embd_head_qk_nope, kv_lora_rank, n_head} * {n_embd_head_qk_nope, n_tokens, n_head} = {kv_lora_rank, n_tokens, n_head} + struct ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, wk_b_trans_view, q_nope_view); + cb(q_nope_absorbed, "q_nope_absorbed", il); + + // {n_embd_head_qk_rope, n_head, n_tokens} + q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3); + cb(q_nope_absorbed, "q_nope_absorbed_perm", il); + + // TODO: build_k_shift() and build_defrag(); the RoPEed part is the first n_rot as they expect + // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} + struct ggml_tensor * q_states = ggml_concat(ctx0, q_mqa_view, q_nope_absorbed, 0); + cb(q_states, "q_states", il); + + // {kv_lora_rank, 1, n_tokens} + struct ggml_tensor * kv_compressed_view = ggml_view_3d(ctx0, kv_compressed, + kv_lora_rank, 1, n_tokens, + ggml_row_size(k_mqa->type, kv_lora_rank), + ggml_row_size(k_mqa->type, kv_lora_rank), + 0); + cb(kv_compressed_view, "kv_compressed_view", il); - // shared RoPE key - k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this - k_pe = ggml_rope_ext( - ctx0, k_pe, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); - cb(k_pe, "k_pe", il); + // TODO: build_k_shift() and build_defrag(); the RoPEed part is the first n_rot as they expect + // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens} + struct ggml_tensor * k_states = ggml_concat(ctx0, k_mqa_view, kv_compressed_view, 0); + cb(k_states, "k_states", il); - struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); - cb(q_states, "q_states", il); + // {kv_lora_rank, 1, n_tokens} + struct ggml_tensor * v_states = kv_compressed; + cb(v_states, "v_states", il); - struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); - cb(k_states, "k_states", il); + // note: this has essentially converted MLA into MQA (with very low KV-cache overhead) + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wv_b, model.layers[il].wo, nullptr, + k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + } - cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, - k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); } if (il == n_layer - 1) { @@ -6680,7 +6822,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - NULL, NULL, + nullptr, nullptr, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cur = llm_build_norm(ctx0, cur, hparams, @@ -6932,7 +7074,7 @@ struct llm_build_context { struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); + llm_build_kv_store(ctx0, lctx, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); struct ggml_tensor * k = ggml_view_3d(ctx0, kv_self.k_l[il], @@ -7134,7 +7276,7 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + nullptr, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/float(n_embd_head), cb, il); } @@ -7260,7 +7402,7 @@ struct llm_build_context { cb(Kcur, "Kcur_rope", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, + nullptr, model.layers[il].wo, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -7379,7 +7521,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + nullptr, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -7505,7 +7647,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + nullptr, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -7876,7 +8018,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, nullptr, + nullptr, model.layers[il].wo, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); if (hparams.swin_norm) { @@ -9358,6 +9500,7 @@ struct llama_context_params llama_context_default_params() { /*.embeddings =*/ false, /*.offload_kqv =*/ true, /*.flash_attn =*/ false, + /*.mla_attn =*/ false, /*.no_perf =*/ true, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, @@ -9594,6 +9737,7 @@ struct llama_context * llama_init_from_model( cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; cparams.flash_attn = params.flash_attn; + cparams.mla_attn = params.mla_attn; cparams.no_perf = params.no_perf; cparams.pooling_type = params.pooling_type;