Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
},
},
{
Expand All @@ -777,6 +778,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
},
},
{
Expand Down
69 changes: 69 additions & 0 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "llama-impl.h"
#include "llama-batch.h"
#include "llama-cparams.h"
#include "llama-model.h"

#include "llama-kv-cache.h"
#include "llama-kv-cache-iswa.h"
Expand Down Expand Up @@ -632,6 +633,74 @@ ggml_tensor * llm_graph_context::build_lora_mm(
return res;
}

static bool disable_fusion() {
const char * disable_fusion = getenv("LLAMA_GRAPH_DISABLE_FUSION");
return disable_fusion != nullptr && atoi(disable_fusion) != 0;
}


void llm_graph_context::build_qkv(const llama_layer & layer,
ggml_tensor * cur,
int64_t n_embd_head_q,
int64_t n_embd_head_k,
int64_t n_embd_head_v,
int32_t n_head,
int32_t n_head_kv,
ggml_tensor ** q_out,
ggml_tensor ** k_out,
ggml_tensor ** v_out,
int il) const {
if (disable_fusion() || !layer.wqkv || (loras && !loras->empty())) {
*q_out = build_lora_mm(layer.wq, cur);
cb(*q_out, "Qcur", il);

*k_out = build_lora_mm(layer.wk, cur);
cb(*k_out, "Kcur", il);

*v_out = build_lora_mm(layer.wv, cur);
cb(*v_out, "Vcur", il);

*q_out = ggml_reshape_3d(ctx0, *q_out, n_embd_head_q, n_head, n_tokens);
*k_out = ggml_reshape_3d(ctx0, *k_out, n_embd_head_k, n_head_kv, n_tokens);
*v_out = ggml_reshape_3d(ctx0, *v_out, n_embd_head_v, n_head_kv, n_tokens);

return;
}


ggml_tensor * qkv = ggml_mul_mat(ctx0, layer.wqkv, cur);
cb(qkv, "wqkv", il);

const int64_t q_offset = 0;
const int64_t k_offset = n_embd_head_q * n_head;
const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv;
const size_t elt_size = ggml_element_size(qkv);

ggml_tensor * Qcur = ggml_view_3d(
ctx0, qkv,
n_embd_head_q, n_head, n_tokens,
n_embd_head_q * elt_size, qkv->nb[1],
q_offset * elt_size);
ggml_tensor * Kcur = ggml_view_3d(
ctx0, qkv,
n_embd_head_k, n_head_kv, n_tokens,
n_embd_head_k * elt_size, qkv->nb[1],
k_offset * elt_size);
ggml_tensor * Vcur = ggml_view_3d(
ctx0, qkv,
n_embd_head_v, n_head_kv, n_tokens,
n_embd_head_v * elt_size, qkv->nb[1],
v_offset * elt_size);

cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);

*q_out = Qcur;
*k_out = Kcur;
*v_out = Vcur;
}

