diff --git a/common/arg.cpp b/common/arg.cpp
index 8531f0871d44a..84dc6841e3866 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -811,6 +811,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.flash_attn = true;
}
).set_env("LLAMA_ARG_FLASH_ATTN"));
+ add_opt(common_arg(
+ {"-mla", "--mla-attn"},
+ string_format("enable Multi-head Latent Attention (default: %s)", params.mla_attn ? "enabled" : "disabled"),
+ [](common_params & params) {
+ params.mla_attn = true;
+ }
+ ).set_env("LLAMA_ARG_MLA_ATTN"));
add_opt(common_arg(
{"-p", "--prompt"}, "PROMPT",
"prompt to start generation with; for system message, use -sys",
diff --git a/common/common.cpp b/common/common.cpp
index 6448b7b03d6d2..b3438f2646bdf 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -1132,6 +1132,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn;
+ cparams.mla_attn = params.mla_attn;
cparams.no_perf = params.no_perf;
if (params.reranking) {
diff --git a/common/common.h b/common/common.h
index 1c0f199774976..207732a9957a8 100644
--- a/common/common.h
+++ b/common/common.h
@@ -325,6 +325,7 @@ struct common_params {
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention
+ bool mla_attn = false; // MLA attention for deepseek2
bool no_perf = false; // disable performance metrics
bool ctx_shift = true; // context shift on inifinite text generation
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 6358a94e9b55f..d13196696ce33 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -4141,6 +4141,78 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
else:
return []
+ n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
+ n_head_kv = self.hparams["num_key_value_heads"]
+ qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
+ qk_rope_head_dim = self.hparams["qk_rope_head_dim"]
+ v_head_dim = self.hparams["v_head_dim"]
+ kv_lora_rank = self.hparams["kv_lora_rank"]
+
+ # (v2-lite) split q_proj into: q_proj and q_mqa_proj
+ if name.endswith("q_proj.weight"):
+ assert data_torch.shape[0] == n_head_kv * (qk_nope_head_dim + qk_rope_head_dim)
+ assert data_torch.shape[1] == n_embed
+
+ q_proj_with_mqa = data_torch.view(n_head_kv, qk_nope_head_dim + qk_rope_head_dim, n_embed)
+ q_proj, q_mqa_proj = torch.split(q_proj_with_mqa, [qk_nope_head_dim, qk_rope_head_dim], dim = 1)
+
+ q_proj = q_proj.reshape(n_head_kv * qk_nope_head_dim, n_embed)
+ q_mqa_proj = q_mqa_proj.reshape(n_head_kv * qk_rope_head_dim, n_embed)
+
+ return [
+ (self.map_tensor_name(name), q_proj),
+ (self.map_tensor_name(name.replace("q_proj", "q_mqa_proj")), q_mqa_proj)
+ ]
+
+ # (v2/v3/r1) split q_b_proj into: q_b_proj and q_b_mqa_proj
+ if name.endswith("q_b_proj.weight"):
+ q_lora_rank = self.hparams["q_lora_rank"]
+
+ assert data_torch.shape[0] == n_head_kv * (qk_nope_head_dim + qk_rope_head_dim)
+ assert data_torch.shape[1] == q_lora_rank
+
+ q_b_proj_with_mqa = data_torch.view(n_head_kv, qk_nope_head_dim + qk_rope_head_dim, q_lora_rank)
+ q_b_proj, q_b_mqa_proj = torch.split(q_b_proj_with_mqa, [qk_nope_head_dim, qk_rope_head_dim], dim = 1)
+
+ q_b_proj = q_b_proj.reshape(n_head_kv * qk_nope_head_dim, q_lora_rank)
+ q_b_mqa_proj = q_b_mqa_proj.reshape(n_head_kv * qk_rope_head_dim, q_lora_rank)
+
+ return [
+ (self.map_tensor_name(name), q_b_proj),
+ (self.map_tensor_name(name.replace("q_b_proj", "q_b_mqa_proj")), q_b_mqa_proj)
+ ]
+
+ # split kv_a_proj_with_mqa into: kv_a_proj and k_mqa_proj
+ if name.endswith("kv_a_proj_with_mqa.weight"):
+ assert data_torch.shape[0] == kv_lora_rank + qk_rope_head_dim
+ assert data_torch.shape[1] == n_embed
+
+ kv_a_proj_with_mqa = data_torch.view(kv_lora_rank + qk_rope_head_dim, n_embed)
+ kv_a_proj, k_mqa_proj = torch.split(kv_a_proj_with_mqa, [kv_lora_rank, qk_rope_head_dim], dim = 0)
+
+ return [
+ (self.map_tensor_name(name.replace("kv_a_proj_with_mqa", "kv_a_proj")), kv_a_proj),
+ (self.map_tensor_name(name.replace("kv_a_proj_with_mqa", "k_mqa_proj")), k_mqa_proj)
+ ]
+
+ # split kv_b_proj into: k_b_proj, v_b_proj, and k_b_trans_proj (for deepseek-mla)
+ if name.endswith("kv_b_proj.weight"):
+ assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)
+ assert data_torch.shape[1] == kv_lora_rank
+
+ kv_b_proj = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, kv_lora_rank)
+ k_b_proj, v_b_proj = torch.split(kv_b_proj, [qk_nope_head_dim, v_head_dim], dim = 1)
+
+ k_b_trans_proj = k_b_proj.transpose(1, 2).reshape(n_head_kv * kv_lora_rank, qk_nope_head_dim)
+ k_b_proj = k_b_proj.reshape(n_head_kv * qk_nope_head_dim, kv_lora_rank)
+ v_b_proj = v_b_proj.reshape(n_head_kv * v_head_dim, kv_lora_rank)
+
+ return [
+ (self.map_tensor_name(name.replace("kv_b_proj", "k_b_trans_proj")), k_b_trans_proj),
+ (self.map_tensor_name(name.replace("kv_b_proj", "k_b_proj")), k_b_proj),
+ (self.map_tensor_name(name.replace("kv_b_proj", "v_b_proj")), v_b_proj)
+ ]
+
return [(self.map_tensor_name(name), data_torch)]
def prepare_tensors(self):
diff --git a/examples/server/README.md b/examples/server/README.md
index a2a0903261e31..043c725d8d548 100644
--- a/examples/server/README.md
+++ b/examples/server/README.md
@@ -46,6 +46,7 @@ The project is under active development, and we are [looking for feedback and co
| `-ub, --ubatch-size N` | physical maximum batch size (default: 512)
(env: LLAMA_ARG_UBATCH) |
| `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) |
| `-fa, --flash-attn` | enable Flash Attention (default: disabled)
(env: LLAMA_ARG_FLASH_ATTN) |
+| `-mla, --mla-attn` | enable Multi-head Latent Attention (default: disabled)
(env: LLAMA_ARG_MLA_ATTN) |
| `--no-perf` | disable internal libllama performance timings (default: false)
(env: LLAMA_ARG_NO_PERF) |
| `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) |
| `--no-escape` | do not process escape sequences |
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index ecac5b4bb7f59..758efa2f3ef16 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -356,6 +356,13 @@ class MODEL_TENSOR(IntEnum):
ATTN_Q_B = auto()
ATTN_KV_A_MQA = auto()
ATTN_KV_B = auto()
+ ATTN_Q_MQA = auto()
+ ATTN_Q_B_MQA = auto()
+ ATTN_KV_A = auto()
+ ATTN_K_MQA = auto()
+ ATTN_K_B_TRANS = auto()
+ ATTN_K_B = auto()
+ ATTN_V_B = auto()
ATTN_Q_A_NORM = auto()
ATTN_KV_A_NORM = auto()
FFN_SUB_NORM = auto()
@@ -543,6 +550,13 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
+ MODEL_TENSOR.ATTN_Q_MQA: "blk.{bid}.attn_q_mqa",
+ MODEL_TENSOR.ATTN_Q_B_MQA: "blk.{bid}.attn_q_b_mqa",
+ MODEL_TENSOR.ATTN_KV_A: "blk.{bid}.attn_kv_a",
+ MODEL_TENSOR.ATTN_K_MQA: "blk.{bid}.attn_k_mqa",
+ MODEL_TENSOR.ATTN_K_B_TRANS: "blk.{bid}.attn_k_b_trans",
+ MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b",
+ MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b",
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
@@ -1041,6 +1055,13 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_B,
+ MODEL_TENSOR.ATTN_Q_MQA,
+ MODEL_TENSOR.ATTN_Q_B_MQA,
+ MODEL_TENSOR.ATTN_KV_A,
+ MODEL_TENSOR.ATTN_K_MQA,
+ MODEL_TENSOR.ATTN_K_B_TRANS,
+ MODEL_TENSOR.ATTN_K_B,
+ MODEL_TENSOR.ATTN_V_B,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV_A_NORM,
MODEL_TENSOR.ATTN_OUT,
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index 617791e240b60..ae17da73af674 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -586,6 +586,34 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
),
+ MODEL_TENSOR.ATTN_Q_MQA: (
+ "model.layers.{bid}.self_attn.q_mqa_proj", # deepseek2 (v2-lite)
+ ),
+
+ MODEL_TENSOR.ATTN_Q_B_MQA: (
+ "model.layers.{bid}.self_attn.q_b_mqa_proj", # deepseek2 (v2/v3/r1)
+ ),
+
+ MODEL_TENSOR.ATTN_KV_A: (
+ "model.layers.{bid}.self_attn.kv_a_proj", # deepseek2
+ ),
+
+ MODEL_TENSOR.ATTN_K_MQA: (
+ "model.layers.{bid}.self_attn.k_mqa_proj", # deepseek2
+ ),
+
+ MODEL_TENSOR.ATTN_K_B_TRANS: (
+ "model.layers.{bid}.self_attn.k_b_trans_proj", # deepseek2 (mla only)
+ ),
+
+ MODEL_TENSOR.ATTN_K_B: (
+ "model.layers.{bid}.self_attn.k_b_proj", # deepseek2
+ ),
+
+ MODEL_TENSOR.ATTN_V_B: (
+ "model.layers.{bid}.self_attn.v_b_proj", # deepseek2
+ ),
+
MODEL_TENSOR.ATTN_Q_A_NORM: (
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
),
diff --git a/include/llama.h b/include/llama.h
index d62792c0a6760..be6dfc5d87154 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -343,6 +343,7 @@ extern "C" {
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
+ bool mla_attn; // MLA attention for deepseek2
bool no_perf; // whether to measure performance timings
// Abort callback
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
index 97a1e7e5e01ef..cca3cad2c6cb8 100644
--- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp
@@ -997,6 +997,13 @@ static const std::map> LLM_TENSOR_N
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" },
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
+ { LLM_TENSOR_ATTN_Q_MQA, "blk.%d.attn_q_mqa" },
+ { LLM_TENSOR_ATTN_Q_B_MQA, "blk.%d.attn_q_b_mqa" },
+ { LLM_TENSOR_ATTN_KV_A, "blk.%d.attn_kv_a" },
+ { LLM_TENSOR_ATTN_K_MQA, "blk.%d.attn_k_mqa" },
+ { LLM_TENSOR_ATTN_K_B_TRANS, "blk.%d.attn_k_b_trans" },
+ { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
+ { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
@@ -1333,23 +1340,13 @@ static const std::map LLM_TENSOR_INFOS = {
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+ {LLM_TENSOR_ATTN_Q_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+ {LLM_TENSOR_ATTN_Q_B_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+ {LLM_TENSOR_ATTN_KV_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+ {LLM_TENSOR_ATTN_K_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+ {LLM_TENSOR_ATTN_K_B_TRANS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+ {LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+ {LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
diff --git a/src/llama-arch.h b/src/llama-arch.h
index 122fdcebe0af6..cae591373c2de 100644
--- a/src/llama-arch.h
+++ b/src/llama-arch.h
@@ -277,6 +277,13 @@ enum llm_tensor {
LLM_TENSOR_ATTN_Q_B,
LLM_TENSOR_ATTN_KV_A_MQA,
LLM_TENSOR_ATTN_KV_B,
+ LLM_TENSOR_ATTN_Q_MQA,
+ LLM_TENSOR_ATTN_Q_B_MQA,
+ LLM_TENSOR_ATTN_KV_A,
+ LLM_TENSOR_ATTN_K_MQA,
+ LLM_TENSOR_ATTN_K_B_TRANS,
+ LLM_TENSOR_ATTN_K_B,
+ LLM_TENSOR_ATTN_V_B,
LLM_TENSOR_ATTN_Q_A_NORM,
LLM_TENSOR_ATTN_KV_A_NORM,
LLM_TENSOR_ATTN_SUB_NORM,
diff --git a/src/llama-cparams.h b/src/llama-cparams.h
index 252012f3d9405..6ebab857e236a 100644
--- a/src/llama-cparams.h
+++ b/src/llama-cparams.h
@@ -28,6 +28,7 @@ struct llama_cparams {
bool causal_attn;
bool offload_kqv;
bool flash_attn;
+ bool mla_attn;
bool no_perf;
enum llama_pooling_type pooling_type;
diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
index feffdf0de52cf..37ab4adfdbd94 100644
--- a/src/llama-kv-cache.cpp
+++ b/src/llama-kv-cache.cpp
@@ -32,7 +32,7 @@ bool llama_kv_cache_init(
cache.recurrent = llama_model_is_recurrent(&model);
cache.v_trans = !cache.recurrent && !cparams.flash_attn;
- cache.can_shift = !cache.recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
+ cache.can_shift = !cache.recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported yet
LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
__func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, cache.can_shift);
@@ -91,8 +91,21 @@ bool llama_kv_cache_init(
return false;
}
- ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
- ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
+ int64_t n_embd_k;
+ int64_t n_embd_v;
+
+ // note: deepseek-mla stores the compressed versions
+ if (cparams.mla_attn && model.arch == LLM_ARCH_DEEPSEEK2) {
+ n_embd_k = hparams.n_lora_kv + hparams.n_rot;
+ n_embd_v = hparams.n_lora_kv;
+ } else {
+ n_embd_k = hparams.n_embd_k_gqa(i);
+ n_embd_v = hparams.n_embd_v_gqa(i);
+ }
+
+ ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k*kv_size);
+ ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v*kv_size);
+
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
cache.k_l.push_back(k);
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index 1da4eae7e63e2..dc83718b968c6 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -2890,14 +2890,20 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
if (!is_lite) {
layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0);
- layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0);
+ layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_qk_nope}, 0);
+ layer.wq_b_mqa = create_tensor(tn(LLM_TENSOR_ATTN_Q_B_MQA, "weight", i), {q_lora_rank, n_head * n_embd_head_qk_rope}, 0);
} else {
- layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0);
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_qk_nope}, 0);
+ layer.wq_mqa = create_tensor(tn(LLM_TENSOR_ATTN_Q_MQA, "weight", i), {n_embd, n_head * n_embd_head_qk_rope}, 0);
}
- layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
- layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
- layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0);
+ layer.wkv_a = create_tensor(tn(LLM_TENSOR_ATTN_KV_A, "weight", i), {n_embd, kv_lora_rank}, 0);
+ layer.wk_mqa = create_tensor(tn(LLM_TENSOR_ATTN_K_MQA, "weight", i), {n_embd, n_embd_head_qk_rope}, 0);
+ layer.wk_b_trans = create_tensor(tn(LLM_TENSOR_ATTN_K_B_TRANS, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 0);
+ layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_qk_nope}, 0);
+ layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v}, 0);
+
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v, n_embd}, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
diff --git a/src/llama-model.h b/src/llama-model.h
index a7c30444786fd..1b9852402d7b5 100644
--- a/src/llama-model.h
+++ b/src/llama-model.h
@@ -152,23 +152,30 @@ struct llama_layer {
struct ggml_tensor * attn_norm_enc = nullptr;
// attention
- struct ggml_tensor * wq = nullptr;
- struct ggml_tensor * wk = nullptr;
- struct ggml_tensor * wv = nullptr;
- struct ggml_tensor * wo = nullptr;
- struct ggml_tensor * wqkv = nullptr;
- struct ggml_tensor * wq_a = nullptr;
- struct ggml_tensor * wq_b = nullptr;
- struct ggml_tensor * wkv_a_mqa = nullptr;
- struct ggml_tensor * wkv_b = nullptr;
- struct ggml_tensor * wq_cross = nullptr;
- struct ggml_tensor * wk_cross = nullptr;
- struct ggml_tensor * wv_cross = nullptr;
- struct ggml_tensor * wo_cross = nullptr;
- struct ggml_tensor * wq_enc = nullptr;
- struct ggml_tensor * wk_enc = nullptr;
- struct ggml_tensor * wv_enc = nullptr;
- struct ggml_tensor * wo_enc = nullptr;
+ struct ggml_tensor * wq = nullptr;
+ struct ggml_tensor * wk = nullptr;
+ struct ggml_tensor * wv = nullptr;
+ struct ggml_tensor * wo = nullptr;
+ struct ggml_tensor * wqkv = nullptr;
+ struct ggml_tensor * wq_a = nullptr;
+ struct ggml_tensor * wq_b = nullptr;
+ struct ggml_tensor * wkv_a_mqa = nullptr;
+ struct ggml_tensor * wkv_b = nullptr;
+ struct ggml_tensor * wq_mqa = nullptr;
+ struct ggml_tensor * wq_b_mqa = nullptr;
+ struct ggml_tensor * wkv_a = nullptr;
+ struct ggml_tensor * wk_mqa = nullptr;
+ struct ggml_tensor * wk_b_trans = nullptr;
+ struct ggml_tensor * wk_b = nullptr;
+ struct ggml_tensor * wv_b = nullptr;
+ struct ggml_tensor * wq_cross = nullptr;
+ struct ggml_tensor * wk_cross = nullptr;
+ struct ggml_tensor * wv_cross = nullptr;
+ struct ggml_tensor * wo_cross = nullptr;
+ struct ggml_tensor * wq_enc = nullptr;
+ struct ggml_tensor * wk_enc = nullptr;
+ struct ggml_tensor * wv_enc = nullptr;
+ struct ggml_tensor * wo_enc = nullptr;
// attention bias
struct ggml_tensor * bq = nullptr;
diff --git a/src/llama.cpp b/src/llama.cpp
index 607f278615969..72e999e0ac0b0 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -156,8 +156,7 @@ static struct ggml_tensor * llm_build_inp_embd(
static void llm_build_kv_store(
struct ggml_context * ctx,
- const llama_hparams & hparams,
- const llama_cparams & cparams,
+ struct llama_context & lctx,
const llama_kv_cache & kv,
struct ggml_cgraph * graph,
struct ggml_tensor * k_cur,
@@ -166,28 +165,41 @@ static void llm_build_kv_store(
int32_t kv_head,
const llm_build_cb & cb,
int64_t il) {
- const int64_t n_ctx = cparams.n_ctx;
+ const llama_model & model = lctx.model;
+ const llama_hparams & hparams = lctx.model.hparams;
+ const llama_cparams & cparams = lctx.cparams;
- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
- const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+ const int64_t n_ctx = cparams.n_ctx;
GGML_ASSERT(kv.size == n_ctx);
- struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa)*kv_head);
+ int64_t n_embd_k;
+ int64_t n_embd_v;
+
+ // note: deepseek-mla converts MLA to MQA so n_embd_k/n_embd_v change too
+ if (cparams.mla_attn && model.arch == LLM_ARCH_DEEPSEEK2) {
+ n_embd_k = hparams.n_lora_kv + hparams.n_rot;
+ n_embd_v = hparams.n_lora_kv;
+ } else {
+ n_embd_k = hparams.n_embd_k_gqa(il);
+ n_embd_v = hparams.n_embd_v_gqa(il);
+ }
+
+ struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k, ggml_row_size(kv.k_l[il]->type, n_embd_k)*kv_head);
cb(k_cache_view, "k_cache_view", il);
// note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
- assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
+ assert(v_cur->ne[0] == n_embd_v && v_cur->ne[1] == n_tokens);
struct ggml_tensor * v_cache_view = nullptr;
if (cparams.flash_attn) {
- v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa)*kv_head);
+ v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v, ggml_row_size(kv.v_l[il]->type, n_embd_v)*kv_head);
} else {
// note: the V cache is transposed when not using flash attention
- v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
+ v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v,
( n_ctx)*ggml_element_size(kv.v_l[il]),
(kv_head)*ggml_element_size(kv.v_l[il]));
@@ -542,8 +554,9 @@ static struct ggml_tensor * llm_build_kqv(
struct llama_context & lctx,
const llama_kv_cache & kv,
struct ggml_cgraph * graph,
+ struct ggml_tensor * wv_b,
struct ggml_tensor * wo,
- struct ggml_tensor * wo_b,
+ struct ggml_tensor * bo,
struct ggml_tensor * q_cur,
struct ggml_tensor * kq_mask,
int32_t n_tokens,
@@ -558,28 +571,28 @@ static struct ggml_tensor * llm_build_kqv(
const int64_t n_ctx = cparams.n_ctx;
const int64_t n_head = hparams.n_head(il);
const int64_t n_head_kv = hparams.n_head_kv(il);
- const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
- const int64_t n_embd_head_v = hparams.n_embd_head_v;
+ const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+ const int64_t n_embd_head_v = hparams.n_embd_head_v;
+
+ struct ggml_tensor * cur;
struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
cb(q, "q", il);
- struct ggml_tensor * k =
- ggml_view_3d(ctx, kv.k_l[il],
- n_embd_head_k, n_kv, n_head_kv,
- ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa),
- ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
- 0);
- cb(k, "k", il);
-
- struct ggml_tensor * cur;
-
if (cparams.flash_attn) {
GGML_UNUSED(model);
GGML_UNUSED(n_ctx);
+ struct ggml_tensor * k =
+ ggml_view_3d(ctx, kv.k_l[il],
+ n_embd_head_k, n_kv, n_head_kv,
+ ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa),
+ ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
+ 0);
+ cb(k, "k", il);
+
// split cached v into n_head heads (not transposed)
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il],
@@ -596,51 +609,137 @@ static struct ggml_tensor * llm_build_kqv(
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
} else {
- struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
- cb(kq, "kq", il);
- // note: this op tends to require high floating point range
- // while for some models F16 is enough, for others it is not, so we default to F32 here
- ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
+ // MLA converetd to MQA optimised to use non-batched matrix multiplies
+ if (cparams.mla_attn && model.arch == LLM_ARCH_DEEPSEEK2) {
+ const int64_t n_embd_head_k_mqa = hparams.n_lora_kv + hparams.n_rot;
+ const int64_t n_embd_head_v_mqa = hparams.n_lora_kv;
- if (model.arch == LLM_ARCH_GROK) {
- // need to do the following:
- // multiply by attn_output_multiplyer of 0.08838834764831845
- // and then :
- // kq = 30 * tanh(kq / 30)
- // before the softmax below
+ // must cont for the 2D view or else kq with have n_tokens <-> n_head swapped...
+ q = ggml_cont(ctx, q);
+ cb(q, "q_cont", il);
- kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
- kq = ggml_scale(ctx, kq, 30);
- }
+ q = ggml_view_2d(ctx, q,
+ n_embd_head_k_mqa, n_head * n_tokens,
+ ggml_row_size(q->type, n_embd_head_k_mqa),
+ 0);
+ cb(q, "q_view", il);
- if (hparams.attn_soft_cap) {
- kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
- kq = ggml_tanh(ctx, kq);
- kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
- }
+ struct ggml_tensor * k =
+ ggml_view_2d(ctx, kv.k_l[il],
+ n_embd_head_k_mqa, n_kv,
+ ggml_row_size(kv.k_l[il]->type, n_embd_head_k_mqa),
+ 0);
+ cb(k, "k", il);
- kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
- cb(kq, "kq_soft_max_ext", il);
+ struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
+ cb(kq, "kq", il);
- GGML_ASSERT(kv.size == n_ctx);
+ // note: this doesn't seem necessary
+ //ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
- // split cached v into n_head heads
- struct ggml_tensor * v =
- ggml_view_3d(ctx, kv.v_l[il],
- n_kv, n_embd_head_v, n_head_kv,
- ggml_element_size(kv.v_l[il])*n_ctx,
- ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
+ kq = ggml_view_3d(ctx, kq,
+ n_kv, n_tokens, n_head,
+ ggml_row_size(kq->type, n_kv),
+ ggml_row_size(kq->type, n_kv * n_tokens),
0);
- cb(v, "v", il);
+ cb(kq, "kq_view", il);
+
+ kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
+ cb(kq, "kq_soft_max_ext", il);
+
+ kq = ggml_view_2d(ctx, kq,
+ n_kv, n_tokens * n_head,
+ ggml_row_size(kq->type, n_kv),
+ 0);
+ cb(kq, "kq_soft_max_view", il);
+
+ GGML_ASSERT(kv.size == n_ctx);
+
+ struct ggml_tensor * v =
+ ggml_view_2d(ctx, kv.v_l[il],
+ n_kv, n_embd_head_v_mqa,
+ ggml_element_size(kv.v_l[il])*n_ctx,
+ 0);
+ cb(v, "v", il);
+
+ struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
+ cb(kqv, "kqv_compressed", il);
+
+ kqv = ggml_view_3d(ctx, kqv,
+ n_embd_head_v_mqa, n_tokens, n_head,
+ ggml_row_size(kqv->type, n_embd_head_v_mqa),
+ ggml_row_size(kqv->type, n_embd_head_v_mqa * n_tokens),
+ 0);
+ cb(kqv, "kqv_view", il);
+
+ struct ggml_tensor * wv_b_view =
+ ggml_view_3d(ctx, wv_b, n_embd_head_v_mqa, n_embd_head_v, n_head,
+ ggml_row_size(wv_b->type, n_embd_head_v_mqa),
+ ggml_row_size(wv_b->type, n_embd_head_v * n_embd_head_v_mqa),
+ 0);
+ cb(wv_b_view, "wv_b_view", il);
+
+ // dsecompress the MQA to MHA
+ cur = ggml_mul_mat(ctx, wv_b_view, kqv);
+ cb(cur, "kqv", il);
+
+ // standard MHA/GQA non-flash-attension case
+ } else {
+ struct ggml_tensor * k =
+ ggml_view_3d(ctx, kv.k_l[il],
+ n_embd_head_k, n_kv, n_head_kv,
+ ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa),
+ ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
+ 0);
+ cb(k, "k", il);
+
+ struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
+ cb(kq, "kq", il);
+
+ // note: this op tends to require high floating point range
+ // while for some models F16 is enough, for others it is not, so we default to F32 here
+ ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
+
+ if (model.arch == LLM_ARCH_GROK) {
+ // need to do the following:
+ // multiply by attn_output_multiplyer of 0.08838834764831845
+ // and then :
+ // kq = 30 * tanh(kq / 30)
+ // before the softmax below
+
+ kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
+ kq = ggml_scale(ctx, kq, 30);
+ }
+
+ if (hparams.attn_soft_cap) {
+ kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
+ kq = ggml_tanh(ctx, kq);
+ kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
+ }
- struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
- cb(kqv, "kqv", il);
+ kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
+ cb(kq, "kq_soft_max_ext", il);
+
+ GGML_ASSERT(kv.size == n_ctx);
+
+ // split cached v into n_head heads
+ struct ggml_tensor * v =
+ ggml_view_3d(ctx, kv.v_l[il],
+ n_kv, n_embd_head_v, n_head_kv,
+ ggml_element_size(kv.v_l[il])*n_ctx,
+ ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
+ 0);
+ cb(v, "v", il);
- struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
- cb(kqv_merged, "kqv_merged", il);
+ cur = ggml_mul_mat(ctx, v, kq);
+ cb(cur, "kqv", il);
+ }
+
+ cur = ggml_permute(ctx, cur, 0, 2, 1, 3);
+ cb(cur, "kqv_merged", il);
- cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens);
+ cur = ggml_cont_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
cb(cur, "kqv_merged_cont", il);
}
@@ -650,12 +749,12 @@ static struct ggml_tensor * llm_build_kqv(
cur = llm_build_lora_mm(lctx, ctx, wo, cur);
}
- if (wo_b) {
+ if (bo) {
cb(cur, "kqv_wo", il);
}
- if (wo_b) {
- cur = ggml_add(ctx, cur, wo_b);
+ if (bo) {
+ cur = ggml_add(ctx, cur, bo);
}
return cur;
@@ -666,8 +765,9 @@ static struct ggml_tensor * llm_build_kv(
struct llama_context & lctx,
const llama_kv_cache & kv,
struct ggml_cgraph * graph,
+ struct ggml_tensor * wv_b,
struct ggml_tensor * wo,
- struct ggml_tensor * wo_b,
+ struct ggml_tensor * bo,
struct ggml_tensor * k_cur,
struct ggml_tensor * v_cur,
struct ggml_tensor * q_cur,
@@ -678,20 +778,17 @@ static struct ggml_tensor * llm_build_kv(
float kq_scale,
const llm_build_cb & cb,
int il) {
- const llama_hparams & hparams = lctx.model.hparams;
- const llama_cparams & cparams = lctx.cparams;
-
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(graph, q_cur);
ggml_build_forward_expand(graph, k_cur);
ggml_build_forward_expand(graph, v_cur);
- llm_build_kv_store(ctx, hparams, cparams, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il);
+ llm_build_kv_store(ctx, lctx, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il);
struct ggml_tensor * cur;
- cur = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b, q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il);
+ cur = llm_build_kqv(ctx, lctx, kv, graph, wv_b, wo, bo, q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il);
cb(cur, "kqv_out", il);
return cur;
@@ -1546,7 +1643,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, model.layers[il].bo,
+ nullptr, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
}
@@ -1723,7 +1820,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, model.layers[il].bo,
+ nullptr, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
}
@@ -1861,7 +1958,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, NULL,
+ nullptr, model.layers[il].wo, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -1966,7 +2063,7 @@ struct llm_build_context {
);
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, NULL,
+ nullptr, model.layers[il].wo, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -2087,7 +2184,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, NULL,
+ nullptr, model.layers[il].wo, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -2211,7 +2308,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, model.layers[il].bo,
+ nullptr, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}
@@ -2363,7 +2460,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, NULL,
+ nullptr, model.layers[il].wo, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -2475,7 +2572,7 @@ struct llm_build_context {
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, model.layers[il].bo,
+ nullptr, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -2569,7 +2666,7 @@ struct llm_build_context {
cb(Qcur, "Qcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, NULL,
+ nullptr, model.layers[il].wo, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -2864,7 +2961,7 @@ struct llm_build_context {
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, model.layers[il].bo,
+ nullptr, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -2996,13 +3093,13 @@ struct llm_build_context {
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, model.layers[il].bo,
+ nullptr, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
} else {
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, model.layers[il].bo,
+ nullptr, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
}
@@ -3147,7 +3244,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, NULL,
+ nullptr, model.layers[il].wo, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -3266,7 +3363,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, NULL,
+ nullptr, model.layers[il].wo, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -3380,7 +3477,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, model.layers[il].bo,
+ nullptr, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -3498,7 +3595,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, model.layers[il].bo,
+ nullptr, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -3613,7 +3710,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, model.layers[il].bo,
+ nullptr, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -3772,7 +3869,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, model.layers[il].bo,
+ nullptr, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}
@@ -3897,7 +3994,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, model.layers[il].bo,
+ nullptr, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}
@@ -4022,7 +4119,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, NULL,
+ nullptr, model.layers[il].wo, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
struct ggml_tensor * sa_out = cur;
@@ -4124,7 +4221,7 @@ struct llm_build_context {
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, model.layers[il].bo,
+ nullptr, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -4235,7 +4332,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, model.layers[il].bo,
+ nullptr, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -4355,7 +4452,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, NULL,
+ nullptr, model.layers[il].wo, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -4473,7 +4570,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, model.layers[il].bo,
+ nullptr, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -4667,7 +4764,7 @@ struct llm_build_context {
cb(k_states, "k_states", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, NULL,
+ nullptr, model.layers[il].wo, nullptr,
k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
}
@@ -4789,7 +4886,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, NULL,
+ nullptr, model.layers[il].wo, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}
@@ -4908,7 +5005,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, NULL,
+ nullptr, model.layers[il].wo, nullptr,
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}
@@ -5045,7 +5142,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, model.layers[il].bo,
+ nullptr, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -5244,7 +5341,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, model.layers[il].bo,
+ nullptr, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -5381,8 +5478,9 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
}
- cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur,
- KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il);
+ cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+ nullptr, model.layers[il].wo, model.layers[il].bo,
+ Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
@@ -5507,7 +5605,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, nullptr,
+ nullptr, model.layers[il].wo, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -5626,7 +5724,7 @@ struct llm_build_context {
cb(Kcur, "Kcur_rope", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, NULL,
+ nullptr, model.layers[il].wo, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -5758,7 +5856,7 @@ struct llm_build_context {
cb(Kcur, "Kcur_rope", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, NULL,
+ nullptr, model.layers[il].wo, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -5888,7 +5986,7 @@ struct llm_build_context {
cb(Qcur, "Vcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, NULL,
+ nullptr, model.layers[il].wo, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -5997,7 +6095,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, model.layers[il].bo,
+ nullptr, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -6140,7 +6238,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, NULL,
+ nullptr, model.layers[il].wo, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -6289,7 +6387,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, model.layers[il].bo,
+ nullptr, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
}
@@ -6388,8 +6486,8 @@ struct llm_build_context {
const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k));
const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
- const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+ const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
struct ggml_tensor * cur;
@@ -6407,7 +6505,6 @@ struct llm_build_context {
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
- // norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, cb, il);
@@ -6415,115 +6512,160 @@ struct llm_build_context {
// self_attention
{
- struct ggml_tensor * q = NULL;
+ struct ggml_tensor * q_nope = nullptr;
+ struct ggml_tensor * q_mqa = nullptr;
if (!is_lite) {
// {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens}
- q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
- cb(q, "q", il);
+ struct ggml_tensor * q_compressed = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
+ cb(q_compressed, "q_compressed", il);
- q = llm_build_norm(ctx0, q, hparams,
- model.layers[il].attn_q_a_norm, NULL,
+ q_compressed = llm_build_norm(ctx0, q_compressed, hparams,
+ model.layers[il].attn_q_a_norm, nullptr,
LLM_NORM_RMS, cb, il);
- cb(q, "q", il);
+ cb(q_compressed, "q_compressed_norm", il);
- // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens}
- q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
- cb(q, "q", il);
+ // {q_lora_rank, n_head * n_embd_head_qk_nope} * {q_lora_rank, n_tokens} -> {n_head * n_embd_head_qk_nope, n_tokens}
+ q_nope = ggml_mul_mat(ctx0, model.layers[il].wq_b, q_compressed);
+ cb(q_nope, "q_nope", il);
+
+ // {q_lora_rank, n_head * n_embd_head_qk_rope} * {q_lora_rank, n_tokens} -> {n_head * n_embd_head_qk_rope, n_tokens}
+ q_mqa = ggml_mul_mat(ctx0, model.layers[il].wq_b_mqa, q_compressed);
+ cb(q_mqa, "q_mqa", il);
} else {
- q = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
- cb(q, "q", il);
+ // {n_embd, n_head * n_embd_head_qk_nope} * {n_embd, n_tokens} -> {n_head * n_embd_head_qk_nope, n_tokens}
+ q_nope = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+ cb(q_nope, "q_nope", il);
+
+ // {n_embd, n_head * n_embd_head_qk_rope} * {n_embd, n_tokens} -> {n_head * n_embd_head_qk_rope, n_tokens}
+ q_mqa = ggml_mul_mat(ctx0, model.layers[il].wq_mqa, cur);
+ cb(q_mqa, "q_mqa", il);
}
- // split into {n_head * n_embd_head_qk_nope, n_tokens}
- struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
- ggml_row_size(q->type, hparams.n_embd_head_k),
- ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
+ // {n_embd_head_qk_nope, n_head, n_tokens}
+ struct ggml_tensor * q_nope_view = ggml_view_3d(ctx0, q_nope, n_embd_head_qk_nope, n_head, n_tokens,
+ ggml_row_size(q_nope->type, n_embd_head_qk_nope),
+ ggml_row_size(q_nope->type, n_head * n_embd_head_qk_nope),
0);
- cb(q_nope, "q_nope", il);
+ cb(q_nope_view, "q_nope_view", il);
- // and {n_head * n_embd_head_qk_rope, n_tokens}
- struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
- ggml_row_size(q->type, hparams.n_embd_head_k),
- ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
- ggml_row_size(q->type, n_embd_head_qk_nope));
- cb(q_pe, "q_pe", il);
+ // {n_embd_head_qk_rope, n_head, n_tokens}
+ struct ggml_tensor * q_mqa_view = ggml_view_3d(ctx0, q_mqa, n_embd_head_qk_rope, n_head, n_tokens,
+ ggml_row_size(q_mqa->type, n_embd_head_qk_rope),
+ ggml_row_size(q_mqa->type, n_head * n_embd_head_qk_rope),
+ 0);
+ cb(q_mqa_view, "q_mqa_view", il);
- // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
- struct ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
- cb(kv_pe_compresseed, "kv_pe_compresseed", il);
+ q_mqa_view = ggml_rope_ext(ctx0, q_mqa_view, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor_scaled, beta_fast, beta_slow
+ );
+ cb(q_mqa_view, "q_mqa_view_rope", il);
- // split into {kv_lora_rank, n_tokens}
- struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
- kv_pe_compresseed->nb[1],
- 0);
+ // {n_embd, kv_lora_rank} * {n_embd, n_tokens} -> {kv_lora_rank, n_tokens}
+ struct ggml_tensor * kv_compressed = ggml_mul_mat(ctx0, model.layers[il].wkv_a, cur);
cb(kv_compressed, "kv_compressed", il);
- // and {n_embd_head_qk_rope, n_tokens}
- struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
- kv_pe_compresseed->nb[1],
- kv_pe_compresseed->nb[1],
- ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
- cb(k_pe, "k_pe", il);
-
- // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont
- kv_compressed = ggml_cont(ctx0, kv_compressed);
kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
- model.layers[il].attn_kv_a_norm, NULL,
+ model.layers[il].attn_kv_a_norm, nullptr,
LLM_NORM_RMS, cb, il);
- cb(kv_compressed, "kv_compressed", il);
+ cb(kv_compressed, "kv_compressed_norm", il);
- // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
- struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
- cb(kv, "kv", il);
+ // {n_embd, n_embd_head_qk_rope} * {n_embd, n_tokens} -> {n_embd_head_qk_rope, n_tokens}
+ struct ggml_tensor * k_mqa = ggml_mul_mat(ctx0, model.layers[il].wk_mqa, cur);
+ cb(k_mqa, "k_mqa", il);
- // split into {n_head * n_embd_head_qk_nope, n_tokens}
- struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
- ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
- ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+ // {n_embd_head_qk_rope, 1, n_tokens}
+ struct ggml_tensor * k_mqa_view = ggml_view_3d(ctx0, k_mqa, n_embd_head_qk_rope, 1, n_tokens,
+ ggml_row_size(k_mqa->type, n_embd_head_qk_rope),
+ ggml_row_size(k_mqa->type, n_embd_head_qk_rope),
0);
- cb(k_nope, "k_nope", il);
+ cb(k_mqa_view, "k_mqa_view", il);
- // and {n_head * n_embd_head_v, n_tokens}
- struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
- ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
- ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
- ggml_row_size(kv->type, (n_embd_head_qk_nope)));
- cb(v_states, "v_states", il);
-
- v_states = ggml_cont(ctx0, v_states);
- cb(v_states, "v_states", il);
-
- v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
- ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
- 0);
- cb(v_states, "v_states", il);
-
- q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
- q_pe = ggml_rope_ext(
- ctx0, q_pe, inp_pos, nullptr,
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
- ext_factor, attn_factor_scaled, beta_fast, beta_slow
+ k_mqa_view = ggml_rope_ext(ctx0, k_mqa_view, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor_scaled, beta_fast, beta_slow
);
- cb(q_pe, "q_pe", il);
+ cb(k_mqa_view, "k_mqa_view_rope", il);
+
+ if (!cparams.mla_attn) {
+ // {kv_lora_rank, n_head * n_embd_head_qk_nope} * {kv_lora_rank, n_tokens} -> {n_head * n_embd_head_qk_nope, n_tokens}
+ struct ggml_tensor * k_nope = ggml_mul_mat(ctx0, model.layers[il].wk_b, kv_compressed);
+ cb(k_nope, "k_nope", il);
+
+ // {n_embd_head_qk_nope, n_head, n_tokens}
+ struct ggml_tensor * k_nope_view = ggml_view_3d(ctx0, k_nope, n_embd_head_qk_nope, n_head, n_tokens,
+ ggml_row_size(k_nope->type, n_embd_head_qk_nope),
+ ggml_row_size(k_nope->type, n_head * n_embd_head_qk_nope),
+ 0);
+ cb(k_nope_view, "k_nope_view", il);
+
+ // TODO: build_k_shift() and build_defrag(); the RoPEed part is the first n_rot as they expect
+ // {n_embd_head_qk_rope + n_embd_head_qk_nope, n_head, n_tokens}
+ struct ggml_tensor * q_states = ggml_concat(ctx0, q_mqa_view, q_nope_view, 0);
+ cb(q_states, "q_states", il);
+
+ // TODO: build_k_shift() and build_defrag(); the RoPEed part is the first n_rot as they expect
+ // {n_embd_head_qk_rope + n_embd_head_qk_nope, n_head, n_tokens}
+ struct ggml_tensor * k_states = ggml_concat(ctx0, ggml_repeat(ctx0, k_mqa_view, q_mqa_view), k_nope_view, 0);
+ cb(k_states, "k_states", il);
+
+ // {kv_lora_rank, n_head * n_embd_head_v} * {kv_lora_rank, n_tokens} -> {n_head * n_embd_head_v, n_tokens}
+ struct ggml_tensor * v_states = ggml_mul_mat(ctx0, model.layers[il].wv_b, kv_compressed);
+ cb(v_states, "v_states", il);
+
+ // note: this has essentially converted MLA into MHA (with very large KV-cache overhead)
+ cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+ nullptr, model.layers[il].wo, nullptr,
+ k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
+ } else {
+ // {n_embd_head_qk_nope, kv_lora_rank, n_head}
+ struct ggml_tensor * wk_b_trans_view = ggml_view_3d(ctx0, model.layers[il].wk_b_trans,
+ n_embd_head_qk_nope, kv_lora_rank, n_head,
+ ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope),
+ ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope),
+ 0);
+ cb(wk_b_trans_view, "wk_b_trans_view", il);
+
+ // {n_embd_head_qk_nope, n_tokens, n_head}
+ q_nope_view = ggml_permute(ctx0, q_nope_view, 0, 2, 1, 3);
+ cb(q_nope_view, "q_nope_view_perm", il);
+
+ // {n_embd_head_qk_nope, kv_lora_rank, n_head} * {n_embd_head_qk_nope, n_tokens, n_head} = {kv_lora_rank, n_tokens, n_head}
+ struct ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, wk_b_trans_view, q_nope_view);
+ cb(q_nope_absorbed, "q_nope_absorbed", il);
+
+ // {n_embd_head_qk_rope, n_head, n_tokens}
+ q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3);
+ cb(q_nope_absorbed, "q_nope_absorbed_perm", il);
+
+ // TODO: build_k_shift() and build_defrag(); the RoPEed part is the first n_rot as they expect
+ // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
+ struct ggml_tensor * q_states = ggml_concat(ctx0, q_mqa_view, q_nope_absorbed, 0);
+ cb(q_states, "q_states", il);
+
+ // {kv_lora_rank, 1, n_tokens}
+ struct ggml_tensor * kv_compressed_view = ggml_view_3d(ctx0, kv_compressed,
+ kv_lora_rank, 1, n_tokens,
+ ggml_row_size(k_mqa->type, kv_lora_rank),
+ ggml_row_size(k_mqa->type, kv_lora_rank),
+ 0);
+ cb(kv_compressed_view, "kv_compressed_view", il);
- // shared RoPE key
- k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
- k_pe = ggml_rope_ext(
- ctx0, k_pe, inp_pos, nullptr,
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
- ext_factor, attn_factor_scaled, beta_fast, beta_slow
- );
- cb(k_pe, "k_pe", il);
+ // TODO: build_k_shift() and build_defrag(); the RoPEed part is the first n_rot as they expect
+ // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
+ struct ggml_tensor * k_states = ggml_concat(ctx0, k_mqa_view, kv_compressed_view, 0);
+ cb(k_states, "k_states", il);
- struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
- cb(q_states, "q_states", il);
+ // {kv_lora_rank, 1, n_tokens}
+ struct ggml_tensor * v_states = kv_compressed;
+ cb(v_states, "v_states", il);
- struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
- cb(k_states, "k_states", il);
+ // note: this has essentially converted MLA into MQA (with very low KV-cache overhead)
+ cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+ model.layers[il].wv_b, model.layers[il].wo, nullptr,
+ k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
+ }
- cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, NULL,
- k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
}
if (il == n_layer - 1) {
@@ -6680,7 +6822,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- NULL, NULL,
+ nullptr, nullptr, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cur = llm_build_norm(ctx0, cur, hparams,
@@ -6932,7 +7074,7 @@ struct llm_build_context {
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
- llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
+ llm_build_kv_store(ctx0, lctx, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
struct ggml_tensor * k =
ggml_view_3d(ctx0, kv_self.k_l[il],
@@ -7134,7 +7276,7 @@ struct llm_build_context {
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, model.layers[il].bo,
+ nullptr, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/float(n_embd_head), cb, il);
}
@@ -7260,7 +7402,7 @@ struct llm_build_context {
cb(Kcur, "Kcur_rope", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, NULL,
+ nullptr, model.layers[il].wo, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -7379,7 +7521,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, model.layers[il].bo,
+ nullptr, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -7505,7 +7647,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, model.layers[il].bo,
+ nullptr, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
@@ -7876,7 +8018,7 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, nullptr,
+ nullptr, model.layers[il].wo, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
if (hparams.swin_norm) {
@@ -9358,6 +9500,7 @@ struct llama_context_params llama_context_default_params() {
/*.embeddings =*/ false,
/*.offload_kqv =*/ true,
/*.flash_attn =*/ false,
+ /*.mla_attn =*/ false,
/*.no_perf =*/ true,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
@@ -9594,6 +9737,7 @@ struct llama_context * llama_init_from_model(
cparams.embeddings = params.embeddings;
cparams.offload_kqv = params.offload_kqv;
cparams.flash_attn = params.flash_attn;
+ cparams.mla_attn = params.mla_attn;
cparams.no_perf = params.no_perf;
cparams.pooling_type = params.pooling_type;