Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 63 additions & 4 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,24 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
}
}

bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_memory_recurrent_context *>(params.mctx);

this->mctx = mctx;

bool res = true;

res &= s_copy->ne[0] == mctx->get_n_rs();
Copy link
Collaborator

@compilade compilade Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mctx->get_head() (the start of the slot) and mctx->get_rs_z() (the first zeroed state) are used in view offsets, and so would need to match too, otherwise the graph can't really be re-used.

The case where they wouldn't match (but n_rs matches) is when ubatches of the same size with different sequences are used.

E.g. seq_ids 0, 1, with 1 token and then seq_ids 2, 3 with 1 token, in consecutive ubatches, repeatedly.

This probably happens when using -ub 1 in the llama-parallel example, I think (because it uses a single seq_id per ubatch at a time, but ends up using different seq_ids while using the same size of ubatches).

(Note that I didn't actually test the changes yet, so I don't know if this is a real problem)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a check earlier for whether the sequences are the same:

llama.cpp/src/llama-graph.h

Lines 443 to 457 in 638e2c2

// when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same
// the reason is because the set of attention streams would be different for different sequences
if (can_reuse_ubatch && ubatch.equal_seqs()) {
if (!ubatch.data) {
// if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and
// therefore we cannot perform the sequence id check. normally should never happen
can_reuse_ubatch = false;
} else {
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s];
}
}
}

This check applies to all graphs and if not satisfied, we don't attempt to reuse the graph. I think this should cover this case.

Copy link
Collaborator

@compilade compilade Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a check earlier for whether the sequences are the same

Right, this should cover the case where different sequences are used.

However, I don't think it covers the case when a sequence is cleared (which will make mctx->get_rs_z() differ).


I'm noticing different perplexity with and without graph-reuse with a Q8_0 mamba-130m on CPU.

(this is on the first 10 chunks of calibration_datav3)

params LLAMA_GRAPH_REUSE_DISABLE PPL
-b 512 0 7.7852
-b 2048 0 7.8628
-b 512 1 7.7852
-b 2048 1 7.7852

I'm not sure it's caused by what exactly, but I'm suspecting it's either related to rs_z or head (since this doesn't seem to happen with non-recurrent models (I tested with a Q8_0 TinyLlama)).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ggerganov
Checking for head and rs_z mismatch does seem to help with the case in my previous comment, making the graph-reuse case have the same PPL as when it's not used.

Patch with changes
diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
index 7f0c974f1..aad42d62d 100644
--- a/src/llama-graph.cpp
+++ b/src/llama-graph.cpp
@@ -258,6 +258,9 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
 
     bool res = true;
 
+    res &= this->head == mctx->get_head();
+    res &= this->rs_z == mctx->get_rs_z();
+
     res &= s_copy->ne[0] == mctx->get_n_rs();
 
     res &= s_copy_main->ne[0]  == params.ubatch.n_seqs;
@@ -482,6 +485,9 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
     res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
     res &= inp_attn->self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
 
+    res &= inp_rs->head == mctx->get_recr()->get_head();
+    res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
+
     res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
 
     res &= inp_rs->s_copy_main->ne[0]  == params.ubatch.n_seqs;
@@ -1827,6 +1833,9 @@ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
     inp->s_copy_main  = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
     inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
 
+    inp->head = mctx_cur->get_head();
+    inp->rs_z = mctx_cur->get_rs_z();
+
     return inp;
 }
 
diff --git a/src/llama-graph.h b/src/llama-graph.h
index 394e88432..a596461bb 100644
--- a/src/llama-graph.h
+++ b/src/llama-graph.h
@@ -234,6 +234,10 @@ public:
     ggml_tensor * s_copy_extra;  // I32 [n_rs - n_seqs]
 
     const llama_memory_recurrent_context * mctx;