ggml_tensor * llm_graph_context::build_lora_mm_id(
ggml_tensor * w, // ggml_tensor * as
ggml_tensor * cur, // ggml_tensor * b
Expand Down
14 changes: 14 additions & 0 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class llama_kv_cache_iswa_context;
class llama_memory_recurrent_context;
class llama_memory_hybrid_context;

struct llama_layer;

// certain models (typically multi-modal) can produce different types of graphs
enum llm_graph_type {
LLM_GRAPH_TYPE_DEFAULT,
Expand Down Expand Up @@ -604,6 +606,18 @@ struct llm_graph_context {
ggml_tensor * w,
ggml_tensor * cur) const;

void build_qkv(const llama_layer & layer,
ggml_tensor * cur,
int64_t n_embd_head_q,
int64_t n_embd_head_k,
int64_t n_embd_head_v,
int32_t n_head,
int32_t n_head_kv,
ggml_tensor ** q_out,
ggml_tensor ** k_out,
ggml_tensor ** v_out,
int il) const;

// do mat_mul_id, while optionally apply lora
ggml_tensor * build_lora_mm_id(
ggml_tensor * w, // ggml_tensor * as
Expand Down
34 changes: 34 additions & 0 deletions src/llama-model-loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,40 @@ struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_conte
return tensor;
}

struct ggml_tensor * llama_model_loader::create_contiguous_tensor(struct ggml_context * ctx, const std::string & fused_name, const std::initializer_list<int64_t> & ne
, std::vector<ggml_tensor**> tensors, int flags) {

(void)flags;

if (weights_map.find(fused_name) != weights_map.end()) {
return nullptr;
}

if (ggml_get_tensor(ctx, fused_name.c_str()) != nullptr) {
return nullptr;
}

const ggml_type type = (*tensors[0])->type;

struct ggml_tensor * fused = ggml_new_tensor(ctx, type, ne.size(), ne.begin());

if (!fused) {
return nullptr;
}

ggml_set_name(fused, fused_name.c_str());

size_t offset = 0;
for (ggml_tensor **tensor : tensors) {
std::initializer_list<int64_t> ne = { (*tensor)->ne[0], (*tensor)->ne[1], (*tensor)->ne[2], (*tensor)->ne[3] };
struct ggml_tensor * view = create_tensor_as_view(ctx, fused, ggml_get_name(*tensor), ne, offset, false);
*tensor = view;
offset += ggml_nbytes(*tensor);
}

return fused;
}

void llama_model_loader::done_getting_tensors() const {
if (n_created != n_tensors) {
throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created));
Expand Down
3 changes: 3 additions & 0 deletions src/llama-model-loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ struct llama_model_loader {

struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list<int64_t> & ne, size_t offset, bool required = true);

struct ggml_tensor * create_contiguous_tensor(struct ggml_context * ctx, const std::string & fused_name, const std::initializer_list<int64_t> & ne
, std::vector<ggml_tensor**> tensors, int flags = 0);

void done_getting_tensors() const;

void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr);
Expand Down
176 changes: 150 additions & 26 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "llama-model.h"

#include "gguf.h"
#include "llama-impl.h"
#include "llama-mmap.h"
#include "llama-batch.h"
Expand Down Expand Up @@ -2428,6 +2429,115 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
return ml.create_tensor(ctx, tn, ne, flags);
};

struct tensor_def {
LLM_TN_IMPL tn;
std::vector<int64_t> ne;
int flags;
ggml_tensor ** out;
};

auto create_contiguous = [&](const LLM_TN_IMPL & fused_tn,
std::initializer_list<int64_t> ne,
std::initializer_list<tensor_def> reqs) -> ggml_tensor * {
ggml_backend_buffer_type_t fused_buft = nullptr;

std::vector<const ggml_tensor*> tensor_metas;

for (size_t i = 0; i < reqs.size(); ++i) {
const tensor_def & req = reqs.begin()[i];
const bool required = (req.flags & llama_model_loader::TENSOR_NOT_REQUIRED) == 0;
const ggml_tensor * tensor_meta = ml.check_tensor_dims(req.tn.str(), req.ne, required);

if (!tensor_meta) {
return nullptr;
}

tensor_metas.push_back(tensor_meta);

*req.out = const_cast<ggml_tensor*>(tensor_meta);

if (!*req.out) {
return nullptr;
}

llm_tensor tn_tensor = req.tn.tensor;
if (tn_tensor == LLM_TENSOR_TOKEN_EMBD && (req.flags & llama_model_loader::TENSOR_DUPLICATED)) {
tn_tensor = LLM_TENSOR_OUTPUT;
}

llm_tensor_info info;
try {
info = llm_tensor_info_for(tn_tensor);
} catch (const std::out_of_range &) {
throw std::runtime_error(format("missing tensor info mapping for %s", req.tn.str().c_str()));
}

bool bias = req.tn.suffix != nullptr && strcmp(req.tn.suffix, "bias") == 0;
ggml_op op = bias ? (info.op == GGML_OP_MUL_MAT_ID ? GGML_OP_ADD_ID : GGML_OP_ADD) : info.op;

buft_list_t * buft_list = nullptr;
switch (info.layer) {
case LLM_TENSOR_LAYER_INPUT:
buft_list = pimpl->dev_input.buft_list;
break;
case LLM_TENSOR_LAYER_OUTPUT:
buft_list = pimpl->dev_output.buft_list;
break;
case LLM_TENSOR_LAYER_REPEATING:
buft_list = pimpl->dev_layer.at(req.tn.bid).buft_list;
break;
default:
GGML_ABORT("invalid layer %d for tensor %s", info.layer, req.tn.str().c_str());
}

ggml_backend_buffer_type_t buft = select_weight_buft(hparams, *req.out, op, *buft_list);
if (!buft) {
return nullptr;
}

auto * buft_dev = ggml_backend_buft_get_device(buft);
if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) {
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (!cpu_dev) {
throw std::runtime_error("no CPU backend found");
}
buft = ggml_backend_dev_buffer_type(cpu_dev);
}

//TODO: check buft overrides

if (!fused_buft) {
fused_buft = buft;
} else if (fused_buft != buft) {
return nullptr;
}
}

