From d1d89e0b26727c07baea5290e218011bc5513333 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Mon, 14 Jul 2025 16:09:49 +0800 Subject: [PATCH 1/3] llama-context: add ability to get logits --- src/llama-context.cpp | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 06e93b19cbf40..c55bbe4ea817d 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -731,7 +731,8 @@ int llama_context::encode(const llama_batch & batch_inp) { const auto & hparams = model.hparams; - const int64_t n_embd = hparams.n_embd; + const int64_t n_embd = hparams.n_embd; + const int32_t n_vocab = model.vocab.n_tokens(); // note: during encode, we always pass the full sequence starting from pos = 0 if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) { @@ -791,10 +792,22 @@ int llama_context::encode(const llama_batch & batch_inp) { } } + auto * t_logits = res->get_logits(); auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + // extract logits + if (t_logits && n_outputs > 0) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); + + if (n_outputs) { + ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_outputs*n_vocab*sizeof(float)); + } + } + // extract embeddings - if (t_embd) { + if (cparams.embeddings && t_embd) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); From c013397a5adabb2a3aa795b6bb25c4188dbfb27f Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Mon, 14 Jul 2025 18:40:19 +0800 Subject: [PATCH 2/3] refactor if statement Co-authored-by: Georgi Gerganov --- src/llama-context.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index c55bbe4ea817d..5ab007623080d 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -796,14 +796,12 @@ int llama_context::encode(const llama_batch & batch_inp) { auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); // extract logits - if (t_logits && n_outputs > 0) { + if (logits && t_logits) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(logits != nullptr); - if (n_outputs) { - ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_outputs*n_vocab*sizeof(float)); - } + ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float)); } // extract embeddings From 1b9ca1c6fc1212ab2eeaeed40a408e8a57a4a32f Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Mon, 14 Jul 2025 18:40:35 +0800 Subject: [PATCH 3/3] refactor if statement Co-authored-by: Georgi Gerganov --- src/llama-context.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 5ab007623080d..7c07b047b0dd9 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -805,7 +805,7 @@ int llama_context::encode(const llama_batch & batch_inp) { } // extract embeddings - if (cparams.embeddings && t_embd) { + if (embd && t_embd) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr);