Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pkg/api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ type SessionsResponse struct {
Title string `json:"title"`
CreatedAt string `json:"created_at"`
NumMessages int `json:"num_messages"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
WorkingDir string `json:"working_dir,omitempty"`
}

Expand All @@ -131,8 +131,8 @@ type SessionResponse struct {
Messages []session.Message `json:"messages,omitempty"`
CreatedAt time.Time `json:"created_at"`
ToolsApproved bool `json:"tools_approved"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
WorkingDir string `json:"working_dir,omitempty"`
Pagination *PaginationMetadata `json:"pagination,omitempty"`
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/chat/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,11 @@ type MessageStreamResponse struct {
}

type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CachedInputTokens int `json:"cached_input_tokens"`
CachedOutputTokens int `json:"cached_output_tokens"`
ReasoningTokens int `json:"reasoning_tokens,omitempty"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CachedInputTokens int64 `json:"cached_input_tokens"`
CacheWriteTokens int64 `json:"cached_output_tokens"`
ReasoningTokens int64 `json:"reasoning_tokens,omitempty"`
}

// MessageStream interface represents a stream of chat completions
Expand Down
8 changes: 4 additions & 4 deletions pkg/model/provider/anthropic/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,10 @@ func (a *streamAdapter) Recv() (chat.MessageStreamResponse, error) {
case anthropic.MessageDeltaEvent:
if a.trackUsage {
response.Usage = &chat.Usage{
InputTokens: int(eventVariant.Usage.InputTokens),
OutputTokens: int(eventVariant.Usage.OutputTokens),
CachedInputTokens: int(eventVariant.Usage.CacheReadInputTokens),
CachedOutputTokens: int(eventVariant.Usage.CacheCreationInputTokens),
InputTokens: eventVariant.Usage.InputTokens,
OutputTokens: eventVariant.Usage.OutputTokens,
CachedInputTokens: eventVariant.Usage.CacheReadInputTokens,
CacheWriteTokens: eventVariant.Usage.CacheCreationInputTokens,
}
}
case anthropic.MessageStopEvent:
Expand Down
8 changes: 4 additions & 4 deletions pkg/model/provider/anthropic/beta_adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ func (a *betaStreamAdapter) Recv() (chat.MessageStreamResponse, error) {
}
case anthropic.BetaRawMessageDeltaEvent:
response.Usage = &chat.Usage{
InputTokens: int(eventVariant.Usage.InputTokens),
OutputTokens: int(eventVariant.Usage.OutputTokens),
CachedInputTokens: int(eventVariant.Usage.CacheReadInputTokens),
CachedOutputTokens: int(eventVariant.Usage.CacheCreationInputTokens),
InputTokens: eventVariant.Usage.InputTokens,
OutputTokens: eventVariant.Usage.OutputTokens,
CachedInputTokens: eventVariant.Usage.CacheReadInputTokens,
CacheWriteTokens: eventVariant.Usage.CacheCreationInputTokens,
}
case anthropic.BetaRawMessageStopEvent:
if a.toolCall {
Expand Down
8 changes: 4 additions & 4 deletions pkg/model/provider/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ func (c *Config) BaseConfig() Config {
// EmbeddingResult contains the embedding and usage information
type EmbeddingResult struct {
Embedding []float64
InputTokens int
TotalTokens int
InputTokens int64
TotalTokens int64
Cost float64
}

// BatchEmbeddingResult contains multiple embeddings and usage information
type BatchEmbeddingResult struct {
Embeddings [][]float64
InputTokens int
TotalTokens int
InputTokens int64
TotalTokens int64
Cost float64
}
13 changes: 5 additions & 8 deletions pkg/model/provider/dmr/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -634,8 +634,8 @@ func (c *Client) CreateEmbedding(ctx context.Context, text string) (*base.Embedd
copy(embedding, embedding32)

// Extract usage information
inputTokens := int(response.Usage.PromptTokens)
totalTokens := int(response.Usage.TotalTokens)
inputTokens := response.Usage.PromptTokens
totalTokens := response.Usage.TotalTokens

// DMR is local/free, so cost is 0
cost := 0.0
Expand All @@ -657,10 +657,7 @@ func (c *Client) CreateEmbedding(ctx context.Context, text string) (*base.Embedd
func (c *Client) CreateBatchEmbedding(ctx context.Context, texts []string) (*base.BatchEmbeddingResult, error) {
if len(texts) == 0 {
return &base.BatchEmbeddingResult{
Embeddings: [][]float64{},
InputTokens: 0,
TotalTokens: 0,
Cost: 0,
Embeddings: [][]float64{},
}, nil
}

Expand Down Expand Up @@ -693,8 +690,8 @@ func (c *Client) CreateBatchEmbedding(ctx context.Context, texts []string) (*bas
}

// Extract usage information
inputTokens := int(response.Usage.PromptTokens)
totalTokens := int(response.Usage.TotalTokens)
inputTokens := response.Usage.PromptTokens
totalTokens := response.Usage.TotalTokens

// DMR is local/free, so cost is 0
cost := 0.0
Expand Down
9 changes: 4 additions & 5 deletions pkg/model/provider/gemini/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,10 @@ func (g *StreamAdapter) Recv() (chat.MessageStreamResponse, error) {
// Handle token usage if present
if res.resp.UsageMetadata != nil && g.trackUsage {
resp.Usage = &chat.Usage{
InputTokens: int(res.resp.UsageMetadata.PromptTokenCount),
OutputTokens: int(res.resp.UsageMetadata.CandidatesTokenCount),
CachedInputTokens: int(res.resp.UsageMetadata.CachedContentTokenCount),
CachedOutputTokens: 0, // Gemini doesn't provide cached output tokens
ReasoningTokens: int(res.resp.UsageMetadata.ThoughtsTokenCount),
InputTokens: int64(res.resp.UsageMetadata.PromptTokenCount),
OutputTokens: int64(res.resp.UsageMetadata.CandidatesTokenCount),
CachedInputTokens: int64(res.resp.UsageMetadata.CachedContentTokenCount),
ReasoningTokens: int64(res.resp.UsageMetadata.ThoughtsTokenCount),
}
}

Expand Down
12 changes: 5 additions & 7 deletions pkg/model/provider/oaistream/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,15 @@ func (a *StreamAdapter) Recv() (chat.MessageStreamResponse, error) {
if a.trackUsage {
usage := openaiResponse.Usage
response.Usage = &chat.Usage{
InputTokens: int(usage.PromptTokens),
OutputTokens: int(usage.CompletionTokens),
CachedInputTokens: 0,
CachedOutputTokens: 0,
ReasoningTokens: 0,
InputTokens: usage.PromptTokens,
OutputTokens: usage.CompletionTokens,
}
if usage.JSON.PromptTokensDetails.Valid() {
response.Usage.CachedInputTokens = int(usage.PromptTokensDetails.CachedTokens)
response.Usage.CachedInputTokens = usage.PromptTokensDetails.CachedTokens
response.Usage.InputTokens -= usage.PromptTokensDetails.CachedTokens
}
if usage.JSON.CompletionTokensDetails.Valid() {
response.Usage.ReasoningTokens = int(usage.CompletionTokensDetails.ReasoningTokens)
response.Usage.ReasoningTokens = usage.CompletionTokensDetails.ReasoningTokens
}
}

Expand Down
9 changes: 3 additions & 6 deletions pkg/model/provider/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -657,10 +657,7 @@ func (c *Client) CreateEmbedding(ctx context.Context, text string) (*base.Embedd
func (c *Client) CreateBatchEmbedding(ctx context.Context, texts []string) (*base.BatchEmbeddingResult, error) {
if len(texts) == 0 {
return &base.BatchEmbeddingResult{
Embeddings: [][]float64{},
InputTokens: 0,
TotalTokens: 0,
Cost: 0,
Embeddings: [][]float64{},
}, nil
}

Expand Down Expand Up @@ -704,8 +701,8 @@ func (c *Client) CreateBatchEmbedding(ctx context.Context, texts []string) (*bas
}

// Extract usage information
inputTokens := int(response.Usage.PromptTokens)
totalTokens := int(response.Usage.TotalTokens)
inputTokens := response.Usage.PromptTokens
totalTokens := response.Usage.TotalTokens

// Cost calculation is handled at the strategy level using models.dev pricing
// Provider just returns token counts
Expand Down
5 changes: 3 additions & 2 deletions pkg/model/provider/openai/response_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,9 @@ func (a *ResponseStreamAdapter) Recv() (chat.MessageStreamResponse, error) {
u := event.Response.Usage
if u.TotalTokens > 0 {
response.Usage = &chat.Usage{
InputTokens: int(u.InputTokens),
OutputTokens: int(u.OutputTokens),
InputTokens: u.InputTokens - u.InputTokensDetails.CachedTokens,
OutputTokens: u.OutputTokens,
CachedInputTokens: u.InputTokensDetails.CachedTokens,
}
}
// Check if there were any tool calls in the output
Expand Down
8 changes: 4 additions & 4 deletions pkg/rag/embed/embed.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ import (
// Embedder generates vector embeddings for text
type Embedder struct {
provider provider.Provider
usageHandler func(tokens int, cost float64) // Callback to emit usage events
batchSize int // Batch size for API calls
maxConcurrency int // Maximum concurrent embedding batch requests
usageHandler func(tokens int64, cost float64) // Callback to emit usage events
batchSize int // Batch size for API calls
maxConcurrency int // Maximum concurrent embedding batch requests
}

// Option is a functional option for configuring the Embedder
Expand Down Expand Up @@ -52,7 +52,7 @@ func New(p provider.Provider, opts ...Option) *Embedder {
}

// SetUsageHandler sets a callback to be called after each embedding with usage info
func (e *Embedder) SetUsageHandler(handler func(tokens int, cost float64)) {
func (e *Embedder) SetUsageHandler(handler func(tokens int64, cost float64)) {
e.usageHandler = handler
}

Expand Down
12 changes: 6 additions & 6 deletions pkg/rag/strategy/vector_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ type VectorStore struct {

similarityMetric string

indexingTokens int // Track tokens used during indexing
indexingTokens int64 // Track tokens used during indexing
indexingCost float64

modelID string // Full model ID (e.g., "openai/text-embedding-3-small") for pricing lookup
Expand Down Expand Up @@ -135,7 +135,7 @@ func NewVectorStore(cfg VectorStoreConfig) *VectorStore {

// Set usage handler to calculate cost from models.dev and emit events with CUMULATIVE totals
// This matches how chat completions calculate cost in runtime.go
cfg.Embedder.SetUsageHandler(func(tokens int, _ float64) {
cfg.Embedder.SetUsageHandler(func(tokens int64, _ float64) {
cost := s.calculateCost(context.Background(), tokens)
s.recordUsage(tokens, cost)
})
Expand All @@ -155,7 +155,7 @@ func (s *VectorStore) SetEmbeddingInputBuilder(builder EmbeddingInputBuilder) {
}

// calculateCost calculates embedding cost using models.dev pricing
func (s *VectorStore) calculateCost(ctx context.Context, tokens int) float64 {
func (s *VectorStore) calculateCost(ctx context.Context, tokens int64) float64 {
if s.modelsStore == nil || strings.HasPrefix(s.modelID, "dmr/") {
return 0
}
Expand All @@ -179,11 +179,11 @@ func (s *VectorStore) calculateCost(ctx context.Context, tokens int) float64 {

// RecordUsage records usage and emits a usage event with cumulative totals.
// This is exported so strategies can track additional usage (e.g., semantic LLM calls).
func (s *VectorStore) RecordUsage(tokens int, cost float64) {
func (s *VectorStore) RecordUsage(tokens int64, cost float64) {
s.recordUsage(tokens, cost)
}

func (s *VectorStore) recordUsage(tokens int, cost float64) {
func (s *VectorStore) recordUsage(tokens int64, cost float64) {
if tokens == 0 && cost == 0 {
return
}
Expand Down Expand Up @@ -460,7 +460,7 @@ func (s *VectorStore) Close() error {
}

// GetIndexingUsage returns usage statistics from indexing
func (s *VectorStore) GetIndexingUsage() (tokens int, cost float64) {
func (s *VectorStore) GetIndexingUsage() (tokens int64, cost float64) {
return s.indexingTokens, s.indexingCost
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/rag/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type Event struct {
Message string
Progress *Progress
Error error
TotalTokens int // For usage events
TotalTokens int64 // For usage events
Cost float64 // For usage events
}

Expand Down
10 changes: 5 additions & 5 deletions pkg/runtime/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,14 @@ type TokenUsageEvent struct {
}

type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
ContextLength int `json:"context_length"`
ContextLimit int `json:"context_limit"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
ContextLength int64 `json:"context_length"`
ContextLimit int64 `json:"context_limit"`
Cost float64 `json:"cost"`
}

