From 763de82364d3da8f576bccd9efd4509a5d67f0d6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 9 Oct 2025 19:33:30 +0300 Subject: [PATCH 1/5] graph : reuse hybrid graphs --- src/llama-graph.cpp | 41 ++++++++++++++++++++++++++++++++++--- src/llama-graph.h | 10 +++++++-- src/llama-memory-hybrid.cpp | 2 +- 3 files changed, 47 insertions(+), 6 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index f29a1e98c9103..4aa827620e850 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -458,8 +458,43 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { - inp_attn->set_input(ubatch); - inp_rs->set_input(ubatch); + mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); + mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch); + + mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + + const int64_t n_rs = mctx->get_recr()->get_n_rs(); + + if (inp_rs->s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); + int32_t * data = (int32_t *) inp_rs->s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + data[i] = mctx->get_recr()->s_copy(i); + } + } +} + +bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); + res &= inp_attn->self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); + + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); + + res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + + return res; } // @@ -1879,7 +1914,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); - auto inp = std::make_unique(std::move(inp_attn), std::move(inp_rs), mctx_cur); + auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } diff --git a/src/llama-graph.h b/src/llama-graph.h index d0c3934f67927..25e50238f52d1 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -364,22 +364,28 @@ class llm_graph_input_attn_cross : public llm_graph_input_i { class llm_graph_input_mem_hybrid : public llm_graph_input_i { public: llm_graph_input_mem_hybrid( + const llama_cparams & cparams, std::unique_ptr inp_attn, - std::unique_ptr inp_rs, - const llama_memory_hybrid_context * mctx) : + std::unique_ptr inp_rs, + const llama_memory_hybrid_context * mctx) : inp_attn(std::move(inp_attn)), inp_rs(std::move(inp_rs)), + cparams(cparams), mctx(mctx) { } virtual ~llm_graph_input_mem_hybrid() = default; void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + std::unique_ptr inp_attn; std::unique_ptr inp_rs; llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); } llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + const llama_cparams cparams; + const llama_memory_hybrid_context * mctx; }; diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index dfb8439e01bdf..a1b45e4a3cce3 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -222,7 +222,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context( ubatches(std::move(ubatches)), // note: here we copy the ubatches. not sure if this is ideal ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)), - ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), + ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { } From 08365ca15617db18fc1c3f8b0d41ddb6de56f405 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 9 Oct 2025 19:36:17 +0300 Subject: [PATCH 2/5] graph : reuse recurrent graphs --- src/llama-graph.cpp | 15 +++++++++++++++ src/llama-graph.h | 2 ++ 2 files changed, 17 insertions(+) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 4aa827620e850..c71e8bae31dfb 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -251,6 +251,21 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { } } +bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= s_copy->ne[0] == mctx->get_n_rs(); + + res &= s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; + + return res; +} + void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); diff --git a/src/llama-graph.h b/src/llama-graph.h index 25e50238f52d1..944d129c3e11e 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -224,6 +224,8 @@ class llm_graph_input_rs : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + ggml_tensor * s_copy; // I32 [n_rs] // views of s_copy, computed once per graph From 1bf3f6ad67d4aa393d5b18cec6afb149dd57bcfc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 10 Oct 2025 10:44:41 +0300 Subject: [PATCH 3/5] graph : fix reuse check for recurrent inputs --- src/llama-graph.cpp | 11 ++++++++++- src/llama-graph.h | 4 ++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index c71e8bae31dfb..6d1a6cc24595c 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -263,6 +263,9 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { res &= s_copy_main->ne[0] == params.ubatch.n_seqs; res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; + res &= head == mctx->get_head(); + res &= rs_z == mctx->get_rs_z(); + return res; } @@ -509,6 +512,9 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + return res; } @@ -1858,6 +1864,9 @@ static std::unique_ptr build_rs_inp_impl( inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0); inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]); + inp->head = mctx_cur->get_head(); + inp->rs_z = mctx_cur->get_rs_z(); + return inp; } @@ -1926,7 +1935,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { const auto * mctx_cur = static_cast(mctx); - auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); + auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr()); auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); diff --git a/src/llama-graph.h b/src/llama-graph.h index 944d129c3e11e..caba9779b5d48 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -234,6 +234,10 @@ class llm_graph_input_rs : public llm_graph_input_i { ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs] const llama_memory_recurrent_context * mctx; + + // used in view offsets, need to match for valid graph reuse + uint32_t head; + int32_t rs_z; }; class llm_graph_input_cross_embd : public llm_graph_input_i { From 08dc973ed9b3eb47b39d55a314af436278c7b38c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 10 Oct 2025 10:57:35 +0300 Subject: [PATCH 4/5] memory : move the recurrent state into the memory context --- src/llama-graph.cpp | 13 ++++++++----- src/llama-graph.h | 8 ++++---- src/llama-memory-recurrent.cpp | 17 ++++++++++------- src/llama-memory-recurrent.h | 6 ++++-- 4 files changed, 26 insertions(+), 18 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 6d1a6cc24595c..caf2d78c5ea4b 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -235,6 +235,12 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { } } +llm_graph_input_rs::llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : + mctx(mctx), + head(mctx->get_head()), + rs_z(mctx->get_rs_z()) { +} + void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); @@ -263,8 +269,8 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { res &= s_copy_main->ne[0] == params.ubatch.n_seqs; res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; - res &= head == mctx->get_head(); - res &= rs_z == mctx->get_rs_z(); + res &= this->head == mctx->get_head(); + res &= this->rs_z == mctx->get_rs_z(); return res; } @@ -1864,9 +1870,6 @@ static std::unique_ptr build_rs_inp_impl( inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0); inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]); - inp->head = mctx_cur->get_head(); - inp->rs_z = mctx_cur->get_rs_z(); - return inp; } diff --git a/src/llama-graph.h b/src/llama-graph.h index caba9779b5d48..44192c66a2633 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -219,7 +219,7 @@ class llm_graph_input_cls : public llm_graph_input_i { class llm_graph_input_rs : public llm_graph_input_i { public: - llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {} + llm_graph_input_rs(const llama_memory_recurrent_context * mctx); virtual ~llm_graph_input_rs() = default; void set_input(const llama_ubatch * ubatch) override; @@ -235,9 +235,9 @@ class llm_graph_input_rs : public llm_graph_input_i { const llama_memory_recurrent_context * mctx; - // used in view offsets, need to match for valid graph reuse - uint32_t head; - int32_t rs_z; + // need to match for valid graph reuse + const uint32_t head; + const int32_t rs_z; }; class llm_graph_input_cross_embd : public llm_graph_input_i { diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index d67f5a5f47b87..28d1b2a623901 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -1088,12 +1088,15 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {} llama_memory_recurrent_context::llama_memory_recurrent_context( - llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) { + llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), + n_rs(mem->size), head(0), rs_z(0), size(mem->size) { } llama_memory_recurrent_context::llama_memory_recurrent_context( llama_memory_recurrent * mem, - std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {} + std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)), + n_rs(mem->n), head(mem->head), rs_z(mem->rs_z), size(mem->size) { +} llama_memory_recurrent_context::~llama_memory_recurrent_context() = default; @@ -1134,19 +1137,19 @@ const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const { } uint32_t llama_memory_recurrent_context::get_n_rs() const { - return is_full ? mem->size : mem->n; + return n_rs; } uint32_t llama_memory_recurrent_context::get_head() const { - return is_full ? 0 : mem->head; + return head; } int32_t llama_memory_recurrent_context::get_rs_z() const { - return is_full ? 0 : mem->rs_z; + return rs_z; } uint32_t llama_memory_recurrent_context::get_size() const { - return mem->size; + return size; } ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const { @@ -1158,5 +1161,5 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { } int32_t llama_memory_recurrent_context::s_copy(int i) const { - return mem->cells[i + mem->head].src0; + return mem->cells[i + head].src0; } diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index 077c6e3ce938d..c99b155bcbc42 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -175,8 +175,10 @@ class llama_memory_recurrent_context : public llama_memory_context_i { // // data needed for building the compute graph for the current ubatch: - // TODO: extract all the state like `head` and `n` here // - const bool is_full = false; + const uint32_t n_rs = 0; + const uint32_t head = 0; + const int32_t rs_z = -1; + const uint32_t size = 0; }; From 7641e6f2824ef700e0ce9a99b7f708fda2151d26 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 10 Oct 2025 19:41:10 +0300 Subject: [PATCH 5/5] Revert "memory : move the recurrent state into the memory context" This reverts commit 00f115fe810815d4a22a6dee0acc346131e970e1. --- src/llama-graph.cpp | 13 +++++-------- src/llama-graph.h | 8 ++++---- src/llama-memory-recurrent.cpp | 17 +++++++---------- src/llama-memory-recurrent.h | 6 ++---- 4 files changed, 18 insertions(+), 26 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index caf2d78c5ea4b..6d1a6cc24595c 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -235,12 +235,6 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { } } -llm_graph_input_rs::llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : - mctx(mctx), - head(mctx->get_head()), - rs_z(mctx->get_rs_z()) { -} - void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); @@ -269,8 +263,8 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { res &= s_copy_main->ne[0] == params.ubatch.n_seqs; res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; - res &= this->head == mctx->get_head(); - res &= this->rs_z == mctx->get_rs_z(); + res &= head == mctx->get_head(); + res &= rs_z == mctx->get_rs_z(); return res; } @@ -1870,6 +1864,9 @@ static std::unique_ptr build_rs_inp_impl( inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0); inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]); + inp->head = mctx_cur->get_head(); + inp->rs_z = mctx_cur->get_rs_z(); + return inp; } diff --git a/src/llama-graph.h b/src/llama-graph.h index 44192c66a2633..caba9779b5d48 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -219,7 +219,7 @@ class llm_graph_input_cls : public llm_graph_input_i { class llm_graph_input_rs : public llm_graph_input_i { public: - llm_graph_input_rs(const llama_memory_recurrent_context * mctx); + llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {} virtual ~llm_graph_input_rs() = default; void set_input(const llama_ubatch * ubatch) override; @@ -235,9 +235,9 @@ class llm_graph_input_rs : public llm_graph_input_i { const llama_memory_recurrent_context * mctx; - // need to match for valid graph reuse - const uint32_t head; - const int32_t rs_z; + // used in view offsets, need to match for valid graph reuse + uint32_t head; + int32_t rs_z; }; class llm_graph_input_cross_embd : public llm_graph_input_i { diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 28d1b2a623901..d67f5a5f47b87 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -1088,15 +1088,12 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {} llama_memory_recurrent_context::llama_memory_recurrent_context( - llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), - n_rs(mem->size), head(0), rs_z(0), size(mem->size) { + llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) { } llama_memory_recurrent_context::llama_memory_recurrent_context( llama_memory_recurrent * mem, - std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)), - n_rs(mem->n), head(mem->head), rs_z(mem->rs_z), size(mem->size) { -} + std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {} llama_memory_recurrent_context::~llama_memory_recurrent_context() = default; @@ -1137,19 +1134,19 @@ const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const { } uint32_t llama_memory_recurrent_context::get_n_rs() const { - return n_rs; + return is_full ? mem->size : mem->n; } uint32_t llama_memory_recurrent_context::get_head() const { - return head; + return is_full ? 0 : mem->head; } int32_t llama_memory_recurrent_context::get_rs_z() const { - return rs_z; + return is_full ? 0 : mem->rs_z; } uint32_t llama_memory_recurrent_context::get_size() const { - return size; + return mem->size; } ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const { @@ -1161,5 +1158,5 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { } int32_t llama_memory_recurrent_context::s_copy(int i) const { - return mem->cells[i + head].src0; + return mem->cells[i + mem->head].src0; } diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index c99b155bcbc42..077c6e3ce938d 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -175,10 +175,8 @@ class llama_memory_recurrent_context : public llama_memory_context_i { // // data needed for building the compute graph for the current ubatch: + // TODO: extract all the state like `head` and `n` here // - const uint32_t n_rs = 0; - const uint32_t head = 0; - const int32_t rs_z = -1; - const uint32_t size = 0; + const bool is_full = false; };