Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ message PredictOptions {
string CorrelationId = 47;
string Tools = 48; // JSON array of available tools/functions for tool calling
string ToolChoice = 49; // JSON string or object specifying tool choice behavior
int32 Logprobs = 50; // Number of top logprobs to return (maps to OpenAI logprobs parameter)
int32 TopLogprobs = 51; // Number of top logprobs to return per token (maps to OpenAI top_logprobs parameter)
}

// The response message containing the result
Expand All @@ -166,6 +168,7 @@ message Reply {
double timing_prompt_processing = 4;
double timing_token_generation = 5;
bytes audio = 6;
bytes logprobs = 7; // JSON-encoded logprobs data matching OpenAI format
}

message GrammarTrigger {
Expand Down
118 changes: 111 additions & 7 deletions backend/cpp/llama-cpp/grpc-server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,34 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
SRV_INF("Extracted tool_choice as string: %s\n", predict->toolchoice().c_str());
}
}

// Extract logprobs and top_logprobs from proto and add to JSON data
// Following server.cpp pattern: logprobs maps to n_probs when provided
if (predict->logprobs() > 0) {
data["logprobs"] = predict->logprobs();
// Map logprobs to n_probs (following server.cpp line 369 pattern)
// n_probs will be set by params_from_json_cmpl if logprobs is provided
data["n_probs"] = predict->logprobs();
SRV_INF("Using logprobs: %d\n", predict->logprobs());
}
if (predict->toplogprobs() > 0) {
data["top_logprobs"] = predict->toplogprobs();
SRV_INF("Using top_logprobs: %d\n", predict->toplogprobs());
}

// Extract logit_bias from proto and add to JSON data
if (!predict->logitbias().empty()) {
try {
// Parse logit_bias JSON string from proto
json logit_bias_json = json::parse(predict->logitbias());
// Add to data - llama.cpp server expects it as an object (map)
data["logit_bias"] = logit_bias_json;
SRV_INF("Using logit_bias: %s\n", predict->logitbias().c_str());
} catch (const json::parse_error& e) {
SRV_ERR("Failed to parse logit_bias JSON from proto: %s\n", e.what());
}
}

data["ignore_eos"] = predict->ignoreeos();
data["embeddings"] = predict->embeddings();

Expand Down Expand Up @@ -568,6 +596,28 @@ class BackendServiceImpl final : public backend::Backend::Service {
return Status::OK;
}

// Helper function to extract logprobs from JSON response
static json extract_logprobs_from_json(const json& res_json) {
json logprobs_json = json::object();

// Check for OAI-compatible format: choices[0].logprobs
if (res_json.contains("choices") && res_json["choices"].is_array() &&
res_json["choices"].size() > 0 && res_json["choices"][0].contains("logprobs")) {
logprobs_json = res_json["choices"][0]["logprobs"];
}
// Check for non-OAI format: completion_probabilities
else if (res_json.contains("completion_probabilities")) {
// Convert completion_probabilities to OAI format
logprobs_json["content"] = res_json["completion_probabilities"];
}
// Check for direct logprobs field
else if (res_json.contains("logprobs")) {
logprobs_json = res_json["logprobs"];
}

return logprobs_json;
}

grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
json data = parse_options(true, request, ctx_server);

Expand Down Expand Up @@ -915,6 +965,13 @@ class BackendServiceImpl final : public backend::Backend::Service {
reply.set_timing_token_generation(timing_token_generation);
}

// Extract and set logprobs if present
json logprobs_json = extract_logprobs_from_json(res);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
std::string logprobs_str = logprobs_json.dump();
reply.set_logprobs(logprobs_str);
}

writer->Write(reply);
}
} else {
Expand All @@ -934,6 +991,13 @@ class BackendServiceImpl final : public backend::Backend::Service {
reply.set_timing_token_generation(timing_token_generation);
}

// Extract and set logprobs if present
json logprobs_json = extract_logprobs_from_json(first_res_json);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
std::string logprobs_str = logprobs_json.dump();
reply.set_logprobs(logprobs_str);
}

writer->Write(reply);
}

Expand Down Expand Up @@ -969,6 +1033,13 @@ class BackendServiceImpl final : public backend::Backend::Service {
reply.set_timing_token_generation(timing_token_generation);
}

// Extract and set logprobs if present
json logprobs_json = extract_logprobs_from_json(res);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
std::string logprobs_str = logprobs_json.dump();
reply.set_logprobs(logprobs_str);
}