+
+    // used in view offsets, need to match for valid graph reuse
+    uint32_t head;
+    int32_t rs_z;
 };
 
 class llm_graph_input_cross_embd : public llm_graph_input_i {

It might not be ideal to expose another way to get head and rs_z. But the constructor of llm_graph_input_rs would need access to llama-memory-recurrent.h to use mctx->get_head() and mctx->get_rs_z().

Strangely enough, hybrid models like Falcon-H1 don't manifest the same problem as mamba-130m; I can't reproduce the original problem with that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might not be ideal to expose another way to get head and rs_z. But the constructor of llm_graph_input_rs would need access to llama-memory-recurrent.h to use mctx->get_head() and mctx->get_rs_z().

Can you clarify what you mean here? The proposed solution seems OK to me.

On a related topic, would it be possible to avoid these offsets through the use of ggml_set_rows() in a similar way as we avoided the KV cache offset for the regular attention?

Copy link
Member Author

@ggerganov ggerganov Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@compilade I improved the state management of the recurrent state with 6589d3b. The recurrent memory context now keeps immutable values such as head, rs_z, etc. These can be used in the can_reuse() logic without duplicating this state in the inputs.


res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;

res &= head == mctx->get_head();
res &= rs_z == mctx->get_rs_z();

return res;
}

void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);

Expand Down Expand Up @@ -436,8 +454,46 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
}

void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
inp_attn->set_input(ubatch);
inp_rs->set_input(ubatch);
mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);

mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);

const int64_t n_rs = mctx->get_recr()->get_n_rs();

if (inp_rs->s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
int32_t * data = (int32_t *) inp_rs->s_copy->data;

// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_rs; ++i) {
data[i] = mctx->get_recr()->s_copy(i);
}
}
}

bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);

this->mctx = mctx;

bool res = true;

res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there

res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
res &= inp_attn->self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);

res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();

res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;

res &= inp_rs->head == mctx->get_recr()->get_head();
res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();

return res;
}

//
Expand Down Expand Up @@ -1777,6 +1833,9 @@ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);

inp->head = mctx_cur->get_head();
inp->rs_z = mctx_cur->get_rs_z();

return inp;
}

Expand Down Expand Up @@ -1845,10 +1904,10 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);

auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr());
auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());

auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);

return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
}
Expand Down
16 changes: 14 additions & 2 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ class llm_graph_input_rs : public llm_graph_input_i {

void set_input(const llama_ubatch * ubatch) override;

bool can_reuse(const llm_graph_params & params) override;

ggml_tensor * s_copy; // I32 [n_rs]

// views of s_copy, computed once per graph
Expand All @@ -232,6 +234,10 @@ class llm_graph_input_rs : public llm_graph_input_i {
ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]

const llama_memory_recurrent_context * mctx;

// used in view offsets, need to match for valid graph reuse
uint32_t head;
int32_t rs_z;
};

class llm_graph_input_cross_embd : public llm_graph_input_i {
Expand Down Expand Up @@ -360,22 +366,28 @@ class llm_graph_input_attn_cross : public llm_graph_input_i {
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
public:
llm_graph_input_mem_hybrid(
const llama_cparams & cparams,
std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
std::unique_ptr<llm_graph_input_rs> inp_rs,
const llama_memory_hybrid_context * mctx) :
std::unique_ptr<llm_graph_input_rs> inp_rs,
const llama_memory_hybrid_context * mctx) :
inp_attn(std::move(inp_attn)),
inp_rs(std::move(inp_rs)),
cparams(cparams),
mctx(mctx) { }
virtual ~llm_graph_input_mem_hybrid() = default;

void set_input(const llama_ubatch * ubatch) override;

bool can_reuse(const llm_graph_params & params) override;

std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
std::unique_ptr<llm_graph_input_rs> inp_rs;

llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }

const llama_cparams cparams;

const llama_memory_hybrid_context * mctx;
};

Expand Down
2 changes: 1 addition & 1 deletion src/llama-memory-hybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal
ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
}

Expand Down
Loading