From 7d654464228ab95d769d5c6f9bc4b0b1ce7ee242 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 16 Nov 2025 11:00:39 +0100 Subject: [PATCH 1/2] feat: add support to logprobs in results Signed-off-by: Ettore Di Giacinto --- backend/backend.proto | 3 + backend/cpp/llama-cpp/grpc-server.cpp | 105 ++++++++++++++++++++++-- core/backend/llm.go | 30 ++++++- core/backend/options.go | 4 +- core/http/app_test.go | 57 +++++++++++++ core/http/endpoints/openai/chat.go | 20 ++++- core/http/endpoints/openai/inference.go | 25 +++++- core/schema/openai.go | 13 +++ core/schema/prediction.go | 83 +++++++++++++++++++ 9 files changed, 329 insertions(+), 11 deletions(-) diff --git a/backend/backend.proto b/backend/backend.proto index a367523de5c6..187294236862 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -156,6 +156,8 @@ message PredictOptions { string CorrelationId = 47; string Tools = 48; // JSON array of available tools/functions for tool calling string ToolChoice = 49; // JSON string or object specifying tool choice behavior + int32 Logprobs = 50; // Number of top logprobs to return (maps to OpenAI logprobs parameter) + int32 TopLogprobs = 51; // Number of top logprobs to return per token (maps to OpenAI top_logprobs parameter) } // The response message containing the result @@ -166,6 +168,7 @@ message Reply { double timing_prompt_processing = 4; double timing_token_generation = 5; bytes audio = 6; + bytes logprobs = 7; // JSON-encoded logprobs data matching OpenAI format } message GrammarTrigger { diff --git a/backend/cpp/llama-cpp/grpc-server.cpp b/backend/cpp/llama-cpp/grpc-server.cpp index a26d995a458c..51320a99fd78 100644 --- a/backend/cpp/llama-cpp/grpc-server.cpp +++ b/backend/cpp/llama-cpp/grpc-server.cpp @@ -166,6 +166,21 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const SRV_INF("Extracted tool_choice as string: %s\n", predict->toolchoice().c_str()); } } + + // Extract logprobs and top_logprobs from proto and add to JSON data + // Following server.cpp pattern: logprobs maps to n_probs when provided + if (predict->logprobs() > 0) { + data["logprobs"] = predict->logprobs(); + // Map logprobs to n_probs (following server.cpp line 369 pattern) + // n_probs will be set by params_from_json_cmpl if logprobs is provided + data["n_probs"] = predict->logprobs(); + SRV_INF("Using logprobs: %d\n", predict->logprobs()); + } + if (predict->toplogprobs() > 0) { + data["top_logprobs"] = predict->toplogprobs(); + SRV_INF("Using top_logprobs: %d\n", predict->toplogprobs()); + } + data["ignore_eos"] = predict->ignoreeos(); data["embeddings"] = predict->embeddings(); @@ -568,6 +583,28 @@ class BackendServiceImpl final : public backend::Backend::Service { return Status::OK; } + // Helper function to extract logprobs from JSON response + static json extract_logprobs_from_json(const json& res_json) { + json logprobs_json = json::object(); + + // Check for OAI-compatible format: choices[0].logprobs + if (res_json.contains("choices") && res_json["choices"].is_array() && + res_json["choices"].size() > 0 && res_json["choices"][0].contains("logprobs")) { + logprobs_json = res_json["choices"][0]["logprobs"]; + } + // Check for non-OAI format: completion_probabilities + else if (res_json.contains("completion_probabilities")) { + // Convert completion_probabilities to OAI format + logprobs_json["content"] = res_json["completion_probabilities"]; + } + // Check for direct logprobs field + else if (res_json.contains("logprobs")) { + logprobs_json = res_json["logprobs"]; + } + + return logprobs_json; + } + grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter* writer) override { json data = parse_options(true, request, ctx_server); @@ -915,6 +952,13 @@ class BackendServiceImpl final : public backend::Backend::Service { reply.set_timing_token_generation(timing_token_generation); } + // Extract and set logprobs if present + json logprobs_json = extract_logprobs_from_json(res); + if (!logprobs_json.empty() && !logprobs_json.is_null()) { + std::string logprobs_str = logprobs_json.dump(); + reply.set_logprobs(logprobs_str); + } + writer->Write(reply); } } else { @@ -934,6 +978,13 @@ class BackendServiceImpl final : public backend::Backend::Service { reply.set_timing_token_generation(timing_token_generation); } + // Extract and set logprobs if present + json logprobs_json = extract_logprobs_from_json(first_res_json); + if (!logprobs_json.empty() && !logprobs_json.is_null()) { + std::string logprobs_str = logprobs_json.dump(); + reply.set_logprobs(logprobs_str); + } + writer->Write(reply); } @@ -969,6 +1020,13 @@ class BackendServiceImpl final : public backend::Backend::Service { reply.set_timing_token_generation(timing_token_generation); } + // Extract and set logprobs if present + json logprobs_json = extract_logprobs_from_json(res); + if (!logprobs_json.empty() && !logprobs_json.is_null()) { + std::string logprobs_str = logprobs_json.dump(); + reply.set_logprobs(logprobs_str); + } + writer->Write(reply); } } else { @@ -988,6 +1046,13 @@ class BackendServiceImpl final : public backend::Backend::Service { reply.set_timing_token_generation(timing_token_generation); } + // Extract and set logprobs if present + json logprobs_json = extract_logprobs_from_json(res_json); + if (!logprobs_json.empty() && !logprobs_json.is_null()) { + std::string logprobs_str = logprobs_json.dump(); + reply.set_logprobs(logprobs_str); + } + writer->Write(reply); } } @@ -1335,28 +1400,54 @@ class BackendServiceImpl final : public backend::Backend::Service { if (all_results.results.size() == 1) { // single result GGML_ASSERT(dynamic_cast(all_results.results[0].get()) != nullptr); - reply->set_message(all_results.results[0]->to_json().value("content", "")); + json result_json = all_results.results[0]->to_json(); + reply->set_message(result_json.value("content", "")); - int32_t tokens_predicted = all_results.results[0]->to_json().value("tokens_predicted", 0); + int32_t tokens_predicted = result_json.value("tokens_predicted", 0); reply->set_tokens(tokens_predicted); - int32_t tokens_evaluated = all_results.results[0]->to_json().value("tokens_evaluated", 0); + int32_t tokens_evaluated = result_json.value("tokens_evaluated", 0); reply->set_prompt_tokens(tokens_evaluated); - if (all_results.results[0]->to_json().contains("timings")) { - double timing_prompt_processing = all_results.results[0]->to_json().at("timings").value("prompt_ms", 0.0); + if (result_json.contains("timings")) { + double timing_prompt_processing = result_json.at("timings").value("prompt_ms", 0.0); reply->set_timing_prompt_processing(timing_prompt_processing); - double timing_token_generation = all_results.results[0]->to_json().at("timings").value("predicted_ms", 0.0); + double timing_token_generation = result_json.at("timings").value("predicted_ms", 0.0); reply->set_timing_token_generation(timing_token_generation); } + // Extract and set logprobs if present + json logprobs_json = extract_logprobs_from_json(result_json); + if (!logprobs_json.empty() && !logprobs_json.is_null()) { + std::string logprobs_str = logprobs_json.dump(); + reply->set_logprobs(logprobs_str); + } + } else { // multiple results (multitask) json arr = json::array(); + json logprobs_arr = json::array(); + bool has_logprobs = false; for (auto & res : all_results.results) { GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - arr.push_back(res->to_json().value("content", "")); + json res_json = res->to_json(); + arr.push_back(res_json.value("content", "")); + + // Extract logprobs for each result + json logprobs_json = extract_logprobs_from_json(res_json); + if (!logprobs_json.empty() && !logprobs_json.is_null()) { + has_logprobs = true; + logprobs_arr.push_back(logprobs_json); + } else { + logprobs_arr.push_back(json::object()); + } } reply->set_message(arr); + + // Set logprobs if any result has them + if (has_logprobs) { + std::string logprobs_str = logprobs_arr.dump(); + reply->set_logprobs(logprobs_str); + } } } diff --git a/core/backend/llm.go b/core/backend/llm.go index 3cd74d9a4953..8a418c62c3b0 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -2,6 +2,7 @@ package backend import ( "context" + "encoding/json" "regexp" "slices" "strings" @@ -24,6 +25,7 @@ type LLMResponse struct { Response string // should this be []byte? Usage TokenUsage AudioOutput string + Logprobs *schema.Logprobs // Logprobs from the backend response } type TokenUsage struct { @@ -33,7 +35,7 @@ type TokenUsage struct { TimingTokenGeneration float64 } -func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string) (func() (LLMResponse, error), error) { +func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string, logprobs *int, topLogprobs *int) (func() (LLMResponse, error), error) { modelFile := c.Model // Check if the modelFile exists, if it doesn't try to load it from the gallery @@ -78,6 +80,12 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima opts.Audios = audios opts.Tools = tools opts.ToolChoice = toolChoice + if logprobs != nil { + opts.Logprobs = int32(*logprobs) + } + if topLogprobs != nil { + opts.TopLogprobs = int32(*topLogprobs) + } tokenUsage := TokenUsage{} @@ -109,6 +117,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima } ss := "" + var logprobs *schema.Logprobs var partialRune []byte err := inferenceModel.PredictStream(ctx, opts, func(reply *proto.Reply) { @@ -120,6 +129,14 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing + // Parse logprobs from reply if present (collect from last chunk that has them) + if len(reply.Logprobs) > 0 { + var parsedLogprobs schema.Logprobs + if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil { + logprobs = &parsedLogprobs + } + } + // Process complete runes and accumulate them var completeRunes []byte for len(partialRune) > 0 { @@ -145,6 +162,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima return LLMResponse{ Response: ss, Usage: tokenUsage, + Logprobs: logprobs, }, err } else { // TODO: Is the chicken bit the only way to get here? is that acceptable? @@ -167,9 +185,19 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima response = c.TemplateConfig.ReplyPrefix + response } + // Parse logprobs from reply if present + var logprobs *schema.Logprobs + if len(reply.Logprobs) > 0 { + var parsedLogprobs schema.Logprobs + if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil { + logprobs = &parsedLogprobs + } + } + return LLMResponse{ Response: response, Usage: tokenUsage, + Logprobs: logprobs, }, err } } diff --git a/core/backend/options.go b/core/backend/options.go index 9a6b8993e903..d0965188a1ff 100644 --- a/core/backend/options.go +++ b/core/backend/options.go @@ -212,7 +212,7 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions } } - return &pb.PredictOptions{ + pbOpts := &pb.PredictOptions{ Temperature: float32(*c.Temperature), TopP: float32(*c.TopP), NDraft: c.NDraft, @@ -249,4 +249,6 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions TailFreeSamplingZ: float32(*c.TFZ), TypicalP: float32(*c.TypicalP), } + // Logprobs and TopLogprobs are set by the caller if provided + return pbOpts } diff --git a/core/http/app_test.go b/core/http/app_test.go index bce9c56903bb..ffb9c14e64f6 100644 --- a/core/http/app_test.go +++ b/core/http/app_test.go @@ -816,6 +816,63 @@ var _ = Describe("API test", func() { Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) }) + It("returns logprobs in chat completions when requested", func() { + topLogprobsVal := 3 + response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{ + Model: "testmodel.ggml", + LogProbs: true, + TopLogProbs: topLogprobsVal, + Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}}) + Expect(err).ToNot(HaveOccurred()) + + Expect(len(response.Choices)).To(Equal(1)) + Expect(response.Choices[0].Message).ToNot(BeNil()) + Expect(response.Choices[0].Message.Content).ToNot(BeEmpty()) + + // Verify logprobs are present and have correct structure + Expect(response.Choices[0].LogProbs).ToNot(BeNil()) + Expect(response.Choices[0].LogProbs.Content).ToNot(BeEmpty()) + + Expect(len(response.Choices[0].LogProbs.Content)).To(BeNumerically(">", 1)) + + foundatLeastToken := "" + foundAtLeastBytes := []byte{} + foundAtLeastTopLogprobBytes := []byte{} + foundatLeastTopLogprob := "" + // Verify logprobs content structure matches OpenAI format + for _, logprobContent := range response.Choices[0].LogProbs.Content { + // Bytes can be empty for certain tokens (special tokens, etc.), so we don't require it + if len(logprobContent.Bytes) > 0 { + foundAtLeastBytes = logprobContent.Bytes + } + if len(logprobContent.Token) > 0 { + foundatLeastToken = logprobContent.Token + } + Expect(logprobContent.LogProb).To(BeNumerically("<=", 0)) // Logprobs are always <= 0 + Expect(len(logprobContent.TopLogProbs)).To(BeNumerically(">", 1)) + + // If top_logprobs is requested, verify top_logprobs array respects the limit + if len(logprobContent.TopLogProbs) > 0 { + // Should respect top_logprobs limit (3 in this test) + Expect(len(logprobContent.TopLogProbs)).To(BeNumerically("<=", topLogprobsVal)) + for _, topLogprob := range logprobContent.TopLogProbs { + if len(topLogprob.Bytes) > 0 { + foundAtLeastTopLogprobBytes = topLogprob.Bytes + } + if len(topLogprob.Token) > 0 { + foundatLeastTopLogprob = topLogprob.Token + } + Expect(topLogprob.LogProb).To(BeNumerically("<=", 0)) + } + } + } + + Expect(foundAtLeastBytes).ToNot(BeEmpty()) + Expect(foundAtLeastTopLogprobBytes).ToNot(BeEmpty()) + Expect(foundatLeastToken).ToNot(BeEmpty()) + Expect(foundatLeastTopLogprob).ToNot(BeEmpty()) + }) + It("returns errors", func() { _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: testPrompt}) Expect(err).To(HaveOccurred()) diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index d121a081eb47..a4b7b640818e 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -635,7 +635,25 @@ func handleQuestion(config *config.ModelConfig, cl *config.ModelConfigLoader, in } } - predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, cl, o, nil, toolsJSON, toolChoiceJSON) + // Extract logprobs from request + // According to OpenAI API: logprobs is boolean, top_logprobs (0-20) controls how many top tokens per position + var logprobs *int + var topLogprobs *int + if input.Logprobs.IsEnabled() { + // If logprobs is enabled, use top_logprobs if provided, otherwise default to 1 + if input.TopLogprobs != nil { + topLogprobs = input.TopLogprobs + // For backend compatibility, set logprobs to the top_logprobs value + logprobs = input.TopLogprobs + } else { + // Default to 1 if logprobs is true but top_logprobs not specified + val := 1 + logprobs = &val + topLogprobs = &val + } + } + + predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, cl, o, nil, toolsJSON, toolChoiceJSON, logprobs, topLogprobs) if err != nil { log.Error().Err(err).Msg("model inference failed") return "", err diff --git a/core/http/endpoints/openai/inference.go b/core/http/endpoints/openai/inference.go index 95d3ee24671d..aef42f9259d1 100644 --- a/core/http/endpoints/openai/inference.go +++ b/core/http/endpoints/openai/inference.go @@ -55,9 +55,27 @@ func ComputeChoices( } } + // Extract logprobs from request + // According to OpenAI API: logprobs is boolean, top_logprobs (0-20) controls how many top tokens per position + var logprobs *int + var topLogprobs *int + if req.Logprobs.IsEnabled() { + // If logprobs is enabled, use top_logprobs if provided, otherwise default to 1 + if req.TopLogprobs != nil { + topLogprobs = req.TopLogprobs + // For backend compatibility, set logprobs to the top_logprobs value + logprobs = req.TopLogprobs + } else { + // Default to 1 if logprobs is true but top_logprobs not specified + val := 1 + logprobs = &val + topLogprobs = &val + } + } + // get the model function to call for the result predFunc, err := backend.ModelInference( - req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON) + req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs) if err != nil { return result, backend.TokenUsage{}, err } @@ -78,6 +96,11 @@ func ComputeChoices( finetunedResponse := backend.Finetune(*config, predInput, prediction.Response) cb(finetunedResponse, &result) + // Add logprobs to the last choice if present + if prediction.Logprobs != nil && len(result) > 0 { + result[len(result)-1].Logprobs = prediction.Logprobs + } + //result = append(result, Choice{Text: prediction}) } diff --git a/core/schema/openai.go b/core/schema/openai.go index fd24ec2808d3..604968f5e1e4 100644 --- a/core/schema/openai.go +++ b/core/schema/openai.go @@ -54,6 +54,19 @@ type Choice struct { Message *Message `json:"message,omitempty"` Delta *Message `json:"delta,omitempty"` Text string `json:"text,omitempty"` + Logprobs *Logprobs `json:"logprobs,omitempty"` +} + +type Logprobs struct { + Content []LogprobContent `json:"content,omitempty"` +} + +type LogprobContent struct { + ID int32 `json:"id"` + Token string `json:"token"` + Bytes []int `json:"bytes,omitempty"` + Logprob float64 `json:"logprob"` + TopLogprobs []LogprobContent `json:"top_logprobs,omitempty"` } type Content struct { diff --git a/core/schema/prediction.go b/core/schema/prediction.go index 40345ba50b14..0b5f441e2cb7 100644 --- a/core/schema/prediction.go +++ b/core/schema/prediction.go @@ -1,5 +1,82 @@ package schema +import ( + "encoding/json" + + "gopkg.in/yaml.v3" +) + +// LogprobsValue represents the logprobs parameter which is a boolean. +// According to OpenAI API: true means return log probabilities, false/null means don't return them. +// The actual number of top logprobs per token is controlled by top_logprobs (0-5). +type LogprobsValue struct { + Enabled bool // true if logprobs should be returned +} + +// UnmarshalJSON implements json.Unmarshaler to handle boolean +func (l *LogprobsValue) UnmarshalJSON(data []byte) error { + // Try to unmarshal as boolean + var b bool + if err := json.Unmarshal(data, &b); err == nil { + l.Enabled = b + return nil + } + + // If it's null, set to false + var n *bool + if err := json.Unmarshal(data, &n); err == nil { + l.Enabled = false + return nil + } + + // Try as integer for backward compatibility (treat > 0 as true) + var i int + if err := json.Unmarshal(data, &i); err == nil { + l.Enabled = i > 0 + return nil + } + + return json.Unmarshal(data, &l.Enabled) +} + +// MarshalJSON implements json.Marshaler +func (l LogprobsValue) MarshalJSON() ([]byte, error) { + return json.Marshal(l.Enabled) +} + +// UnmarshalYAML implements yaml.Unmarshaler to handle boolean +func (l *LogprobsValue) UnmarshalYAML(value *yaml.Node) error { + switch value.Kind { + case yaml.ScalarNode: + switch value.Tag { + case "!!bool": + var b bool + if err := value.Decode(&b); err != nil { + return err + } + l.Enabled = b + return nil + case "!!int": + // For backward compatibility, treat integer > 0 as true + var i int + if err := value.Decode(&i); err != nil { + return err + } + l.Enabled = i > 0 + return nil + case "!!null": + l.Enabled = false + return nil + } + } + return value.Decode(&l.Enabled) +} + +// IsEnabled returns true if logprobs should be returned +func (l *LogprobsValue) IsEnabled() bool { + return l.Enabled +} + // @Description PredictionOptions contains prediction parameters for model inference type PredictionOptions struct { @@ -38,6 +115,12 @@ type PredictionOptions struct { TypicalP *float64 `json:"typical_p,omitempty" yaml:"typical_p,omitempty"` Seed *int `json:"seed,omitempty" yaml:"seed,omitempty"` + // OpenAI API logprobs parameters + // logprobs: boolean - if true, returns log probabilities of each output token + // top_logprobs: integer 0-20 - number of most likely tokens to return at each token position + Logprobs LogprobsValue `json:"logprobs,omitempty" yaml:"logprobs,omitempty"` // Whether to return log probabilities (true/false) + TopLogprobs *int `json:"top_logprobs,omitempty" yaml:"top_logprobs,omitempty"` // Number of top logprobs per token (0-20) + NegativePrompt string `json:"negative_prompt,omitempty" yaml:"negative_prompt,omitempty"` RopeFreqBase float32 `json:"rope_freq_base,omitempty" yaml:"rope_freq_base,omitempty"` RopeFreqScale float32 `json:"rope_freq_scale,omitempty" yaml:"rope_freq_scale,omitempty"` From eaa956943a9e7bfb541e93444ca10b32fb5b6dbb Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 16 Nov 2025 11:09:04 +0100 Subject: [PATCH 2/2] feat: add support to logitbias Signed-off-by: Ettore Di Giacinto --- backend/cpp/llama-cpp/grpc-server.cpp | 13 +++++++++++++ core/backend/llm.go | 9 ++++++++- core/http/app_test.go | 20 ++++++++++++++++++++ core/http/endpoints/openai/chat.go | 9 ++++++++- core/http/endpoints/openai/inference.go | 9 ++++++++- core/http/endpoints/openai/realtime.go | 2 +- core/schema/prediction.go | 5 +++-- 7 files changed, 61 insertions(+), 6 deletions(-) diff --git a/backend/cpp/llama-cpp/grpc-server.cpp b/backend/cpp/llama-cpp/grpc-server.cpp index 51320a99fd78..a27c7166d798 100644 --- a/backend/cpp/llama-cpp/grpc-server.cpp +++ b/backend/cpp/llama-cpp/grpc-server.cpp @@ -181,6 +181,19 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const SRV_INF("Using top_logprobs: %d\n", predict->toplogprobs()); } + // Extract logit_bias from proto and add to JSON data + if (!predict->logitbias().empty()) { + try { + // Parse logit_bias JSON string from proto + json logit_bias_json = json::parse(predict->logitbias()); + // Add to data - llama.cpp server expects it as an object (map) + data["logit_bias"] = logit_bias_json; + SRV_INF("Using logit_bias: %s\n", predict->logitbias().c_str()); + } catch (const json::parse_error& e) { + SRV_ERR("Failed to parse logit_bias JSON from proto: %s\n", e.what()); + } + } + data["ignore_eos"] = predict->ignoreeos(); data["embeddings"] = predict->embeddings(); diff --git a/core/backend/llm.go b/core/backend/llm.go index 8a418c62c3b0..c00f5876d852 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -35,7 +35,7 @@ type TokenUsage struct { TimingTokenGeneration float64 } -func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string, logprobs *int, topLogprobs *int) (func() (LLMResponse, error), error) { +func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (LLMResponse, error), error) { modelFile := c.Model // Check if the modelFile exists, if it doesn't try to load it from the gallery @@ -86,6 +86,13 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima if topLogprobs != nil { opts.TopLogprobs = int32(*topLogprobs) } + if len(logitBias) > 0 { + // Serialize logit_bias map to JSON string for proto + logitBiasJSON, err := json.Marshal(logitBias) + if err == nil { + opts.LogitBias = string(logitBiasJSON) + } + } tokenUsage := TokenUsage{} diff --git a/core/http/app_test.go b/core/http/app_test.go index ffb9c14e64f6..7935d2c91a2f 100644 --- a/core/http/app_test.go +++ b/core/http/app_test.go @@ -873,6 +873,26 @@ var _ = Describe("API test", func() { Expect(foundatLeastTopLogprob).ToNot(BeEmpty()) }) + It("applies logit_bias to chat completions when requested", func() { + // logit_bias is a map of token IDs (as strings) to bias values (-100 to 100) + // According to OpenAI API: modifies the likelihood of specified tokens appearing in the completion + logitBias := map[string]int{ + "15043": 1, // Bias token ID 15043 (example token ID) with bias value 1 + } + response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{ + Model: "testmodel.ggml", + Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}, + LogitBias: logitBias, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(len(response.Choices)).To(Equal(1)) + Expect(response.Choices[0].Message).ToNot(BeNil()) + Expect(response.Choices[0].Message.Content).ToNot(BeEmpty()) + // If logit_bias is applied, the response should be generated successfully + // We can't easily verify the bias effect without knowing the actual token IDs for the model, + // but the fact that the request succeeds confirms the API accepts and processes logit_bias + }) + It("returns errors", func() { _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: testPrompt}) Expect(err).To(HaveOccurred()) diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index a4b7b640818e..a68ce16c7600 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -653,7 +653,14 @@ func handleQuestion(config *config.ModelConfig, cl *config.ModelConfigLoader, in } } - predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, cl, o, nil, toolsJSON, toolChoiceJSON, logprobs, topLogprobs) + // Extract logit_bias from request + // According to OpenAI API: logit_bias is a map of token IDs (as strings) to bias values (-100 to 100) + var logitBias map[string]float64 + if len(input.LogitBias) > 0 { + logitBias = input.LogitBias + } + + predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, cl, o, nil, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias) if err != nil { log.Error().Err(err).Msg("model inference failed") return "", err diff --git a/core/http/endpoints/openai/inference.go b/core/http/endpoints/openai/inference.go index aef42f9259d1..37b14c98bcfa 100644 --- a/core/http/endpoints/openai/inference.go +++ b/core/http/endpoints/openai/inference.go @@ -73,9 +73,16 @@ func ComputeChoices( } } + // Extract logit_bias from request + // According to OpenAI API: logit_bias is a map of token IDs (as strings) to bias values (-100 to 100) + var logitBias map[string]float64 + if len(req.LogitBias) > 0 { + logitBias = req.LogitBias + } + // get the model function to call for the result predFunc, err := backend.ModelInference( - req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs) + req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias) if err != nil { return result, backend.TokenUsage{}, err } diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go index f715bb2b4281..b9c7d8e532d0 100644 --- a/core/http/endpoints/openai/realtime.go +++ b/core/http/endpoints/openai/realtime.go @@ -1087,7 +1087,7 @@ func processTextResponse(config *config.ModelConfig, session *Session, prompt st // For example, the model might return a special token or JSON indicating a function call /* - predFunc, err := backend.ModelInference(context.Background(), prompt, input.Messages, images, videos, audios, ml, *config, o, nil) + predFunc, err := backend.ModelInference(context.Background(), prompt, input.Messages, images, videos, audios, ml, *config, o, nil, "", "", nil, nil, nil) result, tokenUsage, err := ComputeChoices(input, prompt, config, startupOptions, ml, func(s string, c *[]schema.Choice) { if !shouldUseFn { diff --git a/core/schema/prediction.go b/core/schema/prediction.go index 0b5f441e2cb7..cf1eda841166 100644 --- a/core/schema/prediction.go +++ b/core/schema/prediction.go @@ -118,8 +118,9 @@ type PredictionOptions struct { // OpenAI API logprobs parameters // logprobs: boolean - if true, returns log probabilities of each output token // top_logprobs: integer 0-20 - number of most likely tokens to return at each token position - Logprobs LogprobsValue `json:"logprobs,omitempty" yaml:"logprobs,omitempty"` // Whether to return log probabilities (true/false) - TopLogprobs *int `json:"top_logprobs,omitempty" yaml:"top_logprobs,omitempty"` // Number of top logprobs per token (0-20) + Logprobs LogprobsValue `json:"logprobs,omitempty" yaml:"logprobs,omitempty"` // Whether to return log probabilities (true/false) + TopLogprobs *int `json:"top_logprobs,omitempty" yaml:"top_logprobs,omitempty"` // Number of top logprobs per token (0-20) + LogitBias map[string]float64 `json:"logit_bias,omitempty" yaml:"logit_bias,omitempty"` // Map of token IDs to bias values (-100 to 100) NegativePrompt string `json:"negative_prompt,omitempty" yaml:"negative_prompt,omitempty"` RopeFreqBase float32 `json:"rope_freq_base,omitempty" yaml:"rope_freq_base,omitempty"`