writer->Write(reply);
}
} else {
Expand All @@ -988,6 +1059,13 @@ class BackendServiceImpl final : public backend::Backend::Service {
reply.set_timing_token_generation(timing_token_generation);
}

// Extract and set logprobs if present
json logprobs_json = extract_logprobs_from_json(res_json);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
std::string logprobs_str = logprobs_json.dump();
reply.set_logprobs(logprobs_str);
}

writer->Write(reply);
}
}
Expand Down Expand Up @@ -1335,28 +1413,54 @@ class BackendServiceImpl final : public backend::Backend::Service {
if (all_results.results.size() == 1) {
// single result
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get()) != nullptr);
reply->set_message(all_results.results[0]->to_json().value("content", ""));
json result_json = all_results.results[0]->to_json();
reply->set_message(result_json.value("content", ""));

int32_t tokens_predicted = all_results.results[0]->to_json().value("tokens_predicted", 0);
int32_t tokens_predicted = result_json.value("tokens_predicted", 0);
reply->set_tokens(tokens_predicted);
int32_t tokens_evaluated = all_results.results[0]->to_json().value("tokens_evaluated", 0);
int32_t tokens_evaluated = result_json.value("tokens_evaluated", 0);
reply->set_prompt_tokens(tokens_evaluated);

if (all_results.results[0]->to_json().contains("timings")) {
double timing_prompt_processing = all_results.results[0]->to_json().at("timings").value("prompt_ms", 0.0);
if (result_json.contains("timings")) {
double timing_prompt_processing = result_json.at("timings").value("prompt_ms", 0.0);
reply->set_timing_prompt_processing(timing_prompt_processing);
double timing_token_generation = all_results.results[0]->to_json().at("timings").value("predicted_ms", 0.0);
double timing_token_generation = result_json.at("timings").value("predicted_ms", 0.0);
reply->set_timing_token_generation(timing_token_generation);
}

// Extract and set logprobs if present
json logprobs_json = extract_logprobs_from_json(result_json);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
std::string logprobs_str = logprobs_json.dump();
reply->set_logprobs(logprobs_str);
}

} else {
// multiple results (multitask)
json arr = json::array();
json logprobs_arr = json::array();
bool has_logprobs = false;
for (auto & res : all_results.results) {
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
arr.push_back(res->to_json().value("content", ""));
json res_json = res->to_json();
arr.push_back(res_json.value("content", ""));

// Extract logprobs for each result
json logprobs_json = extract_logprobs_from_json(res_json);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
has_logprobs = true;
logprobs_arr.push_back(logprobs_json);
} else {
logprobs_arr.push_back(json::object());
}
}
reply->set_message(arr);

// Set logprobs if any result has them
if (has_logprobs) {
std::string logprobs_str = logprobs_arr.dump();
reply->set_logprobs(logprobs_str);
}
}
}

Expand Down
37 changes: 36 additions & 1 deletion core/backend/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package backend

import (
"context"
"encoding/json"
"regexp"
"slices"
"strings"
Expand All @@ -24,6 +25,7 @@ type LLMResponse struct {
Response string // should this be []byte?
Usage TokenUsage
AudioOutput string
Logprobs *schema.Logprobs // Logprobs from the backend response
}

type TokenUsage struct {
Expand All @@ -33,7 +35,7 @@ type TokenUsage struct {
TimingTokenGeneration float64
}

func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string) (func() (LLMResponse, error), error) {
func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (LLMResponse, error), error) {
modelFile := c.Model

// Check if the modelFile exists, if it doesn't try to load it from the gallery
Expand Down Expand Up @@ -78,6 +80,19 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
opts.Audios = audios
opts.Tools = tools
opts.ToolChoice = toolChoice
if logprobs != nil {
opts.Logprobs = int32(*logprobs)
}
if topLogprobs != nil {
opts.TopLogprobs = int32(*topLogprobs)
}
if len(logitBias) > 0 {
// Serialize logit_bias map to JSON string for proto
logitBiasJSON, err := json.Marshal(logitBias)
if err == nil {
opts.LogitBias = string(logitBiasJSON)
}
}

tokenUsage := TokenUsage{}

Expand Down Expand Up @@ -109,6 +124,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
}

ss := ""
var logprobs *schema.Logprobs

var partialRune []byte
err := inferenceModel.PredictStream(ctx, opts, func(reply *proto.Reply) {
Expand All @@ -120,6 +136,14 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing

// Parse logprobs from reply if present (collect from last chunk that has them)
if len(reply.Logprobs) > 0 {
var parsedLogprobs schema.Logprobs
if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil {
logprobs = &parsedLogprobs
}
}

// Process complete runes and accumulate them
var completeRunes []byte
for len(partialRune) > 0 {
Expand All @@ -145,6 +169,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
return LLMResponse{
Response: ss,
Usage: tokenUsage,
Logprobs: logprobs,
}, err
} else {
// TODO: Is the chicken bit the only way to get here? is that acceptable?
Expand All @@ -167,9 +192,19 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
response = c.TemplateConfig.ReplyPrefix + response
}

// Parse logprobs from reply if present
var logprobs *schema.Logprobs
if len(reply.Logprobs) > 0 {
var parsedLogprobs schema.Logprobs
if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil {
logprobs = &parsedLogprobs
}
}

return LLMResponse{
Response: response,
Usage: tokenUsage,
Logprobs: logprobs,
}, err
}
}
Expand Down
4 changes: 3 additions & 1 deletion core/backend/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
}
}

