Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions pkg/model/provider/anthropic/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@ import (

// streamAdapter adapts the Anthropic stream to our interface
type streamAdapter struct {
stream *ssestream.Stream[anthropic.MessageStreamEventUnion]
toolCall bool
toolID string
stream *ssestream.Stream[anthropic.MessageStreamEventUnion]
trackUsage bool
toolCall bool
toolID string
}

func newStreamAdapter(stream *ssestream.Stream[anthropic.MessageStreamEventUnion]) *streamAdapter {
func newStreamAdapter(stream *ssestream.Stream[anthropic.MessageStreamEventUnion], trackUsage bool) *streamAdapter {
return &streamAdapter{
stream: stream,
stream: stream,
trackUsage: trackUsage,
}
}

Expand Down Expand Up @@ -96,11 +98,13 @@ func (a *streamAdapter) Recv() (chat.MessageStreamResponse, error) {
return response, fmt.Errorf("unknown delta type: %T", deltaVariant)
}
case anthropic.MessageDeltaEvent:
response.Usage = &chat.Usage{
InputTokens: int(eventVariant.Usage.InputTokens),
OutputTokens: int(eventVariant.Usage.OutputTokens),
CachedInputTokens: int(eventVariant.Usage.CacheReadInputTokens),
CachedOutputTokens: int(eventVariant.Usage.CacheCreationInputTokens),
if a.trackUsage {
response.Usage = &chat.Usage{
InputTokens: int(eventVariant.Usage.InputTokens),
OutputTokens: int(eventVariant.Usage.OutputTokens),
CachedInputTokens: int(eventVariant.Usage.CacheReadInputTokens),
CachedOutputTokens: int(eventVariant.Usage.CacheCreationInputTokens),
}
}
case anthropic.MessageStopEvent:
if a.toolCall {
Expand Down
4 changes: 3 additions & 1 deletion pkg/model/provider/anthropic/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,9 @@ func (c *Client) CreateChatCompletionStream(
}

stream := client.Messages.NewStreaming(ctx, params)
ad := newStreamAdapter(stream)
trackUsage := c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage
ad := newStreamAdapter(stream, trackUsage)

slog.Debug("Anthropic chat completion stream created successfully", "model", c.ModelConfig.Model)
return ad, nil
}
Expand Down
10 changes: 6 additions & 4 deletions pkg/model/provider/gemini/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
type StreamAdapter struct {
ch chan result
model string
trackUsage bool
mu sync.Mutex
lastResponse *genai.GenerateContentResponse // Store last response for final message
}
Expand All @@ -30,10 +31,11 @@ type result struct {
}

// NewStreamAdapter constructs a StreamAdapter from Gemini's iterator
func NewStreamAdapter(iter func(func(*genai.GenerateContentResponse, error) bool), model string) *StreamAdapter {
func NewStreamAdapter(iter func(func(*genai.GenerateContentResponse, error) bool), model string, trackUsage bool) *StreamAdapter {
adapter := &StreamAdapter{
ch: make(chan result),
model: model,
ch: make(chan result),
model: model,
trackUsage: trackUsage,
}

go func() {
Expand Down Expand Up @@ -173,7 +175,7 @@ func (g *StreamAdapter) Recv() (chat.MessageStreamResponse, error) {
resp.ID = res.resp.ResponseID

// Handle token usage if present
if res.resp.UsageMetadata != nil {
if res.resp.UsageMetadata != nil && g.trackUsage {
resp.Usage = &chat.Usage{
InputTokens: int(res.resp.UsageMetadata.PromptTokenCount),
OutputTokens: int(res.resp.UsageMetadata.CandidatesTokenCount),
Expand Down
2 changes: 1 addition & 1 deletion pkg/model/provider/gemini/adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestStreamAdapter_FunctionCalls(t *testing.T) {
fn(mockResp, nil)
}

adapter := NewStreamAdapter(iter, "test-model")
adapter := NewStreamAdapter(iter, "test-model", true)

// Read the response
resp, err := adapter.Recv()
Expand Down
3 changes: 2 additions & 1 deletion pkg/model/provider/gemini/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,8 @@ func (c *Client) CreateChatCompletionStream(

// Build a fresh client per request when using the gateway
iter := client.Models.GenerateContentStream(ctx, c.ModelConfig.Model, contents, config)
return NewStreamAdapter(iter, c.ModelConfig.Model), nil
trackUsage := c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage
return NewStreamAdapter(iter, c.ModelConfig.Model, trackUsage), nil
}

// Rerank scores documents by relevance to the query using Gemini's structured
Expand Down
29 changes: 16 additions & 13 deletions pkg/model/provider/oaistream/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,23 @@ func (a *StreamAdapter) Recv() (chat.MessageStreamResponse, error) {

// Check if Usage field is present using the JSON metadata
if openaiResponse.JSON.Usage.Valid() {
usage := openaiResponse.Usage
response.Usage = &chat.Usage{
InputTokens: int(usage.PromptTokens),
OutputTokens: int(usage.CompletionTokens),
CachedInputTokens: 0,
CachedOutputTokens: 0,
ReasoningTokens: 0,
}
if usage.JSON.PromptTokensDetails.Valid() {
response.Usage.CachedInputTokens = int(usage.PromptTokensDetails.CachedTokens)
}
if usage.JSON.CompletionTokensDetails.Valid() {
response.Usage.ReasoningTokens = int(usage.CompletionTokensDetails.ReasoningTokens)
if a.trackUsage {
usage := openaiResponse.Usage
response.Usage = &chat.Usage{
InputTokens: int(usage.PromptTokens),
OutputTokens: int(usage.CompletionTokens),
CachedInputTokens: 0,
CachedOutputTokens: 0,
ReasoningTokens: 0,
}
if usage.JSON.PromptTokensDetails.Valid() {
response.Usage.CachedInputTokens = int(usage.PromptTokensDetails.CachedTokens)
}
if usage.JSON.CompletionTokensDetails.Valid() {
response.Usage.ReasoningTokens = int(usage.CompletionTokensDetails.ReasoningTokens)
}
}

// Use the tracked finish reason instead of hardcoding stop
finishReason := a.lastFinishReason
if finishReason == chat.FinishReasonNull || finishReason == "" {
Expand Down
Loading