diff --git a/pkg/common/logprobs_test.go b/pkg/common/logprobs_test.go new file mode 100644 index 00000000..443afebb --- /dev/null +++ b/pkg/common/logprobs_test.go @@ -0,0 +1,253 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package common + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Logprobs", func() { + + Context("GenerateTextLogprobs", func() { + It("should generate correct text logprobs structure", func() { + tokens := []string{" Paris", ",", " the", " capital"} + logprobsCount := 2 + + logprobs := GenerateTextLogprobs(tokens, logprobsCount) + + Expect(logprobs).NotTo(BeNil()) + Expect(logprobs.Tokens).To(HaveLen(len(tokens))) + Expect(logprobs.TokenLogprobs).To(HaveLen(len(tokens))) + Expect(logprobs.TopLogprobs).To(HaveLen(len(tokens))) + Expect(logprobs.TextOffset).To(HaveLen(len(tokens))) + + // Check that each top logprobs entry has the expected number of alternatives + for i, topLogprob := range logprobs.TopLogprobs { + Expect(topLogprob).To(HaveLen(logprobsCount)) + // Check that the main token is included in the alternatives + Expect(topLogprob).To(HaveKey(tokens[i])) + } + + // Check text offsets are calculated correctly (byte-based) + expectedOffsets := []int{0, 6, 7, 11} // " Paris" - 6, "," - 1, " the" -4, " capital" - 11 + for i, expected := range expectedOffsets { + Expect(logprobs.TextOffset[i]).To(Equal(expected)) + } + + // Check deterministic logprobs + expectedLogprob0 := -1.0 // defaultLogprob - float64(0%3)*0.1 + Expect(logprobs.TokenLogprobs[0]).To(Equal(expectedLogprob0)) + }) + }) + + Context("GenerateChatLogprobs", func() { + It("should generate correct chat logprobs structure", func() { + tokens := []string{"4"} + topLogprobsCount := 3 + + logprobs := GenerateChatLogprobs(tokens, topLogprobsCount) + + Expect(logprobs).NotTo(BeNil()) + Expect(logprobs.Content).To(HaveLen(len(tokens))) + + content := logprobs.Content[0] + Expect(content.Token).To(Equal(tokens[0])) + Expect(content.Bytes).To(HaveLen(len(tokens[0]))) + Expect(content.TopLogprobs).To(HaveLen(topLogprobsCount)) + + // Check that the main token is the first in top logprobs + Expect(content.TopLogprobs[0].Token).To(Equal(tokens[0])) + + // Check alternative tokens follow the pattern + expectedAlt1 := "4_1" + Expect(content.TopLogprobs[1].Token).To(Equal(expectedAlt1)) + + // Check byte conversion + expectedBytes := []int{52} // byte value of '4' + for i, expected := range expectedBytes { + Expect(content.Bytes[i]).To(Equal(expected)) + } + + // Check deterministic logprobs + expectedLogprob := -1.0 // defaultLogprob - float64(0%3)*0.1 + Expect(content.Logprob).To(Equal(expectedLogprob)) + }) + }) + + Context("calculateLogprob", func() { + It("should calculate main token probabilities correctly", func() { + // Test position cycle behavior (cycle of 3) + // Position 0: -1.0 - (0 % 3) * 0.1 = -1.0 + result0 := calculateLogprob(0, 0) + Expect(result0).To(Equal(-1.0)) + + // Position 1: -1.0 - (1 % 3) * 0.1 = -1.1 + result1 := calculateLogprob(1, 0) + Expect(result1).To(Equal(-1.1)) + + // Position 2: -1.0 - (2 % 3) * 0.1 = -1.2 + result2 := calculateLogprob(2, 0) + Expect(result2).To(Equal(-1.2)) + + // Position 3: -1.0 - (3 % 3) * 0.1 = -1.0 (cycle repeats) + result3 := calculateLogprob(3, 0) + Expect(result3).To(Equal(-1.0)) + + // Position 4: -1.0 - (4 % 3) * 0.1 = -1.1 (cycle repeats) + result4 := calculateLogprob(4, 0) + Expect(result4).To(Equal(-1.1)) + }) + + It("should calculate alternative token probabilities correctly", func() { + // Test alternative token decrements (0.5 per alternative index) + tokenPosition := 0 // Start with position 0 (main logprob = -1.0) + + // Alternative 1: -1.0 - 1 * 0.5 = -1.5 + alt1 := calculateLogprob(tokenPosition, 1) + Expect(alt1).To(Equal(-1.5)) + + // Alternative 2: -1.0 - 2 * 0.5 = -2.0 + alt2 := calculateLogprob(tokenPosition, 2) + Expect(alt2).To(Equal(-2.0)) + + // Alternative 3: -1.0 - 3 * 0.5 = -2.5 + alt3 := calculateLogprob(tokenPosition, 3) + Expect(alt3).To(Equal(-2.5)) + }) + + It("should combine position cycle and alternative index correctly", func() { + // Test with position 1 (main logprob = -1.1) + tokenPosition := 1 + + // Main token: -1.0 - (1 % 3) * 0.1 = -1.1 + main := calculateLogprob(tokenPosition, 0) + Expect(main).To(Equal(-1.1)) + + // Alternative 1: -1.1 - 1 * 0.5 = -1.6 + alt1 := calculateLogprob(tokenPosition, 1) + Expect(alt1).To(Equal(-1.6)) + + // Alternative 2: -1.1 - 2 * 0.5 = -2.1 + alt2 := calculateLogprob(tokenPosition, 2) + Expect(alt2).To(Equal(-2.1)) + }) + + It("should handle large position values correctly", func() { + // Test with large position values to ensure cycle works + largePosition := 100 + + // Position 100: -1.0 - (100 % 3) * 0.1 = -1.0 - 1 * 0.1 = -1.1 + result := calculateLogprob(largePosition, 0) + Expect(result).To(Equal(-1.1)) + + // With alternative: -1.1 - 1 * 0.5 = -1.6 + resultAlt := calculateLogprob(largePosition, 1) + Expect(resultAlt).To(Equal(-1.6)) + }) + + It("should handle edge cases correctly", func() { + // Test with zero values + result := calculateLogprob(0, 0) + Expect(result).To(Equal(-1.0)) + + // Test with large alternative index + largeAlt := calculateLogprob(0, 10) + expectedLargeAlt := -1.0 - float64(10)*0.5 // -6.0 + Expect(largeAlt).To(Equal(expectedLargeAlt)) + }) + }) + + Context("Other scenarios", func() { + It("should handle empty tokens for text logprobs", func() { + logprobs := GenerateTextLogprobs([]string{}, 2) + + Expect(logprobs).NotTo(BeNil()) + Expect(logprobs.Tokens).To(BeEmpty()) + }) + + It("should handle empty tokens for chat logprobs", func() { + logprobs := GenerateChatLogprobs([]string{}, 2) + + Expect(logprobs).NotTo(BeNil()) + Expect(logprobs.Content).To(BeEmpty()) + }) + + It("should verify probability pattern as token position grows", func() { + // Test the cycling pattern of probabilities + + // Test first cycle (positions 0-2) + prob0 := calculateLogprob(0, 0) + prob1 := calculateLogprob(1, 0) + prob2 := calculateLogprob(2, 0) + + Expect(prob0).To(Equal(-1.0)) // defaultLogprob + Expect(prob1).To(Equal(-1.1)) // defaultLogprob - 1*0.1 + Expect(prob2).To(Equal(-1.2)) // defaultLogprob - 2*0.1 + + // Test second cycle (positions 3-5) - should repeat the pattern + prob3 := calculateLogprob(3, 0) + prob4 := calculateLogprob(4, 0) + prob5 := calculateLogprob(5, 0) + + Expect(prob3).To(Equal(prob0)) // Should equal position 0 + Expect(prob4).To(Equal(prob1)) // Should equal position 1 + Expect(prob5).To(Equal(prob2)) // Should equal position 2 + + // Test third cycle (positions 6-8) - should repeat again + prob6 := calculateLogprob(6, 0) + prob7 := calculateLogprob(7, 0) + prob8 := calculateLogprob(8, 0) + + Expect(prob6).To(Equal(prob0)) // Should equal position 0 + Expect(prob7).To(Equal(prob1)) // Should equal position 1 + Expect(prob8).To(Equal(prob2)) // Should equal position 2 + + // Verify the cycling pattern continues for larger positions + for i := 0; i < 20; i++ { + expectedProb := defaultLogprob - float64(i%positionCycle)*positionDecrement + actualProb := calculateLogprob(i, 0) + Expect(actualProb).To(Equal(expectedProb), "Position %d should have probability %f", i, expectedProb) + } + }) + }) + + Context("No Limits", func() { + It("should allow unlimited logprobs count", func() { + tokens := []string{"test"} + + // Test text completion (no clamping) + textLogprobs := GenerateTextLogprobs(tokens, 10) + Expect(textLogprobs.TopLogprobs[0]).To(HaveLen(10)) + + // Test chat completion (no clamping) + chatLogprobs := GenerateChatLogprobs(tokens, 25) + Expect(chatLogprobs.Content[0].TopLogprobs).To(HaveLen(25)) + + // Test high count + textLogprobs = GenerateTextLogprobs(tokens, 100) + Expect(textLogprobs.TopLogprobs[0]).To(HaveLen(100)) + + chatLogprobs = GenerateChatLogprobs(tokens, 50) + Expect(chatLogprobs.Content[0].TopLogprobs).To(HaveLen(50)) + + // Test minimum (at least 1) + textLogprobs = GenerateTextLogprobs(tokens, 0) + Expect(textLogprobs.TopLogprobs[0]).To(HaveLen(1)) + }) + }) +}) diff --git a/pkg/common/logprobs_utils.go b/pkg/common/logprobs_utils.go new file mode 100644 index 00000000..eed44200 --- /dev/null +++ b/pkg/common/logprobs_utils.go @@ -0,0 +1,260 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package common + +import ( + "fmt" + + openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" +) + +const ( + // Default logprob value + defaultLogprob = -1.0 + // Cycle length for position-based variation + positionCycle = 3 + // Logprob decrement per cycle position + positionDecrement = 0.1 + // Logprob decrement per alternative token + alternativeDecrement = 0.5 +) + +// NOTE: These functions produce synthetic data for API shape compatibility. +// The logprobs are deterministic placeholders and have no semantic meaning. + +// calculateLogprob calculates synthetic log probabilities using a deterministic algorithm. +// For the main token (alternativeIndex = 0), it uses a cycle of 3 positions with decreasing probability. +// For alternative tokens, it decreases probability by 0.5 per alternative index. +// +// Algorithm: +// - Main token: defaultLogprob - (tokenPosition % 3) * 0.1 +// - Alternative: mainTokenLogprob - alternativeIndex * 0.5 +func calculateLogprob(tokenPosition int, alternativeIndex int) float64 { + // Calculate main token probability based on position cycle + mainLogprob := defaultLogprob - float64(tokenPosition%positionCycle)*positionDecrement + + // For main token (index 0), return the main probability + if alternativeIndex == 0 { + return mainLogprob + } + + // For alternatives, decrease by alternativeDecrement per index + return mainLogprob - float64(alternativeIndex)*alternativeDecrement +} + +// GenerateSingleTokenChatLogprobs generates logprobs for a single token in chat completion streaming +func GenerateSingleTokenChatLogprobs(token string, tokenPosition int, topLogprobsCount int) *openaiserverapi.LogprobsContent { + if token == "" { + return nil + } + + // Calculate main token probability + mainLogprob := calculateLogprob(tokenPosition, 0) + tokenBytes := stringToIntBytes(token) + + content := openaiserverapi.LogprobsContent{ + Token: token, + Logprob: mainLogprob, + Bytes: tokenBytes, + } + + // Generate top alternatives if requested + if topLogprobsCount > 0 { + // Pre-size alternatives slice + content.TopLogprobs = make([]openaiserverapi.LogprobsContent, topLogprobsCount) + + // Main token first + content.TopLogprobs[0] = openaiserverapi.LogprobsContent{ + Token: token, + Logprob: mainLogprob, + Bytes: tokenBytes, + } + + // Alternative tokens + for j := 1; j < topLogprobsCount; j++ { + altToken := fmt.Sprintf("%s_%d", token, j) + altLogprob := calculateLogprob(tokenPosition, j) + altBytes := stringToIntBytes(altToken) + + content.TopLogprobs[j] = openaiserverapi.LogprobsContent{ + Token: altToken, + Logprob: altLogprob, + Bytes: altBytes, + } + } + } + + return &content +} + +// GenerateSingleTokenTextLogprobs generates logprobs for a single token in text completion streaming +func GenerateSingleTokenTextLogprobs(token string, tokenPosition int, logprobsCount int) *openaiserverapi.TextLogprobs { + if token == "" { + return nil + } + + // Ensure minimum count + if logprobsCount <= 0 { + logprobsCount = 1 // Include the main token, at a minimum + } + + logprobs := &openaiserverapi.TextLogprobs{ + Tokens: []string{token}, + TokenLogprobs: make([]float64, 1), + TopLogprobs: make([]map[string]float64, 1), + TextOffset: []int{0}, + } + + // Calculate main token probability + mainLogprob := calculateLogprob(tokenPosition, 0) + logprobs.TokenLogprobs[0] = mainLogprob + + topLogprobs := make(map[string]float64, logprobsCount) + topLogprobs[token] = mainLogprob + + // Add alternative tokens + for j := 1; j < logprobsCount; j++ { + altToken := fmt.Sprintf("%s_%d", token, j) + altLogprob := calculateLogprob(tokenPosition, j) + topLogprobs[altToken] = altLogprob + } + + logprobs.TopLogprobs[0] = topLogprobs + + return logprobs +} + +// GenerateTextLogprobs generates synthetic log probabilities for text completion responses +func GenerateTextLogprobs(tokens []string, logprobsCount int) *openaiserverapi.TextLogprobs { + // Return empty struct for empty input (not nil) + if len(tokens) == 0 { + return &openaiserverapi.TextLogprobs{ + Tokens: []string{}, + TokenLogprobs: []float64{}, + TopLogprobs: []map[string]float64{}, + TextOffset: []int{}, + } + } + + // Ensure minimum count + if logprobsCount <= 0 { + logprobsCount = 1 // Include the main token, at least + } + + // Avoid reallocations + numTokens := len(tokens) + logprobs := &openaiserverapi.TextLogprobs{ + Tokens: tokens, + TokenLogprobs: make([]float64, numTokens), + TopLogprobs: make([]map[string]float64, numTokens), + TextOffset: make([]int, numTokens), + } + + offset := 0 + for i, token := range tokens { + logprobs.TextOffset[i] = offset + offset += len(token) // Use byte length + + // Calculate main token probability using helper function + mainLogprob := calculateLogprob(i, 0) + logprobs.TokenLogprobs[i] = mainLogprob + + topLogprobs := make(map[string]float64, logprobsCount) + topLogprobs[token] = mainLogprob + + // Add alternative tokens using helper function + for j := 1; j < logprobsCount; j++ { + altToken := fmt.Sprintf("%s_%d", token, j) + altLogprob := calculateLogprob(i, j) + topLogprobs[altToken] = altLogprob + } + + logprobs.TopLogprobs[i] = topLogprobs + } + + return logprobs +} + +// GenerateChatLogprobs generates synthetic log probabilities for chat completion responses +func GenerateChatLogprobs(tokens []string, topLogprobsCount int) *openaiserverapi.ChatLogprobs { + // Return empty struct for empty input (not nil) + if len(tokens) == 0 { + return &openaiserverapi.ChatLogprobs{ + Content: []openaiserverapi.LogprobsContent{}, + } + } + + numTokens := len(tokens) + logprobs := &openaiserverapi.ChatLogprobs{ + Content: make([]openaiserverapi.LogprobsContent, numTokens), + } + + for i, token := range tokens { + // Calculate main token probability using helper function + mainLogprob := calculateLogprob(i, 0) + + tokenBytes := stringToIntBytes(token) + + content := openaiserverapi.LogprobsContent{ + Token: token, + Logprob: mainLogprob, + Bytes: tokenBytes, + } + + // Generate top alternatives if requested + if topLogprobsCount > 0 { + // Pre-size alternatives slice + content.TopLogprobs = make([]openaiserverapi.LogprobsContent, topLogprobsCount) + + // Main token first + content.TopLogprobs[0] = openaiserverapi.LogprobsContent{ + Token: token, + Logprob: mainLogprob, + Bytes: tokenBytes, + } + + // Alternative tokens using helper function + for j := 1; j < topLogprobsCount; j++ { + altToken := fmt.Sprintf("%s_%d", token, j) + altLogprob := calculateLogprob(i, j) + altBytes := stringToIntBytes(altToken) + + content.TopLogprobs[j] = openaiserverapi.LogprobsContent{ + Token: altToken, + Logprob: altLogprob, + Bytes: altBytes, + } + } + } + + logprobs.Content[i] = content + } + + return logprobs +} + +// stringToIntBytes converts a string to []int of byte values inline +func stringToIntBytes(s string) []int { + if s == "" { + return nil + } + out := make([]int, len(s)) + for i := range out { + out[i] = int(s[i]) + } + return out +} diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 89bec6c7..df5acc29 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -551,13 +551,14 @@ func (s *VllmSimulator) responseSentCallback(model string, isChatCompletion bool // createCompletionResponse creates the response for completion requests, supports both completion request types (text and chat) // as defined by isChatCompletion +// logprobs - nil if no logprobs needed, otherwise number of logprob options to include // respTokens - tokenized content to be sent in the response // toolCalls - tool calls to be sent in the response // finishReason - a pointer to string that represents finish reason, can be nil or stop or length, ... // usageData - usage (tokens statistics) for this response // modelName - display name returned to the client and used in metrics. It is either the first alias // from --served-model-name (for a base-model request) or the LoRA adapter name (for a LoRA request). -func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respTokens []string, toolCalls []openaiserverapi.ToolCall, +func (s *VllmSimulator) createCompletionResponse(logprobs *int, isChatCompletion bool, respTokens []string, toolCalls []openaiserverapi.ToolCall, finishReason *string, usageData *openaiserverapi.Usage, modelName string, doRemoteDecode bool) openaiserverapi.CompletionResponse { baseResp := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(), time.Now().Unix(), modelName, usageData) @@ -585,13 +586,42 @@ func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respToke } else { message.Content = openaiserverapi.Content{Raw: respText} } - return openaiserverapi.CreateChatCompletionResponse(baseResp, - []openaiserverapi.ChatRespChoice{openaiserverapi.CreateChatRespChoice(baseChoice, message)}) + + choice := openaiserverapi.CreateChatRespChoice(baseChoice, message) + + // Generate logprobs if requested + if logprobs != nil && toolCalls == nil { + if logprobsData := common.GenerateChatLogprobs(respTokens, *logprobs); logprobsData != nil && len(logprobsData.Content) > 0 { + choice.Logprobs = logprobsData + } else { + // Set to nil if generation failed or content is empty + choice.Logprobs = nil + } + } else { + // Explicitly ensure logprobs is nil when not requested + choice.Logprobs = nil + } + + return openaiserverapi.CreateChatCompletionResponse(baseResp, []openaiserverapi.ChatRespChoice{choice}) + } + + choice := openaiserverapi.CreateTextRespChoice(baseChoice, respText) + + // Generate logprobs if requested for text completion + if logprobs != nil && *logprobs > 0 { + if logprobsData := common.GenerateTextLogprobs(respTokens, *logprobs); logprobsData != nil && len(logprobsData.Tokens) > 0 { + choice.Logprobs = logprobsData + } else { + // Set to nil if generation failed or tokens is empty + choice.Logprobs = nil + } + } else { + // Explicitly ensure logprobs is nil when not requested + choice.Logprobs = nil } baseResp.Object = textCompletionObject - return openaiserverapi.CreateTextCompletionResponse(baseResp, - []openaiserverapi.TextRespChoice{openaiserverapi.CreateTextRespChoice(baseChoice, respText)}) + return openaiserverapi.CreateTextCompletionResponse(baseResp, []openaiserverapi.TextRespChoice{choice}) } // sendResponse sends response for completion API, supports both completions (text and chat) @@ -604,7 +634,13 @@ func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respToke // usageData - usage (tokens statistics) for this response func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, respTokens []string, toolCalls []openaiserverapi.ToolCall, modelName string, finishReason string, usageData *openaiserverapi.Usage) { - resp := s.createCompletionResponse(reqCtx.IsChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName, + // Extract logprob data from request (unified approach) + var logprobs *int + if toolCalls == nil { + logprobs = reqCtx.CompletionReq.GetLogprobs() + } + + resp := s.createCompletionResponse(logprobs, reqCtx.IsChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName, reqCtx.CompletionReq.IsDoRemoteDecode()) // calculate how long to wait before returning the response, time is based on number of tokens diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index 010b82ad..38e0ad93 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -399,6 +399,207 @@ var _ = Describe("Simulator", func() { }) }) + Context("logprobs functionality", func() { + DescribeTable("streaming chat completions with logprobs", + func(mode string, logprobs bool, topLogprobs int) { + ctx := context.TODO() + client, err := startServer(ctx, mode) + Expect(err).NotTo(HaveOccurred()) + + openaiclient, params := getOpenAIClientAndChatParams(client, testModel, testUserMessage, true) + params.Logprobs = param.NewOpt(logprobs) + if logprobs && topLogprobs > 0 { + params.TopLogprobs = param.NewOpt(int64(topLogprobs)) + } + + stream := openaiclient.Chat.Completions.NewStreaming(ctx, params) + defer func() { + err := stream.Close() + Expect(err).NotTo(HaveOccurred()) + }() + + tokens := []string{} + chunksWithLogprobs := 0 + + for stream.Next() { + chunk := stream.Current() + for _, choice := range chunk.Choices { + if choice.FinishReason == "" && choice.Delta.Content != "" { + tokens = append(tokens, choice.Delta.Content) + + // Check logprobs in streaming chunks + if logprobs && len(choice.Logprobs.Content) > 0 { + chunksWithLogprobs++ + logprobContent := choice.Logprobs.Content[0] + Expect(logprobContent.Token).To(Equal(choice.Delta.Content)) + Expect(logprobContent.Logprob).To(BeNumerically("<=", 0)) + + if topLogprobs > 0 { + Expect(logprobContent.TopLogprobs).To(HaveLen(topLogprobs)) + Expect(logprobContent.TopLogprobs[0].Token).To(Equal(choice.Delta.Content)) + } + } + } + } + } + + msg := strings.Join(tokens, "") + if mode == common.ModeRandom { + Expect(dataset.IsValidText(msg)).To(BeTrue()) + } else { + Expect(msg).Should(Equal(testUserMessage)) + } + + // Verify logprobs behaviour + if logprobs { + Expect(chunksWithLogprobs).To(BeNumerically(">", 0), "Should have chunks with logprobs") + } else { + Expect(chunksWithLogprobs).To(Equal(0), "Should not have chunks with logprobs when not requested") + } + }, + func(mode string, logprobs bool, topLogprobs int) string { + return fmt.Sprintf("mode: %s logprobs: %t top_logprobs: %d", mode, logprobs, topLogprobs) + }, + Entry(nil, common.ModeEcho, true, 0), // logprobs=true, default top_logprobs + Entry(nil, common.ModeEcho, true, 2), // logprobs=true, top_logprobs=2 + Entry(nil, common.ModeEcho, false, 0), // logprobs=false + ) + + DescribeTable("streaming text completions with logprobs", + func(mode string, logprobsCount int) { + ctx := context.TODO() + client, err := startServer(ctx, mode) + Expect(err).NotTo(HaveOccurred()) + + openaiclient, params := getOpenAIClentAndCompletionParams(client, testModel, testUserMessage, true) + if logprobsCount > 0 { + params.Logprobs = param.NewOpt(int64(logprobsCount)) + } + + stream := openaiclient.Completions.NewStreaming(ctx, params) + defer func() { + err := stream.Close() + Expect(err).NotTo(HaveOccurred()) + }() + + tokens := []string{} + chunksWithLogprobs := 0 + + for stream.Next() { + chunk := stream.Current() + for _, choice := range chunk.Choices { + if choice.FinishReason == "" && choice.Text != "" { + tokens = append(tokens, choice.Text) + + // Check logprobs in streaming chunks + if logprobsCount > 0 && len(choice.Logprobs.Tokens) > 0 { + chunksWithLogprobs++ + Expect(choice.Logprobs.Tokens[0]).To(Equal(choice.Text)) + Expect(choice.Logprobs.TokenLogprobs[0]).To(BeNumerically("<=", 0)) + Expect(choice.Logprobs.TopLogprobs[0]).To(HaveLen(logprobsCount)) + Expect(choice.Logprobs.TopLogprobs[0]).To(HaveKey(choice.Text)) + } + } + } + } + + text := strings.Join(tokens, "") + if mode == common.ModeRandom { + Expect(dataset.IsValidText(text)).To(BeTrue()) + } else { + Expect(text).Should(Equal(testUserMessage)) + } + + // Verify logprobs behaviour + if logprobsCount > 0 { + Expect(chunksWithLogprobs).To(BeNumerically(">", 0), "Should have chunks with logprobs") + } else { + Expect(chunksWithLogprobs).To(Equal(0), "Should not have chunks with logprobs when not requested") + } + }, + func(mode string, logprobsCount int) string { + return fmt.Sprintf("mode: %s logprobs: %d", mode, logprobsCount) + }, + Entry(nil, common.ModeEcho, 0), // No logprobs + Entry(nil, common.ModeEcho, 2), // logprobs=2 + ) + + DescribeTable("non-streaming completions with logprobs", + func(isChat bool, mode string, logprobsParam interface{}) { + ctx := context.TODO() + client, err := startServer(ctx, mode) + Expect(err).NotTo(HaveOccurred()) + + var resp interface{} + + if isChat { + openaiclient, params := getOpenAIClientAndChatParams(client, testModel, testUserMessage, false) + if logprobsParam != nil { + if logprobs, ok := logprobsParam.(bool); ok && logprobs { + params.Logprobs = param.NewOpt(true) + params.TopLogprobs = param.NewOpt(int64(2)) + } + } + resp, err = openaiclient.Chat.Completions.New(ctx, params) + } else { + openaiclient, params := getOpenAIClentAndCompletionParams(client, testModel, testUserMessage, false) + if logprobsParam != nil { + if logprobsCount, ok := logprobsParam.(int); ok && logprobsCount > 0 { + params.Logprobs = param.NewOpt(int64(logprobsCount)) + } + } + resp, err = openaiclient.Completions.New(ctx, params) + } + + Expect(err).NotTo(HaveOccurred()) + + // Verify logprobs in non-streaming response + if isChat { + chatResp := resp.(*openai.ChatCompletion) + Expect(chatResp.Choices).ShouldNot(BeEmpty()) + + if logprobsParam != nil { + // When logprobs requested, Content should be populated + Expect(chatResp.Choices[0].Logprobs.Content).NotTo(BeEmpty()) + + tokens := common.Tokenize(chatResp.Choices[0].Message.Content) + Expect(chatResp.Choices[0].Logprobs.Content).To(HaveLen(len(tokens))) + } else { + // When logprobs not requested, Content should be empty/nil + // Note: SDK uses nullable types, so we check the Content field + Expect(chatResp.Choices[0].Logprobs.Content).To(BeNil()) + } + } else { + textResp := resp.(*openai.Completion) + Expect(textResp.Choices).ShouldNot(BeEmpty()) + + if logprobsParam != nil { + // When logprobs requested, fields should be populated + Expect(textResp.Choices[0].Logprobs.Tokens).NotTo(BeNil()) + + tokens := common.Tokenize(textResp.Choices[0].Text) + Expect(textResp.Choices[0].Logprobs.Tokens).To(HaveLen(len(tokens))) + } else { + // When logprobs not requested, all fields should be empty/nil + Expect(textResp.Choices[0].Logprobs.Tokens).To(BeNil()) + } + } + }, + func(isChat bool, mode string, logprobsParam interface{}) string { + apiType := "text" + if isChat { + apiType = "chat" + } + return fmt.Sprintf("%s mode: %s logprobs: %v", apiType, mode, logprobsParam) + }, + Entry(nil, true, common.ModeEcho, true), // Chat with logprobs + Entry(nil, true, common.ModeEcho, nil), // Chat without logprobs + Entry(nil, false, common.ModeEcho, 2), // Text with logprobs=2 + Entry(nil, false, common.ModeEcho, nil), // Text without logprobs + ) + + }) + Context("max-model-len context window validation", func() { It("Should reject requests exceeding context window", func() { ctx := context.TODO() diff --git a/pkg/llm-d-inference-sim/streaming.go b/pkg/llm-d-inference-sim/streaming.go index 7f220043..2a136beb 100644 --- a/pkg/llm-d-inference-sim/streaming.go +++ b/pkg/llm-d-inference-sim/streaming.go @@ -38,6 +38,8 @@ type streamingContext struct { nPromptTokens int nCachedPromptTokens int requestID string + // Logprobs configuration - nil if no logprobs, otherwise number of options + logprobs *int } // sendStreamingResponse creates and sends a streaming response for completion requests of both types (text and chat) @@ -184,14 +186,25 @@ func (s *VllmSimulator) createUsageChunk(context *streamingContext, usageData *o } // createTextCompletionChunk creates and returns a CompletionRespChunk, a single chunk of streamed completion API response, -// for text completion +// for text completion. func (s *VllmSimulator) createTextCompletionChunk(context *streamingContext, token string, finishReason *string) openaiserverapi.CompletionRespChunk { baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(), context.creationTime, context.model, nil) baseChunk.Object = textCompletionObject - return openaiserverapi.CreateTextCompletionResponse(baseChunk, - []openaiserverapi.TextRespChoice{ - openaiserverapi.CreateTextRespChoice(openaiserverapi.CreateBaseResponseChoice(0, finishReason), token)}) + + choice := openaiserverapi.CreateTextRespChoice(openaiserverapi.CreateBaseResponseChoice(0, finishReason), token) + + // Generate logprobs if requested and token is not empty + if context.logprobs != nil && token != "" && *context.logprobs > 0 { + // Use token position based on current time + tokenPosition := int(context.creationTime) % 1000 // Simple position simulation + logprobs := common.GenerateSingleTokenTextLogprobs(token, tokenPosition, *context.logprobs) + if logprobs != nil { + choice.Logprobs = logprobs + } + } + + return openaiserverapi.CreateTextCompletionResponse(baseChunk, []openaiserverapi.TextRespChoice{choice}) } // createChatCompletionChunk creates and returns a CompletionRespChunk, a single chunk of streamed completion @@ -213,6 +226,18 @@ func (s *VllmSimulator) createChatCompletionChunk(context *streamingContext, tok chunk.Choices[0].Delta.ToolCalls = []openaiserverapi.ToolCall{*tool} } else if len(token) > 0 { chunk.Choices[0].Delta.Content.Raw = token + + // Generate logprobs if requested and token is not empty + if context.logprobs != nil { + // Use token position based on current time + tokenPosition := int(context.creationTime) % 1000 // Simple position simulation + logprobs := common.GenerateSingleTokenChatLogprobs(token, tokenPosition, *context.logprobs) + if logprobs != nil { + chunk.Choices[0].Logprobs = &openaiserverapi.ChatLogprobs{ + Content: []openaiserverapi.LogprobsContent{*logprobs}, + } + } + } } return &chunk diff --git a/pkg/llm-d-inference-sim/worker.go b/pkg/llm-d-inference-sim/worker.go index c1fe3719..1dc796a6 100644 --- a/pkg/llm-d-inference-sim/worker.go +++ b/pkg/llm-d-inference-sim/worker.go @@ -134,6 +134,9 @@ func (s *VllmSimulator) processRequest(reqCtx *openaiserverapi.CompletionReqCtx) doRemotePrefill: req.IsDoRemotePrefill(), nPromptTokens: usageData.PromptTokens, nCachedPromptTokens: reqCtx.CompletionReq.GetNumberOfCachedPromptTokens(), + requestID: req.GetRequestID(), + // Logprobs configuration + logprobs: req.GetLogprobs(), }, responseTokens, toolCalls, finishReason, usageDataToSend, ) diff --git a/pkg/openai-server-api/request.go b/pkg/openai-server-api/request.go index 9326b52b..d611f464 100644 --- a/pkg/openai-server-api/request.go +++ b/pkg/openai-server-api/request.go @@ -72,6 +72,8 @@ type CompletionRequest interface { // for chat completion - max_completion_tokens field is used // for text completion - max_tokens field is used ExtractMaxTokens() *int64 + // GetLogprobs returns nil if no logprobs needed, or pointer to number of logprob options to include + GetLogprobs() *int } // baseCompletionRequest contains base completion request related information @@ -187,6 +189,12 @@ type ChatCompletionRequest struct { // ToolChoice controls which (if any) tool is called by the model. // It can be a string ("none", "auto", "required") or an object specifying the function. ToolChoice ToolChoice `json:"tool_choice,omitzero"` + + // Logprobs controls whether log probabilities are included in the response + Logprobs bool `json:"logprobs,omitempty"` + + // TopLogprobs controls how many alternative tokens to include in the logprobs + TopLogprobs *int `json:"top_logprobs,omitempty"` } var _ CompletionRequest = (*ChatCompletionRequest)(nil) @@ -272,6 +280,18 @@ func (req *ChatCompletionRequest) ExtractMaxTokens() *int64 { return req.GetMaxCompletionTokens() } +func (c *ChatCompletionRequest) GetLogprobs() *int { + if !c.Logprobs { + return nil // No logprobs requested + } + if c.TopLogprobs != nil { + return c.TopLogprobs // Return the top_logprobs value + } + // Default to 1 if logprobs=true but no top_logprobs specified + defaultVal := 1 + return &defaultVal +} + // v1/completion // TextCompletionRequest defines structure of /completion request type TextCompletionRequest struct { @@ -285,6 +305,12 @@ type TextCompletionRequest struct { // The token count of your prompt plus `max_tokens` cannot exceed the model's // context length. MaxTokens *int64 `json:"max_tokens"` + + // Logprobs includes the log probabilities on the logprobs most likely tokens, + // as well the chosen tokens. For example, if logprobs is 5, the API will return + // a list of the 5 most likely tokens. The API will always return the logprob + // of the sampled token, so there may be up to logprobs+1 elements in the response. + Logprobs *int `json:"logprobs,omitempty"` } var _ CompletionRequest = (*TextCompletionRequest)(nil) @@ -320,3 +346,7 @@ func (req *TextCompletionRequest) ExtractPrompt() string { func (req *TextCompletionRequest) ExtractMaxTokens() *int64 { return req.MaxTokens } + +func (t *TextCompletionRequest) GetLogprobs() *int { + return t.Logprobs +} diff --git a/pkg/openai-server-api/response.go b/pkg/openai-server-api/response.go index 0398a858..a2bf066a 100644 --- a/pkg/openai-server-api/response.go +++ b/pkg/openai-server-api/response.go @@ -64,6 +64,36 @@ type Usage struct { TotalTokens int `json:"total_tokens"` } +// LogprobsContent represents logprobs for a single token in chat completions +type LogprobsContent struct { + // Token is the token string + Token string `json:"token"` + // Logprob is the log probability of the token + Logprob float64 `json:"logprob"` + // Bytes is the byte representation of the token + Bytes []int `json:"bytes"` + // TopLogprobs is the list of top alternative tokens along their log probabilities + TopLogprobs []LogprobsContent `json:"top_logprobs,omitempty"` +} + +// ChatLogprobs represents logprobs for chat completion responses +type ChatLogprobs struct { + // Content is an array of logprobs for each token in the content + Content []LogprobsContent `json:"content"` +} + +// TextLogprobs represents logprobs for text completion responses +type TextLogprobs struct { + // Tokens is an array of tokens + Tokens []string `json:"tokens"` + // TokenLogprobs is an array of log probabilities for each token + TokenLogprobs []float64 `json:"token_logprobs"` + // TopLogprobs is an array of objects containing the top alternative tokens + TopLogprobs []map[string]float64 `json:"top_logprobs"` + // TextOffset is an array of character offsets + TextOffset []int `json:"text_offset"` +} + // ChatCompletionResponse defines structure of /chat/completion response type ChatCompletionResponse struct { baseCompletionResponse @@ -175,6 +205,8 @@ type ChatRespChoice struct { baseResponseChoice // Message contains choice's Message Message Message `json:"message"` + // Logprobs contains the log probabilities for the response + Logprobs *ChatLogprobs `json:"logprobs,omitempty"` } // TextCompletionResponse defines structure of /completion response @@ -189,6 +221,8 @@ type TextRespChoice struct { baseResponseChoice // Text defines request's content Text string `json:"text"` + // Logprobs contains the log probabilities for the response + Logprobs *TextLogprobs `json:"logprobs,omitempty"` } // CompletionRespChunk is an interface that defines a single response chunk @@ -206,6 +240,8 @@ type ChatRespChunkChoice struct { baseResponseChoice // Delta is a content of the chunk Delta Message `json:"delta"` + // Logprobs contains the log probabilities for the response chunk + Logprobs *ChatLogprobs `json:"logprobs,omitempty"` } // CompletionError defines the simulator's response in case of an error @@ -266,15 +302,15 @@ func CreateBaseResponseChoice(index int, finishReason *string) baseResponseChoic } func CreateChatRespChoice(base baseResponseChoice, message Message) ChatRespChoice { - return ChatRespChoice{baseResponseChoice: base, Message: message} + return ChatRespChoice{baseResponseChoice: base, Message: message, Logprobs: nil} } func CreateChatRespChunkChoice(base baseResponseChoice, message Message) ChatRespChunkChoice { - return ChatRespChunkChoice{baseResponseChoice: base, Delta: message} + return ChatRespChunkChoice{baseResponseChoice: base, Delta: message, Logprobs: nil} } func CreateTextRespChoice(base baseResponseChoice, text string) TextRespChoice { - return TextRespChoice{baseResponseChoice: base, Text: text} + return TextRespChoice{baseResponseChoice: base, Text: text, Logprobs: nil} } func CreateBaseCompletionResponse(id string, created int64, model string, usage *Usage) baseCompletionResponse {