return &pb.PredictOptions{
pbOpts := &pb.PredictOptions{
Temperature: float32(*c.Temperature),
TopP: float32(*c.TopP),
NDraft: c.NDraft,
Expand Down Expand Up @@ -249,4 +249,6 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
TailFreeSamplingZ: float32(*c.TFZ),
TypicalP: float32(*c.TypicalP),
}
// Logprobs and TopLogprobs are set by the caller if provided
return pbOpts
}
77 changes: 77 additions & 0 deletions core/http/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,83 @@ var _ = Describe("API test", func() {
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
})

It("returns logprobs in chat completions when requested", func() {
topLogprobsVal := 3
response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{
Model: "testmodel.ggml",
LogProbs: true,
TopLogProbs: topLogprobsVal,
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}})
Expect(err).ToNot(HaveOccurred())

Expect(len(response.Choices)).To(Equal(1))
Expect(response.Choices[0].Message).ToNot(BeNil())
Expect(response.Choices[0].Message.Content).ToNot(BeEmpty())

// Verify logprobs are present and have correct structure
Expect(response.Choices[0].LogProbs).ToNot(BeNil())
Expect(response.Choices[0].LogProbs.Content).ToNot(BeEmpty())

Expect(len(response.Choices[0].LogProbs.Content)).To(BeNumerically(">", 1))

foundatLeastToken := ""
foundAtLeastBytes := []byte{}
foundAtLeastTopLogprobBytes := []byte{}
foundatLeastTopLogprob := ""
// Verify logprobs content structure matches OpenAI format
for _, logprobContent := range response.Choices[0].LogProbs.Content {
// Bytes can be empty for certain tokens (special tokens, etc.), so we don't require it
if len(logprobContent.Bytes) > 0 {
foundAtLeastBytes = logprobContent.Bytes
}
if len(logprobContent.Token) > 0 {
foundatLeastToken = logprobContent.Token
}
Expect(logprobContent.LogProb).To(BeNumerically("<=", 0)) // Logprobs are always <= 0
Expect(len(logprobContent.TopLogProbs)).To(BeNumerically(">", 1))

// If top_logprobs is requested, verify top_logprobs array respects the limit
if len(logprobContent.TopLogProbs) > 0 {
// Should respect top_logprobs limit (3 in this test)
Expect(len(logprobContent.TopLogProbs)).To(BeNumerically("<=", topLogprobsVal))
for _, topLogprob := range logprobContent.TopLogProbs {
if len(topLogprob.Bytes) > 0 {
foundAtLeastTopLogprobBytes = topLogprob.Bytes
}
if len(topLogprob.Token) > 0 {
foundatLeastTopLogprob = topLogprob.Token
}
Expect(topLogprob.LogProb).To(BeNumerically("<=", 0))
}
}
}

Expect(foundAtLeastBytes).ToNot(BeEmpty())
Expect(foundAtLeastTopLogprobBytes).ToNot(BeEmpty())
Expect(foundatLeastToken).ToNot(BeEmpty())
Expect(foundatLeastTopLogprob).ToNot(BeEmpty())
})

It("applies logit_bias to chat completions when requested", func() {
// logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
// According to OpenAI API: modifies the likelihood of specified tokens appearing in the completion
logitBias := map[string]int{
"15043": 1, // Bias token ID 15043 (example token ID) with bias value 1
}
response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{
Model: "testmodel.ggml",
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}},
LogitBias: logitBias,
})
Expect(err).ToNot(HaveOccurred())
Expect(len(response.Choices)).To(Equal(1))
Expect(response.Choices[0].Message).ToNot(BeNil())
Expect(response.Choices[0].Message.Content).ToNot(BeEmpty())
// If logit_bias is applied, the response should be generated successfully
// We can't easily verify the bias effect without knowing the actual token IDs for the model,
// but the fact that the request succeeds confirms the API accepts and processes logit_bias
})

It("returns errors", func() {
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: testPrompt})
Expect(err).To(HaveOccurred())
Expand Down
Loading
Loading