diff --git a/pkg/model/provider/anthropic/adapter.go b/pkg/model/provider/anthropic/adapter.go index 0955a2d9a..65c1b5339 100644 --- a/pkg/model/provider/anthropic/adapter.go +++ b/pkg/model/provider/anthropic/adapter.go @@ -13,14 +13,16 @@ import ( // streamAdapter adapts the Anthropic stream to our interface type streamAdapter struct { - stream *ssestream.Stream[anthropic.MessageStreamEventUnion] - toolCall bool - toolID string + stream *ssestream.Stream[anthropic.MessageStreamEventUnion] + trackUsage bool + toolCall bool + toolID string } -func newStreamAdapter(stream *ssestream.Stream[anthropic.MessageStreamEventUnion]) *streamAdapter { +func newStreamAdapter(stream *ssestream.Stream[anthropic.MessageStreamEventUnion], trackUsage bool) *streamAdapter { return &streamAdapter{ - stream: stream, + stream: stream, + trackUsage: trackUsage, } } @@ -96,11 +98,13 @@ func (a *streamAdapter) Recv() (chat.MessageStreamResponse, error) { return response, fmt.Errorf("unknown delta type: %T", deltaVariant) } case anthropic.MessageDeltaEvent: - response.Usage = &chat.Usage{ - InputTokens: int(eventVariant.Usage.InputTokens), - OutputTokens: int(eventVariant.Usage.OutputTokens), - CachedInputTokens: int(eventVariant.Usage.CacheReadInputTokens), - CachedOutputTokens: int(eventVariant.Usage.CacheCreationInputTokens), + if a.trackUsage { + response.Usage = &chat.Usage{ + InputTokens: int(eventVariant.Usage.InputTokens), + OutputTokens: int(eventVariant.Usage.OutputTokens), + CachedInputTokens: int(eventVariant.Usage.CacheReadInputTokens), + CachedOutputTokens: int(eventVariant.Usage.CacheCreationInputTokens), + } } case anthropic.MessageStopEvent: if a.toolCall { diff --git a/pkg/model/provider/anthropic/client.go b/pkg/model/provider/anthropic/client.go index cb08854cc..91277ecae 100644 --- a/pkg/model/provider/anthropic/client.go +++ b/pkg/model/provider/anthropic/client.go @@ -256,7 +256,9 @@ func (c *Client) CreateChatCompletionStream( } stream := client.Messages.NewStreaming(ctx, params) - ad := newStreamAdapter(stream) + trackUsage := c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage + ad := newStreamAdapter(stream, trackUsage) + slog.Debug("Anthropic chat completion stream created successfully", "model", c.ModelConfig.Model) return ad, nil } diff --git a/pkg/model/provider/gemini/adapter.go b/pkg/model/provider/gemini/adapter.go index fb8a8b0dc..335a578d9 100644 --- a/pkg/model/provider/gemini/adapter.go +++ b/pkg/model/provider/gemini/adapter.go @@ -19,6 +19,7 @@ import ( type StreamAdapter struct { ch chan result model string + trackUsage bool mu sync.Mutex lastResponse *genai.GenerateContentResponse // Store last response for final message } @@ -30,10 +31,11 @@ type result struct { } // NewStreamAdapter constructs a StreamAdapter from Gemini's iterator -func NewStreamAdapter(iter func(func(*genai.GenerateContentResponse, error) bool), model string) *StreamAdapter { +func NewStreamAdapter(iter func(func(*genai.GenerateContentResponse, error) bool), model string, trackUsage bool) *StreamAdapter { adapter := &StreamAdapter{ - ch: make(chan result), - model: model, + ch: make(chan result), + model: model, + trackUsage: trackUsage, } go func() { @@ -173,7 +175,7 @@ func (g *StreamAdapter) Recv() (chat.MessageStreamResponse, error) { resp.ID = res.resp.ResponseID // Handle token usage if present - if res.resp.UsageMetadata != nil { + if res.resp.UsageMetadata != nil && g.trackUsage { resp.Usage = &chat.Usage{ InputTokens: int(res.resp.UsageMetadata.PromptTokenCount), OutputTokens: int(res.resp.UsageMetadata.CandidatesTokenCount), diff --git a/pkg/model/provider/gemini/adapter_test.go b/pkg/model/provider/gemini/adapter_test.go index 768c0b712..08ffabb9e 100644 --- a/pkg/model/provider/gemini/adapter_test.go +++ b/pkg/model/provider/gemini/adapter_test.go @@ -34,7 +34,7 @@ func TestStreamAdapter_FunctionCalls(t *testing.T) { fn(mockResp, nil) } - adapter := NewStreamAdapter(iter, "test-model") + adapter := NewStreamAdapter(iter, "test-model", true) // Read the response resp, err := adapter.Recv() diff --git a/pkg/model/provider/gemini/client.go b/pkg/model/provider/gemini/client.go index c71d2242e..ab77db128 100644 --- a/pkg/model/provider/gemini/client.go +++ b/pkg/model/provider/gemini/client.go @@ -388,7 +388,8 @@ func (c *Client) CreateChatCompletionStream( // Build a fresh client per request when using the gateway iter := client.Models.GenerateContentStream(ctx, c.ModelConfig.Model, contents, config) - return NewStreamAdapter(iter, c.ModelConfig.Model), nil + trackUsage := c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage + return NewStreamAdapter(iter, c.ModelConfig.Model, trackUsage), nil } // Rerank scores documents by relevance to the query using Gemini's structured diff --git a/pkg/model/provider/oaistream/adapter.go b/pkg/model/provider/oaistream/adapter.go index 6af809e75..039599b0e 100644 --- a/pkg/model/provider/oaistream/adapter.go +++ b/pkg/model/provider/oaistream/adapter.go @@ -111,20 +111,23 @@ func (a *StreamAdapter) Recv() (chat.MessageStreamResponse, error) { // Check if Usage field is present using the JSON metadata if openaiResponse.JSON.Usage.Valid() { - usage := openaiResponse.Usage - response.Usage = &chat.Usage{ - InputTokens: int(usage.PromptTokens), - OutputTokens: int(usage.CompletionTokens), - CachedInputTokens: 0, - CachedOutputTokens: 0, - ReasoningTokens: 0, - } - if usage.JSON.PromptTokensDetails.Valid() { - response.Usage.CachedInputTokens = int(usage.PromptTokensDetails.CachedTokens) - } - if usage.JSON.CompletionTokensDetails.Valid() { - response.Usage.ReasoningTokens = int(usage.CompletionTokensDetails.ReasoningTokens) + if a.trackUsage { + usage := openaiResponse.Usage + response.Usage = &chat.Usage{ + InputTokens: int(usage.PromptTokens), + OutputTokens: int(usage.CompletionTokens), + CachedInputTokens: 0, + CachedOutputTokens: 0, + ReasoningTokens: 0, + } + if usage.JSON.PromptTokensDetails.Valid() { + response.Usage.CachedInputTokens = int(usage.PromptTokensDetails.CachedTokens) + } + if usage.JSON.CompletionTokensDetails.Valid() { + response.Usage.ReasoningTokens = int(usage.CompletionTokensDetails.ReasoningTokens) + } } + // Use the tracked finish reason instead of hardcoding stop finishReason := a.lastFinishReason if finishReason == chat.FinishReasonNull || finishReason == "" {