diff --git a/pkg/api/types.go b/pkg/api/types.go index 5a7b7fbee..5b2cbd654 100644 --- a/pkg/api/types.go +++ b/pkg/api/types.go @@ -119,8 +119,8 @@ type SessionsResponse struct { Title string `json:"title"` CreatedAt string `json:"created_at"` NumMessages int `json:"num_messages"` - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` WorkingDir string `json:"working_dir,omitempty"` } @@ -131,8 +131,8 @@ type SessionResponse struct { Messages []session.Message `json:"messages,omitempty"` CreatedAt time.Time `json:"created_at"` ToolsApproved bool `json:"tools_approved"` - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` WorkingDir string `json:"working_dir,omitempty"` Pagination *PaginationMetadata `json:"pagination,omitempty"` } diff --git a/pkg/chat/chat.go b/pkg/chat/chat.go index 096cef216..f12fd2503 100644 --- a/pkg/chat/chat.go +++ b/pkg/chat/chat.go @@ -118,11 +118,11 @@ type MessageStreamResponse struct { } type Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - CachedInputTokens int `json:"cached_input_tokens"` - CachedOutputTokens int `json:"cached_output_tokens"` - ReasoningTokens int `json:"reasoning_tokens,omitempty"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CachedInputTokens int64 `json:"cached_input_tokens"` + CacheWriteTokens int64 `json:"cached_output_tokens"` + ReasoningTokens int64 `json:"reasoning_tokens,omitempty"` } // MessageStream interface represents a stream of chat completions diff --git a/pkg/model/provider/anthropic/adapter.go b/pkg/model/provider/anthropic/adapter.go index 41d562fe3..0730d5c95 100644 --- a/pkg/model/provider/anthropic/adapter.go +++ b/pkg/model/provider/anthropic/adapter.go @@ -151,10 +151,10 @@ func (a *streamAdapter) Recv() (chat.MessageStreamResponse, error) { case anthropic.MessageDeltaEvent: if a.trackUsage { response.Usage = &chat.Usage{ - InputTokens: int(eventVariant.Usage.InputTokens), - OutputTokens: int(eventVariant.Usage.OutputTokens), - CachedInputTokens: int(eventVariant.Usage.CacheReadInputTokens), - CachedOutputTokens: int(eventVariant.Usage.CacheCreationInputTokens), + InputTokens: eventVariant.Usage.InputTokens, + OutputTokens: eventVariant.Usage.OutputTokens, + CachedInputTokens: eventVariant.Usage.CacheReadInputTokens, + CacheWriteTokens: eventVariant.Usage.CacheCreationInputTokens, } } case anthropic.MessageStopEvent: diff --git a/pkg/model/provider/anthropic/beta_adapter.go b/pkg/model/provider/anthropic/beta_adapter.go index 3aaac4c5c..7f80e19ea 100644 --- a/pkg/model/provider/anthropic/beta_adapter.go +++ b/pkg/model/provider/anthropic/beta_adapter.go @@ -112,10 +112,10 @@ func (a *betaStreamAdapter) Recv() (chat.MessageStreamResponse, error) { } case anthropic.BetaRawMessageDeltaEvent: response.Usage = &chat.Usage{ - InputTokens: int(eventVariant.Usage.InputTokens), - OutputTokens: int(eventVariant.Usage.OutputTokens), - CachedInputTokens: int(eventVariant.Usage.CacheReadInputTokens), - CachedOutputTokens: int(eventVariant.Usage.CacheCreationInputTokens), + InputTokens: eventVariant.Usage.InputTokens, + OutputTokens: eventVariant.Usage.OutputTokens, + CachedInputTokens: eventVariant.Usage.CacheReadInputTokens, + CacheWriteTokens: eventVariant.Usage.CacheCreationInputTokens, } case anthropic.BetaRawMessageStopEvent: if a.toolCall { diff --git a/pkg/model/provider/base/base.go b/pkg/model/provider/base/base.go index cb9a7c288..54bffb92e 100644 --- a/pkg/model/provider/base/base.go +++ b/pkg/model/provider/base/base.go @@ -26,15 +26,15 @@ func (c *Config) BaseConfig() Config { // EmbeddingResult contains the embedding and usage information type EmbeddingResult struct { Embedding []float64 - InputTokens int - TotalTokens int + InputTokens int64 + TotalTokens int64 Cost float64 } // BatchEmbeddingResult contains multiple embeddings and usage information type BatchEmbeddingResult struct { Embeddings [][]float64 - InputTokens int - TotalTokens int + InputTokens int64 + TotalTokens int64 Cost float64 } diff --git a/pkg/model/provider/dmr/client.go b/pkg/model/provider/dmr/client.go index 3c0e27fa4..e97936a01 100644 --- a/pkg/model/provider/dmr/client.go +++ b/pkg/model/provider/dmr/client.go @@ -634,8 +634,8 @@ func (c *Client) CreateEmbedding(ctx context.Context, text string) (*base.Embedd copy(embedding, embedding32) // Extract usage information - inputTokens := int(response.Usage.PromptTokens) - totalTokens := int(response.Usage.TotalTokens) + inputTokens := response.Usage.PromptTokens + totalTokens := response.Usage.TotalTokens // DMR is local/free, so cost is 0 cost := 0.0 @@ -657,10 +657,7 @@ func (c *Client) CreateEmbedding(ctx context.Context, text string) (*base.Embedd func (c *Client) CreateBatchEmbedding(ctx context.Context, texts []string) (*base.BatchEmbeddingResult, error) { if len(texts) == 0 { return &base.BatchEmbeddingResult{ - Embeddings: [][]float64{}, - InputTokens: 0, - TotalTokens: 0, - Cost: 0, + Embeddings: [][]float64{}, }, nil } @@ -693,8 +690,8 @@ func (c *Client) CreateBatchEmbedding(ctx context.Context, texts []string) (*bas } // Extract usage information - inputTokens := int(response.Usage.PromptTokens) - totalTokens := int(response.Usage.TotalTokens) + inputTokens := response.Usage.PromptTokens + totalTokens := response.Usage.TotalTokens // DMR is local/free, so cost is 0 cost := 0.0 diff --git a/pkg/model/provider/gemini/adapter.go b/pkg/model/provider/gemini/adapter.go index 335a578d9..0bc8a408d 100644 --- a/pkg/model/provider/gemini/adapter.go +++ b/pkg/model/provider/gemini/adapter.go @@ -177,11 +177,10 @@ func (g *StreamAdapter) Recv() (chat.MessageStreamResponse, error) { // Handle token usage if present if res.resp.UsageMetadata != nil && g.trackUsage { resp.Usage = &chat.Usage{ - InputTokens: int(res.resp.UsageMetadata.PromptTokenCount), - OutputTokens: int(res.resp.UsageMetadata.CandidatesTokenCount), - CachedInputTokens: int(res.resp.UsageMetadata.CachedContentTokenCount), - CachedOutputTokens: 0, // Gemini doesn't provide cached output tokens - ReasoningTokens: int(res.resp.UsageMetadata.ThoughtsTokenCount), + InputTokens: int64(res.resp.UsageMetadata.PromptTokenCount), + OutputTokens: int64(res.resp.UsageMetadata.CandidatesTokenCount), + CachedInputTokens: int64(res.resp.UsageMetadata.CachedContentTokenCount), + ReasoningTokens: int64(res.resp.UsageMetadata.ThoughtsTokenCount), } } diff --git a/pkg/model/provider/oaistream/adapter.go b/pkg/model/provider/oaistream/adapter.go index 039599b0e..3125ac0ed 100644 --- a/pkg/model/provider/oaistream/adapter.go +++ b/pkg/model/provider/oaistream/adapter.go @@ -114,17 +114,15 @@ func (a *StreamAdapter) Recv() (chat.MessageStreamResponse, error) { if a.trackUsage { usage := openaiResponse.Usage response.Usage = &chat.Usage{ - InputTokens: int(usage.PromptTokens), - OutputTokens: int(usage.CompletionTokens), - CachedInputTokens: 0, - CachedOutputTokens: 0, - ReasoningTokens: 0, + InputTokens: usage.PromptTokens, + OutputTokens: usage.CompletionTokens, } if usage.JSON.PromptTokensDetails.Valid() { - response.Usage.CachedInputTokens = int(usage.PromptTokensDetails.CachedTokens) + response.Usage.CachedInputTokens = usage.PromptTokensDetails.CachedTokens + response.Usage.InputTokens -= usage.PromptTokensDetails.CachedTokens } if usage.JSON.CompletionTokensDetails.Valid() { - response.Usage.ReasoningTokens = int(usage.CompletionTokensDetails.ReasoningTokens) + response.Usage.ReasoningTokens = usage.CompletionTokensDetails.ReasoningTokens } } diff --git a/pkg/model/provider/openai/client.go b/pkg/model/provider/openai/client.go index 300b3c209..848ce6621 100644 --- a/pkg/model/provider/openai/client.go +++ b/pkg/model/provider/openai/client.go @@ -657,10 +657,7 @@ func (c *Client) CreateEmbedding(ctx context.Context, text string) (*base.Embedd func (c *Client) CreateBatchEmbedding(ctx context.Context, texts []string) (*base.BatchEmbeddingResult, error) { if len(texts) == 0 { return &base.BatchEmbeddingResult{ - Embeddings: [][]float64{}, - InputTokens: 0, - TotalTokens: 0, - Cost: 0, + Embeddings: [][]float64{}, }, nil } @@ -704,8 +701,8 @@ func (c *Client) CreateBatchEmbedding(ctx context.Context, texts []string) (*bas } // Extract usage information - inputTokens := int(response.Usage.PromptTokens) - totalTokens := int(response.Usage.TotalTokens) + inputTokens := response.Usage.PromptTokens + totalTokens := response.Usage.TotalTokens // Cost calculation is handled at the strategy level using models.dev pricing // Provider just returns token counts diff --git a/pkg/model/provider/openai/response_stream.go b/pkg/model/provider/openai/response_stream.go index 82d39fbb8..459a4aff6 100644 --- a/pkg/model/provider/openai/response_stream.go +++ b/pkg/model/provider/openai/response_stream.go @@ -210,8 +210,9 @@ func (a *ResponseStreamAdapter) Recv() (chat.MessageStreamResponse, error) { u := event.Response.Usage if u.TotalTokens > 0 { response.Usage = &chat.Usage{ - InputTokens: int(u.InputTokens), - OutputTokens: int(u.OutputTokens), + InputTokens: u.InputTokens - u.InputTokensDetails.CachedTokens, + OutputTokens: u.OutputTokens, + CachedInputTokens: u.InputTokensDetails.CachedTokens, } } // Check if there were any tool calls in the output diff --git a/pkg/rag/embed/embed.go b/pkg/rag/embed/embed.go index c0690f10a..f5e00a4a2 100644 --- a/pkg/rag/embed/embed.go +++ b/pkg/rag/embed/embed.go @@ -14,9 +14,9 @@ import ( // Embedder generates vector embeddings for text type Embedder struct { provider provider.Provider - usageHandler func(tokens int, cost float64) // Callback to emit usage events - batchSize int // Batch size for API calls - maxConcurrency int // Maximum concurrent embedding batch requests + usageHandler func(tokens int64, cost float64) // Callback to emit usage events + batchSize int // Batch size for API calls + maxConcurrency int // Maximum concurrent embedding batch requests } // Option is a functional option for configuring the Embedder @@ -52,7 +52,7 @@ func New(p provider.Provider, opts ...Option) *Embedder { } // SetUsageHandler sets a callback to be called after each embedding with usage info -func (e *Embedder) SetUsageHandler(handler func(tokens int, cost float64)) { +func (e *Embedder) SetUsageHandler(handler func(tokens int64, cost float64)) { e.usageHandler = handler } diff --git a/pkg/rag/strategy/vector_store.go b/pkg/rag/strategy/vector_store.go index 0e7c1a491..ef6a4bc5c 100644 --- a/pkg/rag/strategy/vector_store.go +++ b/pkg/rag/strategy/vector_store.go @@ -56,7 +56,7 @@ type VectorStore struct { similarityMetric string - indexingTokens int // Track tokens used during indexing + indexingTokens int64 // Track tokens used during indexing indexingCost float64 modelID string // Full model ID (e.g., "openai/text-embedding-3-small") for pricing lookup @@ -135,7 +135,7 @@ func NewVectorStore(cfg VectorStoreConfig) *VectorStore { // Set usage handler to calculate cost from models.dev and emit events with CUMULATIVE totals // This matches how chat completions calculate cost in runtime.go - cfg.Embedder.SetUsageHandler(func(tokens int, _ float64) { + cfg.Embedder.SetUsageHandler(func(tokens int64, _ float64) { cost := s.calculateCost(context.Background(), tokens) s.recordUsage(tokens, cost) }) @@ -155,7 +155,7 @@ func (s *VectorStore) SetEmbeddingInputBuilder(builder EmbeddingInputBuilder) { } // calculateCost calculates embedding cost using models.dev pricing -func (s *VectorStore) calculateCost(ctx context.Context, tokens int) float64 { +func (s *VectorStore) calculateCost(ctx context.Context, tokens int64) float64 { if s.modelsStore == nil || strings.HasPrefix(s.modelID, "dmr/") { return 0 } @@ -179,11 +179,11 @@ func (s *VectorStore) calculateCost(ctx context.Context, tokens int) float64 { // RecordUsage records usage and emits a usage event with cumulative totals. // This is exported so strategies can track additional usage (e.g., semantic LLM calls). -func (s *VectorStore) RecordUsage(tokens int, cost float64) { +func (s *VectorStore) RecordUsage(tokens int64, cost float64) { s.recordUsage(tokens, cost) } -func (s *VectorStore) recordUsage(tokens int, cost float64) { +func (s *VectorStore) recordUsage(tokens int64, cost float64) { if tokens == 0 && cost == 0 { return } @@ -460,7 +460,7 @@ func (s *VectorStore) Close() error { } // GetIndexingUsage returns usage statistics from indexing -func (s *VectorStore) GetIndexingUsage() (tokens int, cost float64) { +func (s *VectorStore) GetIndexingUsage() (tokens int64, cost float64) { return s.indexingTokens, s.indexingCost } diff --git a/pkg/rag/types/types.go b/pkg/rag/types/types.go index dd35f14e4..0eeaad03b 100644 --- a/pkg/rag/types/types.go +++ b/pkg/rag/types/types.go @@ -18,7 +18,7 @@ type Event struct { Message string Progress *Progress Error error - TotalTokens int // For usage events + TotalTokens int64 // For usage events Cost float64 // For usage events } diff --git a/pkg/runtime/event.go b/pkg/runtime/event.go index 15409ff3f..0e5d25350 100644 --- a/pkg/runtime/event.go +++ b/pkg/runtime/event.go @@ -190,14 +190,14 @@ type TokenUsageEvent struct { } type Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - ContextLength int `json:"context_length"` - ContextLimit int `json:"context_limit"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + ContextLength int64 `json:"context_length"` + ContextLimit int64 `json:"context_limit"` Cost float64 `json:"cost"` } -func TokenUsage(sessionID, agentName string, inputTokens, outputTokens, contextLength, contextLimit int, cost float64) Event { +func TokenUsage(sessionID, agentName string, inputTokens, outputTokens, contextLength, contextLimit int64, cost float64) Event { return &TokenUsageEvent{ Type: "token_usage", SessionID: sessionID, diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 988a1c445..fb4f856ca 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -710,14 +710,14 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c slog.Debug("Skipping empty assistant message (no content and no tool calls)", "agent", a.Name()) } - contextLimit := 0 + var contextLimit int64 if m != nil { - contextLimit = m.Limit.Context + contextLimit = int64(m.Limit.Context) } events <- TokenUsage(sess.ID, r.currentAgent, sess.InputTokens, sess.OutputTokens, sess.InputTokens+sess.OutputTokens, contextLimit, sess.Cost) if m != nil && r.sessionCompaction { - if sess.InputTokens+sess.OutputTokens > int(float64(contextLimit)*0.9) { + if sess.InputTokens+sess.OutputTokens > int64(float64(contextLimit)*0.9) { // Avoid inserting a summary between assistant tool_use and tool_result messages. // Defer compaction until after tool calls are processed in this iteration. if len(res.Calls) == 0 { @@ -734,7 +734,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c // If tool_use occurred, perform compaction after tool results are appended // to avoid splitting assistant tool_use and user tool_result adjacency. if m != nil && r.sessionCompaction && len(res.Calls) > 0 { - if sess.InputTokens+sess.OutputTokens > int(float64(contextLimit)*0.9) { + if sess.InputTokens+sess.OutputTokens > int64(float64(contextLimit)*0.9) { events <- SessionCompaction(sess.ID, "start", r.currentAgent) r.Summarize(ctx, sess, events) events <- TokenUsage(sess.ID, r.currentAgent, sess.InputTokens, sess.OutputTokens, sess.InputTokens+sess.OutputTokens, contextLimit, sess.Cost) @@ -879,20 +879,21 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre if response.Usage != nil { if m != nil { - sess.Cost += (float64(response.Usage.InputTokens)*m.Cost.Input + - float64(response.Usage.OutputTokens+response.Usage.ReasoningTokens)*m.Cost.Output + + cost := float64(response.Usage.InputTokens)*m.Cost.Input + + float64(response.Usage.OutputTokens)*m.Cost.Output + float64(response.Usage.CachedInputTokens)*m.Cost.CacheRead + - float64(response.Usage.CachedOutputTokens)*m.Cost.CacheWrite) / 1e6 + float64(response.Usage.CacheWriteTokens)*m.Cost.CacheWrite + sess.Cost += cost / 1e6 } - sess.InputTokens = response.Usage.InputTokens + response.Usage.CachedInputTokens - sess.OutputTokens = response.Usage.OutputTokens + response.Usage.CachedOutputTokens + response.Usage.ReasoningTokens + sess.InputTokens = response.Usage.InputTokens + response.Usage.CachedInputTokens + response.Usage.CacheWriteTokens + sess.OutputTokens = response.Usage.OutputTokens modelName := "unknown" if m != nil { modelName = m.Name } - telemetry.RecordTokenUsage(ctx, modelName, int64(response.Usage.InputTokens), int64(response.Usage.OutputTokens+response.Usage.ReasoningTokens), sess.Cost) + telemetry.RecordTokenUsage(ctx, modelName, sess.InputTokens, sess.OutputTokens, sess.Cost) } if len(response.Choices) == 0 { diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index 726f17c2c..b633511bf 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -125,7 +125,7 @@ func (b *streamBuilder) AddToolCallArguments(id, argsChunk string) *streamBuilde return b } -func (b *streamBuilder) AddStopWithUsage(input, output int) *streamBuilder { +func (b *streamBuilder) AddStopWithUsage(input, output int64) *streamBuilder { b.responses = append(b.responses, chat.MessageStreamResponse{ Choices: []chat.MessageStreamChoice{{ Index: 0, diff --git a/pkg/session/session.go b/pkg/session/session.go index bf0a94d79..f24fb7bc9 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -65,8 +65,8 @@ type Session struct { // If 0, there is no limit MaxIterations int `json:"max_iterations"` - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` Cost float64 `json:"cost"` } diff --git a/pkg/session/store.go b/pkg/session/store.go index 8141fbf00..1cbd0b565 100644 --- a/pkg/session/store.go +++ b/pkg/session/store.go @@ -136,12 +136,12 @@ func (s *SQLiteSessionStore) GetSession(ctx context.Context, id string) (*Sessio return nil, err } - inputTokens, err := strconv.Atoi(inputTokensStr) + inputTokens, err := strconv.ParseInt(inputTokensStr, 10, 64) if err != nil { return nil, err } - outputTokens, err := strconv.Atoi(outputTokensStr) + outputTokens, err := strconv.ParseInt(outputTokensStr, 10, 64) if err != nil { return nil, err } @@ -223,12 +223,12 @@ func (s *SQLiteSessionStore) GetSessions(ctx context.Context) ([]*Session, error return nil, err } - inputTokens, err := strconv.Atoi(inputTokensStr) + inputTokens, err := strconv.ParseInt(inputTokensStr, 10, 64) if err != nil { return nil, err } - outputTokens, err := strconv.Atoi(outputTokensStr) + outputTokens, err := strconv.ParseInt(outputTokensStr, 10, 64) if err != nil { return nil, err } diff --git a/pkg/tui/components/sidebar/sidebar.go b/pkg/tui/components/sidebar/sidebar.go index 914ff105b..cc1886241 100644 --- a/pkg/tui/components/sidebar/sidebar.go +++ b/pkg/tui/components/sidebar/sidebar.go @@ -134,7 +134,7 @@ func (m *model) SetToolsetInfo(availableTools int) { } // formatTokenCount formats a token count with K/M suffixes for readability -func formatTokenCount(count int) string { +func formatTokenCount(count int64) string { if count >= 1000000 { return fmt.Sprintf("%.1fM", float64(count)/1000000) } else if count >= 1000 { @@ -477,7 +477,7 @@ func (m *model) tokenUsage() string { return "" } - var totalTokens int + var totalTokens int64 var totalCost float64 for _, usage := range m.sessionUsage { totalTokens += usage.InputTokens + usage.OutputTokens @@ -504,7 +504,7 @@ func (m *model) tokenUsageSummary() string { return "" } - var totalTokens int + var totalTokens int64 var totalCost float64 for _, usage := range m.sessionUsage { totalTokens += usage.InputTokens + usage.OutputTokens