if (!fused_buft) {
return nullptr;
}

ggml_context * ctx = ctx_for_buft(fused_buft);

std::vector<ggml_tensor**> tensor_req{reqs.size()};

ggml_type type = tensor_metas[0]->type;
for (size_t i = 0; i < reqs.size(); ++i) {

// types are not same
if (tensor_metas[i]->type != type) {
return nullptr;
}

const auto & req = reqs.begin()[i];
tensor_req[i] = req.out;
}

ggml_tensor * fused = ml.create_contiguous_tensor(ctx, fused_tn.str(), ne, tensor_req, 0);

return fused;
};

layers.resize(n_layer);

// TODO: move to a separate function
Expand Down Expand Up @@ -3297,9 +3407,19 @@ bool llama_model::load_tensors(llama_model_loader & ml) {

layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);

layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wqkv = create_contiguous(
tn(LLM_TENSOR_ATTN_QKV, "weight", i),
{n_embd, n_embd_head_k * n_head + n_embd_gqa * 2},
{
{ tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0, &layer.wq },
{ tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0, &layer.wk },
{ tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0, &layer.wv },
});
if (!layer.wqkv) {
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
}
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);

layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
Expand Down Expand Up @@ -3328,9 +3448,19 @@ bool llama_model::load_tensors(llama_model_loader & ml) {

layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);

layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wqkv = create_contiguous(
tn(LLM_TENSOR_ATTN_QKV, "weight", i),
{n_embd, n_embd_head_k * n_head + n_embd_gqa * 2},
{
{ tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0, &layer.wq },
{ tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0, &layer.wk },
{ tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0, &layer.wv },
});
if (!layer.wqkv) {
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
}
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);

layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
Expand Down Expand Up @@ -9388,18 +9518,15 @@ struct llm_build_qwen3 : public llm_graph_context {
// self-attention
{
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);

ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);

ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
ggml_tensor * Qcur = nullptr;
ggml_tensor * Kcur = nullptr;
ggml_tensor * Vcur = nullptr;

Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
build_qkv(model.layers[il], cur, n_embd_head,
n_embd_head_k, n_embd_head_v, n_head, n_head_kv,
&Qcur, &Kcur, &Vcur, il
);

Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
Expand Down Expand Up @@ -9509,18 +9636,15 @@ struct llm_build_qwen3moe : public llm_graph_context {
// self_attention
{
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);

ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);

ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
ggml_tensor * Qcur = nullptr;
ggml_tensor * Kcur = nullptr;
ggml_tensor * Vcur = nullptr;

Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
build_qkv(model.layers[il], cur, n_embd_head,
n_embd_head_k, n_embd_head_v, n_head, n_head_kv,
&Qcur, &Kcur, &Vcur, il
);

Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
Expand Down
Loading