Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,18 @@ import (
type ChatCompletionNewParamsWrapper struct {
openai.ChatCompletionNewParams `json:""`
Stream bool `json:"stream,omitempty"`
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
}

func (c ChatCompletionNewParamsWrapper) MarshalJSON() ([]byte, error) {
type shadow ChatCompletionNewParamsWrapper
return param.MarshalWithExtras(c, (*shadow)(&c), map[string]any{
extras := map[string]any{
"stream": c.Stream,
})
}
if c.MaxCompletionTokens != nil {
extras["max_completion_tokens"] = *c.MaxCompletionTokens
}
return param.MarshalWithExtras(c, (*shadow)(&c), extras)
}

func (c *ChatCompletionNewParamsWrapper) UnmarshalJSON(raw []byte) error {
Expand All @@ -43,6 +48,33 @@ func (c *ChatCompletionNewParamsWrapper) UnmarshalJSON(raw []byte) error {
c.ChatCompletionNewParams.StreamOptions = openai.ChatCompletionStreamOptionsParam{}
}

// Extract max_completion_tokens if present and positive
// OpenAI API requires positive integers for token limits
var data map[string]any
if err := json.Unmarshal(raw, &data); err == nil {
if val, exists := data["max_completion_tokens"]; exists {
// Field is explicitly set, convert to int
var tokens int
switch v := val.(type) {
case float64:
tokens = int(v)
case int:
tokens = v
case int64:
tokens = int(v)
default:
// Invalid type, skip
return nil
}
// Only set if positive (0 and negative values are invalid)
if tokens > 0 {
c.MaxCompletionTokens = &tokens
// Set it in the underlying params as well
c.ChatCompletionNewParams.MaxCompletionTokens = openai.Int(int64(tokens))
}
}
}

return nil
}

Expand Down
86 changes: 86 additions & 0 deletions openai_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package aibridge_test

import (
"encoding/json"
"testing"

"github.com/coder/aibridge"
Expand Down Expand Up @@ -131,3 +132,88 @@ func TestOpenAILastUserPrompt(t *testing.T) {
})
}
}

func TestMaxCompletionTokens(t *testing.T) {
t.Parallel()

t.Run("unmarshal max_completion_tokens from JSON", func(t *testing.T) {
jsonStr := `{
"model": "gpt-4o",
"messages": [{"role": "user", "content": "Hello"}],
"max_completion_tokens": 1024
}`

var wrapper aibridge.ChatCompletionNewParamsWrapper
err := json.Unmarshal([]byte(jsonStr), &wrapper)
require.NoError(t, err)
require.NotNil(t, wrapper.MaxCompletionTokens)
require.Equal(t, 1024, *wrapper.MaxCompletionTokens)
})

t.Run("unmarshal max_completion_tokens with zero value ignored", func(t *testing.T) {
jsonStr := `{
"model": "gpt-4o",
"messages": [{"role": "user", "content": "Hello"}],
"max_completion_tokens": 0
}`

var wrapper aibridge.ChatCompletionNewParamsWrapper
err := json.Unmarshal([]byte(jsonStr), &wrapper)
require.NoError(t, err)
require.Nil(t, wrapper.MaxCompletionTokens, "max_completion_tokens should not be set when 0 (invalid value)")
})

t.Run("unmarshal max_completion_tokens with negative value ignored", func(t *testing.T) {
jsonStr := `{
"model": "gpt-4o",
"messages": [{"role": "user", "content": "Hello"}],
"max_completion_tokens": -100
}`

var wrapper aibridge.ChatCompletionNewParamsWrapper
err := json.Unmarshal([]byte(jsonStr), &wrapper)
require.NoError(t, err)
require.Nil(t, wrapper.MaxCompletionTokens, "max_completion_tokens should not be set when negative (invalid value)")
})

t.Run("marshal max_completion_tokens to JSON", func(t *testing.T) {
maxTokens := 2048
wrapper := aibridge.ChatCompletionNewParamsWrapper{
ChatCompletionNewParams: openai.ChatCompletionNewParams{
Model: openai.ChatModelGPT4o,
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage("Hello"),
},
},
MaxCompletionTokens: &maxTokens,
}

jsonBytes, err := json.Marshal(wrapper)
require.NoError(t, err)

var result map[string]interface{}
err = json.Unmarshal(jsonBytes, &result)
require.NoError(t, err)
require.Equal(t, float64(2048), result["max_completion_tokens"])
})

t.Run("max_completion_tokens not set when nil", func(t *testing.T) {
wrapper := aibridge.ChatCompletionNewParamsWrapper{
ChatCompletionNewParams: openai.ChatCompletionNewParams{
Model: openai.ChatModelGPT4o,
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage("Hello"),
},
},
}

jsonBytes, err := json.Marshal(wrapper)
require.NoError(t, err)

var result map[string]interface{}
err = json.Unmarshal(jsonBytes, &result)
require.NoError(t, err)
_, exists := result["max_completion_tokens"]
require.False(t, exists, "max_completion_tokens should not be present when nil")
})
}