func TokenUsage(sessionID, agentName string, inputTokens, outputTokens, contextLength, contextLimit int, cost float64) Event {
func TokenUsage(sessionID, agentName string, inputTokens, outputTokens, contextLength, contextLimit int64, cost float64) Event {
return &TokenUsageEvent{
Type: "token_usage",
SessionID: sessionID,
Expand Down
21 changes: 11 additions & 10 deletions pkg/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -710,14 +710,14 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
slog.Debug("Skipping empty assistant message (no content and no tool calls)", "agent", a.Name())
}

contextLimit := 0
var contextLimit int64
if m != nil {
contextLimit = m.Limit.Context
contextLimit = int64(m.Limit.Context)
}
events <- TokenUsage(sess.ID, r.currentAgent, sess.InputTokens, sess.OutputTokens, sess.InputTokens+sess.OutputTokens, contextLimit, sess.Cost)

if m != nil && r.sessionCompaction {
if sess.InputTokens+sess.OutputTokens > int(float64(contextLimit)*0.9) {
if sess.InputTokens+sess.OutputTokens > int64(float64(contextLimit)*0.9) {
// Avoid inserting a summary between assistant tool_use and tool_result messages.
// Defer compaction until after tool calls are processed in this iteration.
if len(res.Calls) == 0 {
Expand All @@ -734,7 +734,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
// If tool_use occurred, perform compaction after tool results are appended
// to avoid splitting assistant tool_use and user tool_result adjacency.
if m != nil && r.sessionCompaction && len(res.Calls) > 0 {
if sess.InputTokens+sess.OutputTokens > int(float64(contextLimit)*0.9) {
if sess.InputTokens+sess.OutputTokens > int64(float64(contextLimit)*0.9) {
events <- SessionCompaction(sess.ID, "start", r.currentAgent)
r.Summarize(ctx, sess, events)
events <- TokenUsage(sess.ID, r.currentAgent, sess.InputTokens, sess.OutputTokens, sess.InputTokens+sess.OutputTokens, contextLimit, sess.Cost)
Expand Down Expand Up @@ -879,20 +879,21 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre

if response.Usage != nil {
if m != nil {
sess.Cost += (float64(response.Usage.InputTokens)*m.Cost.Input +
float64(response.Usage.OutputTokens+response.Usage.ReasoningTokens)*m.Cost.Output +
cost := float64(response.Usage.InputTokens)*m.Cost.Input +
float64(response.Usage.OutputTokens)*m.Cost.Output +
float64(response.Usage.CachedInputTokens)*m.Cost.CacheRead +
float64(response.Usage.CachedOutputTokens)*m.Cost.CacheWrite) / 1e6
float64(response.Usage.CacheWriteTokens)*m.Cost.CacheWrite
sess.Cost += cost / 1e6
}

sess.InputTokens = response.Usage.InputTokens + response.Usage.CachedInputTokens
sess.OutputTokens = response.Usage.OutputTokens + response.Usage.CachedOutputTokens + response.Usage.ReasoningTokens
sess.InputTokens = response.Usage.InputTokens + response.Usage.CachedInputTokens + response.Usage.CacheWriteTokens
sess.OutputTokens = response.Usage.OutputTokens

modelName := "unknown"
if m != nil {
modelName = m.Name
}
telemetry.RecordTokenUsage(ctx, modelName, int64(response.Usage.InputTokens), int64(response.Usage.OutputTokens+response.Usage.ReasoningTokens), sess.Cost)
telemetry.RecordTokenUsage(ctx, modelName, sess.InputTokens, sess.OutputTokens, sess.Cost)
}

if len(response.Choices) == 0 {
Expand Down
2 changes: 1 addition & 1 deletion pkg/runtime/runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func (b *streamBuilder) AddToolCallArguments(id, argsChunk string) *streamBuilde
return b
}

func (b *streamBuilder) AddStopWithUsage(input, output int) *streamBuilder {
func (b *streamBuilder) AddStopWithUsage(input, output int64) *streamBuilder {
b.responses = append(b.responses, chat.MessageStreamResponse{
Choices: []chat.MessageStreamChoice{{
Index: 0,
Expand Down
4 changes: 2 additions & 2 deletions pkg/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ type Session struct {
// If 0, there is no limit
MaxIterations int `json:"max_iterations"`

InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
Cost float64 `json:"cost"`
}

Expand Down
Loading