From b9f122d21ad6cd44c2cba731bf2fd0439083a435 Mon Sep 17 00:00:00 2001 From: Omar Alani Date: Tue, 12 May 2026 21:17:22 -0500 Subject: [PATCH 1/3] feat(terminal): add context token status --- internal/assistant/anthropic.go | 5 +- internal/assistant/client.go | 13 +- internal/assistant/openai_chat.go | 5 +- internal/assistant/openai_responses.go | 7 +- internal/assistant/runtime.go | 69 ++++++++--- internal/assistant/runtime_test.go | 3 + internal/assistant/sse.go | 38 +++++- internal/assistant/tool_loop.go | 14 ++- internal/assistant/usage.go | 123 +++++++++++++++++++ internal/assistant/usage_events.go | 31 +++++ internal/assistant/usage_test.go | 107 ++++++++++++++++ internal/model/usage.go | 37 ++++++ internal/terminal/app.go | 9 +- internal/terminal/async_events.go | 58 ++++++--- internal/terminal/auth_commands.go | 9 +- internal/terminal/input.go | 8 +- internal/terminal/render_parity_test.go | 23 ++++ internal/terminal/render_test.go | 43 ++++++- internal/terminal/runtime_buffers.go | 3 + internal/terminal/token_usage.go | 73 +++++++++++ internal/terminal/token_usage_export_test.go | 7 ++ internal/terminal/token_usage_test.go | 38 ++++++ 22 files changed, 667 insertions(+), 56 deletions(-) create mode 100644 internal/assistant/usage.go create mode 100644 internal/assistant/usage_events.go create mode 100644 internal/assistant/usage_test.go create mode 100644 internal/model/usage.go create mode 100644 internal/terminal/token_usage.go create mode 100644 internal/terminal/token_usage_export_test.go create mode 100644 internal/terminal/token_usage_test.go diff --git a/internal/assistant/anthropic.go b/internal/assistant/anthropic.go index 73bfda2..bfc83b3 100644 --- a/internal/assistant/anthropic.go +++ b/internal/assistant/anthropic.go @@ -23,7 +23,8 @@ func (client *HTTPCompletionClient) completeAnthropic( return nil, err } var response struct { - Error providerError `json:"error"` + Error providerError `json:"error"` + Usage map[string]any `json:"usage"` Content []struct { Type string `json:"type"` Text string `json:"text"` @@ -46,7 +47,7 @@ func (client *HTTPCompletionClient) completeAnthropic( return nil, oops.In("assistant").Code("anthropic_empty").Errorf("provider returned an empty response") } - return textCompletionResult(text), nil + return textCompletionResult(text, usageFromObject(response.Usage)), nil } func anthropicPayload(request *CompletionRequest) map[string]any { diff --git a/internal/assistant/client.go b/internal/assistant/client.go index f5359fc..f1b0a37 100644 --- a/internal/assistant/client.go +++ b/internal/assistant/client.go @@ -33,6 +33,7 @@ const ( jsonToolParamsKey = "parameters" jsonCallIDKey = "call_id" jsonOutputKey = "output" + jsonOutputTokensKey = "output_tokens" jsonToolChoiceKey = "tool_choice" jsonTextKey = "text" jsonThinkingKey = "thinking" @@ -64,9 +65,10 @@ type CompletionRequest struct { // CompletionResult is a provider response plus model-visible side effects. type CompletionResult struct { - Text string `json:"text"` - Thinking []string `json:"thinking,omitempty"` - ToolEvents []ToolEvent `json:"tool_events,omitempty"` + Text string `json:"text"` + Thinking []string `json:"thinking,omitempty"` + ToolEvents []ToolEvent `json:"tool_events,omitempty"` + Usage model.TokenUsage `json:"usage,omitempty"` } // ToolEvent captures one tool call for persistence and TUI rendering. @@ -95,6 +97,7 @@ type providerResult struct { OutputItems []any Thinking []string ToolCalls []toolCall + Usage model.TokenUsage } // HTTPCompletionClient is a small provider client for built-in API families. @@ -133,6 +136,6 @@ func (client *HTTPCompletionClient) Complete( } } -func textCompletionResult(text string) *CompletionResult { - return &CompletionResult{Text: strings.TrimSpace(text), Thinking: nil, ToolEvents: nil} +func textCompletionResult(text string, usage model.TokenUsage) *CompletionResult { + return &CompletionResult{Text: strings.TrimSpace(text), Thinking: nil, ToolEvents: nil, Usage: usage} } diff --git a/internal/assistant/openai_chat.go b/internal/assistant/openai_chat.go index b937d66..8dafea0 100644 --- a/internal/assistant/openai_chat.go +++ b/internal/assistant/openai_chat.go @@ -29,7 +29,8 @@ func (client *HTTPCompletionClient) completeOpenAIChat( return nil, err } var response struct { - Error providerError `json:"error"` + Error providerError `json:"error"` + Usage map[string]any `json:"usage"` Choices []struct { Message struct { Content string `json:"content"` @@ -46,7 +47,7 @@ func (client *HTTPCompletionClient) completeOpenAIChat( return nil, oops.In("assistant").Code("openai_chat_empty").Errorf("provider returned an empty response") } - return textCompletionResult(response.Choices[0].Message.Content), nil + return textCompletionResult(response.Choices[0].Message.Content, usageFromObject(response.Usage)), nil } func openAIChatMessages(request *CompletionRequest) []map[string]string { diff --git a/internal/assistant/openai_responses.go b/internal/assistant/openai_responses.go index 689a27f..f206790 100644 --- a/internal/assistant/openai_responses.go +++ b/internal/assistant/openai_responses.go @@ -6,6 +6,8 @@ import ( "strings" "github.com/samber/oops" + + "github.com/omarluq/librecode/internal/model" ) func (client *HTTPCompletionClient) completeOpenAIResponses( @@ -36,7 +38,7 @@ func (client *HTTPCompletionClient) completeResponsesLoop( input []any, stream bool, ) (*CompletionResult, error) { - result := &CompletionResult{Text: "", Thinking: nil, ToolEvents: nil} + result := &CompletionResult{Text: "", Thinking: nil, ToolEvents: nil, Usage: model.EmptyTokenUsage()} for { payload := responsesPayload(request, input, stream) providerResult, err := client.requestResponses(ctx, endpoint, headers, payload, stream, request.OnEvent) @@ -44,6 +46,7 @@ func (client *HTTPCompletionClient) completeResponsesLoop( return nil, err } result.Thinking = append(result.Thinking, providerResult.Thinking...) + result.Usage = mergeUsage(result.Usage, providerResult.Usage) if err := validateToolCalls(providerResult.ToolCalls); err != nil { return nil, err } @@ -181,6 +184,7 @@ func providerResultFromResponse(response map[string]any) *providerResult { OutputItems: outputItems, Thinking: thinkingFromOutput(outputItems), ToolCalls: toolCallsFromOutput(outputItems), + Usage: usageFromObject(response["usage"]), } } @@ -195,6 +199,7 @@ func providerResultFromOutputItems(outputItems []any, fallbackText string) *prov OutputItems: outputItems, Thinking: thinkingFromOutput(outputItems), ToolCalls: toolCallsFromOutput(outputItems), + Usage: model.EmptyTokenUsage(), } } diff --git a/internal/assistant/runtime.go b/internal/assistant/runtime.go index f7aa761..62a6f15 100644 --- a/internal/assistant/runtime.go +++ b/internal/assistant/runtime.go @@ -70,30 +70,35 @@ const ( StreamEventToolResult StreamEventKind = "tool_result" // StreamEventSkillLoaded carries an explicitly loaded Agent Skill. StreamEventSkillLoaded StreamEventKind = "skill_loaded" + // StreamEventUsage carries estimated or provider-reported token usage. + StreamEventUsage StreamEventKind = "usage" ) // StreamEvent is emitted during prompt execution before final persistence. type StreamEvent struct { - ToolEvent *ToolEvent `json:"tool_event,omitempty"` - Kind StreamEventKind `json:"kind"` - Text string `json:"text,omitempty"` + ToolEvent *ToolEvent `json:"tool_event,omitempty"` + Usage *model.TokenUsage `json:"usage,omitempty"` + Kind StreamEventKind `json:"kind"` + Text string `json:"text,omitempty"` } // PromptResponse describes persisted prompt output. type PromptResponse struct { - SessionID string `json:"session_id"` - UserEntryID string `json:"user_entry_id"` - AssistantEntryID string `json:"assistant_entry_id"` - Text string `json:"text"` - Thinking []string `json:"thinking,omitempty"` - ToolEvents []ToolEvent `json:"tool_events,omitempty"` - Cached bool `json:"cached"` + SessionID string `json:"session_id"` + UserEntryID string `json:"user_entry_id"` + AssistantEntryID string `json:"assistant_entry_id"` + Text string `json:"text"` + Thinking []string `json:"thinking,omitempty"` + ToolEvents []ToolEvent `json:"tool_events,omitempty"` + Usage model.TokenUsage `json:"usage,omitempty"` + Cached bool `json:"cached"` } type responseBundle struct { Text string Thinking []string ToolEvents []ToolEvent + Usage model.TokenUsage } // NewRuntime creates an assistant runtime. @@ -197,6 +202,7 @@ func (runtime *Runtime) Prompt(ctx context.Context, request *PromptRequest) (*Pr Text: bundle.Text, Thinking: bundle.Thinking, ToolEvents: bundle.ToolEvents, + Usage: bundle.Usage, Cached: cached, }, nil } @@ -368,7 +374,12 @@ func (runtime *Runtime) respond( ) { if strings.HasPrefix(prompt, slashPrefix) { slashResponse, slashToolEvents, slashErr := runtime.respondToSlashCommand(ctx, cwd, prompt, onEvent) - return &responseBundle{Text: slashResponse, Thinking: nil, ToolEvents: slashToolEvents}, false, slashErr + return &responseBundle{ + Text: slashResponse, + Thinking: nil, + ToolEvents: slashToolEvents, + Usage: model.EmptyTokenUsage(), + }, false, slashErr } cacheKey := runtime.cacheKey(sessionID, prompt) @@ -377,7 +388,12 @@ func (runtime *Runtime) respond( return nil, false, oops.In("assistant").Code("cache_get").Wrapf(err, "read response cache") } if found { - return &responseBundle{Text: cachedResponse, Thinking: nil, ToolEvents: nil}, true, nil + return &responseBundle{ + Text: cachedResponse, + Thinking: nil, + ToolEvents: nil, + Usage: model.EmptyTokenUsage(), + }, true, nil } bundle, err = runtime.modelResponse(ctx, sessionID, cwd, prompt, onEvent, onRetry) @@ -449,7 +465,12 @@ func (runtime *Runtime) respondToSkillCommand( if err != nil { return "", nil, err } - emitStreamEvent(onEvent, StreamEvent{ToolEvent: &toolEvent, Kind: StreamEventSkillLoaded, Text: skill.Name}) + emitStreamEvent(onEvent, StreamEvent{ + ToolEvent: &toolEvent, + Usage: nil, + Kind: StreamEventSkillLoaded, + Text: skill.Name, + }) return result, []ToolEvent{toolEvent}, nil } @@ -569,10 +590,13 @@ func (runtime *Runtime) modelResponse( } } + messages := messageEntities(sessionMessages) + estimatedUsage := estimateTokenUsage(systemPrompt, messages, &selectedModel) + runtime.emitUsage(ctx, onEvent, estimatedUsage) request := &CompletionRequest{ Model: selectedModel, Auth: auth, - Messages: messageEntities(sessionMessages), + Messages: messages, SessionID: sessionID, SystemPrompt: systemPrompt, ThinkingLevel: runtime.cfg.Assistant.ThinkingLevel, @@ -583,7 +607,15 @@ func (runtime *Runtime) modelResponse( if err != nil { return nil, err } - return &responseBundle{Text: result.Text, Thinking: result.Thinking, ToolEvents: result.ToolEvents}, nil + usage := mergeUsage(estimatedUsage, result.Usage) + runtime.emitUsage(ctx, onEvent, usage) + + return &responseBundle{ + Text: result.Text, + Thinking: result.Thinking, + ToolEvents: result.ToolEvents, + Usage: usage, + }, nil } func (runtime *Runtime) completeWithRetry( @@ -701,7 +733,12 @@ func (runtime *Runtime) emitActivatedSkillReads( slog.Any("error", err), ) } - emitStreamEvent(onEvent, StreamEvent{ToolEvent: &toolEvent, Kind: StreamEventSkillLoaded, Text: skill.Name}) + emitStreamEvent(onEvent, StreamEvent{ + ToolEvent: &toolEvent, + Usage: nil, + Kind: StreamEventSkillLoaded, + Text: skill.Name, + }) toolEvents = append(toolEvents, toolEvent) } diff --git a/internal/assistant/runtime_test.go b/internal/assistant/runtime_test.go index 764e060..7590d7b 100644 --- a/internal/assistant/runtime_test.go +++ b/internal/assistant/runtime_test.go @@ -372,6 +372,7 @@ func (client *retryCompletionClient) Complete( Text: client.response + " for " + request.Messages[len(request.Messages)-1].Content, Thinking: nil, ToolEvents: nil, + Usage: model.EmptyTokenUsage(), }, nil } @@ -384,6 +385,7 @@ func (testCompletionClient) Complete( if request.OnEvent != nil { request.OnEvent(assistant.StreamEvent{ ToolEvent: nil, + Usage: nil, Kind: assistant.StreamEventTextDelta, Text: "test assistant response for " + request.Messages[len(request.Messages)-1].Content, }) @@ -393,6 +395,7 @@ func (testCompletionClient) Complete( Text: "test assistant response for " + request.Messages[len(request.Messages)-1].Content, Thinking: nil, ToolEvents: nil, + Usage: model.TokenUsage{InputTokens: 12, OutputTokens: 4, ContextTokens: 12, ContextWindow: 1000}, }, nil } diff --git a/internal/assistant/sse.go b/internal/assistant/sse.go index 331ca07..d8fe6bd 100644 --- a/internal/assistant/sse.go +++ b/internal/assistant/sse.go @@ -7,6 +7,8 @@ import ( "strings" "github.com/samber/oops" + + "github.com/omarluq/librecode/internal/model" ) type sseAccumulator struct { @@ -29,12 +31,26 @@ func (accumulator *sseAccumulator) add(event map[string]any, onEvent func(Stream if response, ok := event["response"].(map[string]any); ok { accumulator.finalResponse = response } + if usage, ok := event["usage"].(map[string]any); ok { + accumulator.finalResponse = ensureSSEFinalResponse(accumulator.finalResponse) + accumulator.finalResponse["usage"] = usage + } if text, delta := thinkingTextFromSSEEvent(event); delta && text != "" { - emitStreamEvent(onEvent, StreamEvent{ToolEvent: nil, Kind: StreamEventThinkingDelta, Text: text}) + emitStreamEvent(onEvent, StreamEvent{ + ToolEvent: nil, + Usage: nil, + Kind: StreamEventThinkingDelta, + Text: text, + }) } if text, delta := textFromSSEEvent(event); delta && text != "" { accumulator.parts = append(accumulator.parts, text) - emitStreamEvent(onEvent, StreamEvent{ToolEvent: nil, Kind: StreamEventTextDelta, Text: text}) + emitStreamEvent(onEvent, StreamEvent{ + ToolEvent: nil, + Usage: nil, + Kind: StreamEventTextDelta, + Text: text, + }) } if item, ok := event["item"].(map[string]any); ok { accumulator.addItem(item) @@ -69,6 +85,14 @@ func (accumulator *sseAccumulator) addArguments(event map[string]any, arguments accumulator.items = upsertSSEItem(accumulator.items, item) } +func ensureSSEFinalResponse(response map[string]any) map[string]any { + if response != nil { + return response + } + + return map[string]any{} +} + func sseItemID(event map[string]any) string { for _, key := range []string{"item_id", "output_item_id", "id"} { if value := stringValue(event[key]); value != "" { @@ -109,7 +133,9 @@ func parseSSEResult(reader io.Reader, onEvent func(StreamEvent)) (*providerResul if accumulator.finalResponse != nil { result := providerResultFromResponse(accumulator.finalResponse) if len(result.OutputItems) == 0 && len(accumulator.items) > 0 { + usage := result.Usage result = providerResultFromOutputItems(accumulator.items, fallbackText) + result.Usage = usage } if strings.TrimSpace(result.Text) == "" { result.Text = fallbackText @@ -121,7 +147,13 @@ func parseSSEResult(reader io.Reader, onEvent func(StreamEvent)) (*providerResul return providerResultFromOutputItems(accumulator.items, fallbackText), nil } - return &providerResult{Text: fallbackText, OutputItems: nil, Thinking: nil, ToolCalls: nil}, nil + return &providerResult{ + Text: fallbackText, + OutputItems: nil, + Thinking: nil, + ToolCalls: nil, + Usage: model.EmptyTokenUsage(), + }, nil } func scanSSEResponse(scanner *bufio.Scanner, onEvent func(StreamEvent)) (accumulator *sseAccumulator, err error) { diff --git a/internal/assistant/tool_loop.go b/internal/assistant/tool_loop.go index 6cd652f..339524e 100644 --- a/internal/assistant/tool_loop.go +++ b/internal/assistant/tool_loop.go @@ -39,7 +39,12 @@ func executeToolCalls( outputs := make([]any, 0, len(calls)) events := make([]ToolEvent, 0, len(calls)) for _, call := range calls { - emitStreamEvent(onEvent, StreamEvent{ToolEvent: nil, Kind: StreamEventToolStart, Text: call.Name}) + emitStreamEvent(onEvent, StreamEvent{ + ToolEvent: nil, + Usage: nil, + Kind: StreamEventToolStart, + Text: call.Name, + }) result, err := registry.Execute(ctx, call.Name, call.Arguments) resultText := result.Text() detailsJSON := encodeToolDetails(result.Details) @@ -59,7 +64,12 @@ func executeToolCalls( } event.Result = resultText events = append(events, event) - emitStreamEvent(onEvent, StreamEvent{ToolEvent: &event, Kind: StreamEventToolResult, Text: ""}) + emitStreamEvent(onEvent, StreamEvent{ + ToolEvent: &event, + Usage: nil, + Kind: StreamEventToolResult, + Text: "", + }) outputs = append(outputs, map[string]any{ jsonTypeKey: functionCallOutputType, jsonCallIDKey: call.ID, diff --git a/internal/assistant/usage.go b/internal/assistant/usage.go new file mode 100644 index 0000000..b42d022 --- /dev/null +++ b/internal/assistant/usage.go @@ -0,0 +1,123 @@ +package assistant + +import ( + "encoding/json" + "strings" + "unicode/utf8" + + "github.com/omarluq/librecode/internal/database" + "github.com/omarluq/librecode/internal/model" +) + +func estimateTokenUsage( + systemPrompt string, + messages []database.MessageEntity, + selectedModel *model.Model, +) model.TokenUsage { + inputTokens := estimateInputTokens(systemPrompt, messages) + + return model.TokenUsage{ + ContextWindow: selectedModel.ContextWindow, + ContextTokens: inputTokens, + InputTokens: inputTokens, + OutputTokens: 0, + } +} + +func estimateInputTokens(systemPrompt string, messages []database.MessageEntity) int { + count := estimateTokens(systemPrompt) + for index := range messages { + count += estimateTokens(messages[index].Content) + } + + return count +} + +func estimateTokens(text string) int { + trimmed := strings.TrimSpace(text) + if trimmed == "" { + return 0 + } + runes := utf8.RuneCountInString(trimmed) + if runes == 0 { + return 0 + } + + // Rough cross-provider estimate used until provider usage arrives. + return max(1, (runes+3)/4) +} + +func mergeUsage(estimated, reported model.TokenUsage) model.TokenUsage { + usage := estimated + if reported.ContextWindow > 0 { + usage.ContextWindow = reported.ContextWindow + } + if reported.ContextTokens > usage.ContextTokens { + usage.ContextTokens = reported.ContextTokens + } + if reported.InputTokens > 0 { + usage.InputTokens = reported.InputTokens + } + if reported.OutputTokens > 0 { + usage.OutputTokens = reported.OutputTokens + } + if reportedTotal := reported.TotalTokens(); reportedTotal > usage.ContextTokens { + usage.ContextTokens = reportedTotal + } + + return usage +} + +func usageFromObject(value any) model.TokenUsage { + object, ok := value.(map[string]any) + if !ok { + return model.EmptyTokenUsage() + } + input := usageInputTokens(object) + output := intFromAny(firstPresent(object, jsonOutputTokensKey, "completion_tokens")) + + return model.TokenUsage{ContextWindow: 0, ContextTokens: 0, InputTokens: input, OutputTokens: output} +} + +func usageInputTokens(object map[string]any) int { + input := intFromAny(firstPresent(object, "input_tokens", "prompt_tokens")) + if input > 0 { + return input + } + if total := intFromAny(object["total_tokens"]); total > 0 { + output := intFromAny(firstPresent(object, jsonOutputTokensKey, "completion_tokens")) + if output > 0 && total > output { + return total - output + } + } + + return 0 +} + +func firstPresent(object map[string]any, keys ...string) any { + for _, key := range keys { + if value, ok := object[key]; ok { + return value + } + } + + return nil +} + +func intFromAny(value any) int { + switch typed := value.(type) { + case int: + return typed + case int64: + return int(typed) + case float64: + return int(typed) + case json.Number: + parsed, err := typed.Int64() + if err == nil { + return int(parsed) + } + } + + return 0 +} diff --git a/internal/assistant/usage_events.go b/internal/assistant/usage_events.go new file mode 100644 index 0000000..63f7f6a --- /dev/null +++ b/internal/assistant/usage_events.go @@ -0,0 +1,31 @@ +package assistant + +import ( + "context" + + "github.com/omarluq/librecode/internal/model" +) + +func (runtime *Runtime) emitUsage(ctx context.Context, onEvent func(StreamEvent), usage model.TokenUsage) { + if !usage.HasAny() { + return + } + emitStreamEvent(onEvent, StreamEvent{ + ToolEvent: nil, + Usage: &usage, + Kind: StreamEventUsage, + Text: "", + }) + payload := map[string]any{ + "context_window": usage.ContextWindow, + "context_tokens": usage.ContextTokens, + "input_tokens": usage.InputTokens, + jsonOutputTokensKey: usage.OutputTokens, + } + runtime.emit(ctx, "usage", payload) + if runtime.extensions != nil { + if err := runtime.extensions.Emit(ctx, "usage", payload); err != nil { + runtime.logger.Debug("emit usage extension event failed", "error", err) + } + } +} diff --git a/internal/assistant/usage_test.go b/internal/assistant/usage_test.go new file mode 100644 index 0000000..f43c822 --- /dev/null +++ b/internal/assistant/usage_test.go @@ -0,0 +1,107 @@ +//nolint:testpackage // These tests exercise unexported usage parsing helpers. +package assistant + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/omarluq/librecode/internal/model" +) + +func TestUsageFromObjectParsesProviderShapes(t *testing.T) { + t.Parallel() + + tests := []usageParseTest{ + { + name: "openai responses", + usage: map[string]any{ + "input_tokens": float64(123), + jsonOutputTokensKey: float64(45), + }, + expected: model.TokenUsage{ContextWindow: 0, ContextTokens: 0, InputTokens: 123, OutputTokens: 45}, + }, + { + name: "chat completions", + usage: map[string]any{ + "prompt_tokens": json.Number("77"), + "completion_tokens": json.Number("9"), + }, + expected: model.TokenUsage{ContextWindow: 0, ContextTokens: 0, InputTokens: 77, OutputTokens: 9}, + }, + { + name: "total tokens does not become input tokens", + usage: map[string]any{ + "total_tokens": json.Number("120"), + jsonOutputTokensKey: json.Number("20"), + }, + expected: model.TokenUsage{ContextWindow: 0, ContextTokens: 0, InputTokens: 100, OutputTokens: 20}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + assert.Equal(t, test.expected, usageFromObject(test.usage)) + }) + } +} + +func TestMergeUsagePreservesEstimatedContextWindow(t *testing.T) { + t.Parallel() + + estimated := model.TokenUsage{ContextWindow: 1000, ContextTokens: 200, InputTokens: 200, OutputTokens: 0} + reported := model.TokenUsage{ContextWindow: 0, ContextTokens: 0, InputTokens: 150, OutputTokens: 25} + + assert.Equal(t, model.TokenUsage{ + ContextWindow: 1000, + ContextTokens: 200, + InputTokens: 150, + OutputTokens: 25, + }, mergeUsage(estimated, reported)) +} + +func TestMergeUsageNeverShrinksEstimatedContext(t *testing.T) { + t.Parallel() + + estimated := model.TokenUsage{ContextWindow: 100_000, ContextTokens: 14_000, InputTokens: 14_000, OutputTokens: 0} + reported := model.TokenUsage{ContextWindow: 0, ContextTokens: 12_000, InputTokens: 12_000, OutputTokens: 700} + + assert.Equal(t, model.TokenUsage{ + ContextWindow: 100_000, + ContextTokens: 14_000, + InputTokens: 12_000, + OutputTokens: 700, + }, mergeUsage(estimated, reported)) +} + +func TestParseSSEResultPreservesUsageWhenItemsProvideText(t *testing.T) { + t.Parallel() + + stream := strings.Join([]string{ + `data: {"response":{"usage":{"input_tokens":12,"output_tokens":7}}}`, + `data: {"item":{"id":"msg_1","type":"message","content":[{"type":"output_text","text":"hello"}]}}`, + `data: [DONE]`, + ``, + }, "\n") + + result, err := parseSSEResult(strings.NewReader(stream), nil) + require.NoError(t, err) + assert.Equal(t, model.TokenUsage{ + ContextWindow: 0, + ContextTokens: 0, + InputTokens: 12, + OutputTokens: 7, + }, result.Usage) + assert.Equal(t, "hello", result.Text) +} + +type usageParseTest struct { + usage map[string]any + name string + expected model.TokenUsage +} diff --git a/internal/model/usage.go b/internal/model/usage.go new file mode 100644 index 0000000..81f3414 --- /dev/null +++ b/internal/model/usage.go @@ -0,0 +1,37 @@ +package model + +// TokenUsage tracks model context and request/response token counts. +type TokenUsage struct { + ContextWindow int `json:"context_window,omitempty"` + ContextTokens int `json:"context_tokens,omitempty"` + InputTokens int `json:"input_tokens,omitempty"` + OutputTokens int `json:"output_tokens,omitempty"` +} + +// EmptyTokenUsage returns a zero-value token usage with explicit fields. +func EmptyTokenUsage() TokenUsage { + return TokenUsage{ContextWindow: 0, ContextTokens: 0, InputTokens: 0, OutputTokens: 0} +} + +// TotalTokens returns input plus output tokens reported for the turn. +func (usage TokenUsage) TotalTokens() int { + return usage.InputTokens + usage.OutputTokens +} + +// HasAny reports whether any usage field is populated. +func (usage TokenUsage) HasAny() bool { + return usage.ContextWindow > 0 || usage.ContextTokens > 0 || usage.InputTokens > 0 || usage.OutputTokens > 0 +} + +// ContextPercent returns the context-window usage percentage, if known. +func (usage TokenUsage) ContextPercent() int { + if usage.ContextWindow <= 0 || usage.ContextTokens <= 0 { + return 0 + } + percent := usage.ContextTokens * 100 / usage.ContextWindow + if percent > 100 { + return 100 + } + + return percent +} diff --git a/internal/terminal/app.go b/internal/terminal/app.go index 3eb0daa..efb3991 100644 --- a/internal/terminal/app.go +++ b/internal/terminal/app.go @@ -122,11 +122,11 @@ type App struct { selectedPanelKind panelKind sessionID string statusMessage string - mode appMode streamingText string streamingThinkingText string cwd string promptHistoryDraft string + mode appMode resources core.ResourceSnapshot messageLineCache []cachedRenderedMessage streamingBlockLineCache []cachedRenderedMessage @@ -140,6 +140,7 @@ type App struct { messageLineCacheState messageLineCacheState streamingBlockLineCacheState messageLineCacheState selection mouseSelection + tokenUsage model.TokenUsage promptSequence uint64 workFrame int lastMessageMaxRows int @@ -252,6 +253,7 @@ func newApp(screen tcell.Screen, options *RunOptions) *App { promptHistoryIndex: 0, promptSequence: 0, statusMessage: "", + tokenUsage: model.EmptyTokenUsage(), selectedPanelKind: "", streamingText: "", streamingThinkingText: "", @@ -488,7 +490,7 @@ func (app *App) shouldDrawImmediately(event tcell.Event) bool { if !ok { return true } - payload, ok := interrupt.Data().(asyncEvent) + payload, ok := interrupt.Data().(*asyncEvent) if !ok { return true } @@ -501,7 +503,8 @@ func isHighVolumePromptStreamEvent(kind asyncEventKind) bool { case asyncEventPromptDelta, asyncEventPromptThinkingDelta, asyncEventPromptToolStart, - asyncEventPromptToolResult: + asyncEventPromptToolResult, + asyncEventPromptUsage: return true case asyncEventAuthURL, asyncEventAuthDone, diff --git a/internal/terminal/async_events.go b/internal/terminal/async_events.go index 0c8edfd..b748e65 100644 --- a/internal/terminal/async_events.go +++ b/internal/terminal/async_events.go @@ -8,6 +8,7 @@ import ( "github.com/omarluq/librecode/internal/assistant" "github.com/omarluq/librecode/internal/database" + "github.com/omarluq/librecode/internal/model" ) type asyncEventKind string @@ -23,12 +24,14 @@ const ( asyncEventPromptToolStart asyncEventKind = "prompt_tool_start" asyncEventPromptToolResult asyncEventKind = "prompt_tool_result" asyncEventPromptRetry asyncEventKind = "prompt_retry" + asyncEventPromptUsage asyncEventKind = "prompt_usage" asyncEventPromptError asyncEventKind = "prompt_error" ) type asyncEvent struct { Response *assistant.PromptResponse ToolEvent *assistant.ToolEvent + Usage *model.TokenUsage Kind asyncEventKind Provider string Text string @@ -37,9 +40,10 @@ type asyncEvent struct { func (app *App) promptUserEntryHandler(ctx context.Context, promptID uint64) func(assistant.PromptUserEntryEvent) { return func(event assistant.PromptUserEntryEvent) { - app.postAsyncEvent(ctx, asyncEvent{ + app.postAsyncEvent(ctx, &asyncEvent{ Response: nil, ToolEvent: nil, + Usage: nil, Kind: asyncEventPromptUserEntry, Provider: event.SessionID, Text: event.EntryID, @@ -54,9 +58,10 @@ func (app *App) promptRetryHandler(ctx context.Context, promptID uint64) assista if event.Kind == assistant.RetryEventStart { text = "retrying model request in " + event.Delay.Round(time.Second).String() } - app.postAsyncEvent(ctx, asyncEvent{ + app.postAsyncEvent(ctx, &asyncEvent{ Response: nil, ToolEvent: nil, + Usage: nil, Kind: asyncEventPromptRetry, Provider: string(event.Kind), Text: text, @@ -69,46 +74,60 @@ func (app *App) promptStreamHandler(ctx context.Context, promptID uint64) func(a return func(event assistant.StreamEvent) { switch event.Kind { case assistant.StreamEventTextDelta: - app.postAsyncEvent(ctx, asyncEvent{ + app.postAsyncEvent(ctx, &asyncEvent{ Response: nil, ToolEvent: nil, + Usage: nil, Kind: asyncEventPromptDelta, Provider: "", Text: event.Text, PromptID: promptID, }) case assistant.StreamEventThinkingDelta: - app.postAsyncEvent(ctx, asyncEvent{ + app.postAsyncEvent(ctx, &asyncEvent{ Response: nil, ToolEvent: nil, + Usage: nil, Kind: asyncEventPromptThinkingDelta, Provider: "", Text: event.Text, PromptID: promptID, }) case assistant.StreamEventToolStart: - app.postAsyncEvent(ctx, asyncEvent{ + app.postAsyncEvent(ctx, &asyncEvent{ Response: nil, ToolEvent: nil, + Usage: nil, Kind: asyncEventPromptToolStart, Provider: "", Text: event.Text, PromptID: promptID, }) case assistant.StreamEventToolResult, assistant.StreamEventSkillLoaded: - app.postAsyncEvent(ctx, asyncEvent{ + app.postAsyncEvent(ctx, &asyncEvent{ Response: nil, ToolEvent: event.ToolEvent, + Usage: nil, Kind: asyncEventPromptToolResult, Provider: "", Text: "", PromptID: promptID, }) + case assistant.StreamEventUsage: + app.postAsyncEvent(ctx, &asyncEvent{ + Response: nil, + ToolEvent: nil, + Usage: event.Usage, + Kind: asyncEventPromptUsage, + Provider: "", + Text: "", + PromptID: promptID, + }) } } } -func (app *App) postAsyncEvent(ctx context.Context, event asyncEvent) { +func (app *App) postAsyncEvent(ctx context.Context, event *asyncEvent) { defer func() { panicValue := recover() if panicValue != nil { @@ -122,7 +141,7 @@ func (app *App) postAsyncEvent(ctx context.Context, event asyncEvent) { } func (app *App) handleInterrupt(ctx context.Context, event *tcell.EventInterrupt) (bool, error) { - payload, ok := event.Data().(asyncEvent) + payload, ok := event.Data().(*asyncEvent) if !ok { return false, nil } @@ -134,7 +153,7 @@ func (app *App) handleInterrupt(ctx context.Context, event *tcell.EventInterrupt return false, nil } -func (app *App) handleAuthAsyncEvent(payload asyncEvent) bool { +func (app *App) handleAuthAsyncEvent(payload *asyncEvent) bool { switch payload.Kind { case asyncEventAuthURL: app.addMessage(database.RoleCustom, payload.Text) @@ -157,6 +176,7 @@ func (app *App) handleAuthAsyncEvent(payload asyncEvent) bool { asyncEventPromptToolStart, asyncEventPromptToolResult, asyncEventPromptRetry, + asyncEventPromptUsage, asyncEventPromptError: return false } @@ -164,7 +184,7 @@ func (app *App) handleAuthAsyncEvent(payload asyncEvent) bool { return false } -func (app *App) handlePromptAsyncEvent(ctx context.Context, payload asyncEvent) { +func (app *App) handlePromptAsyncEvent(ctx context.Context, payload *asyncEvent) { if app.ignorePromptEvent(payload) { return } @@ -174,7 +194,7 @@ func (app *App) handlePromptAsyncEvent(ctx context.Context, payload asyncEvent) app.handlePromptStreamEvent(ctx, payload) } -func (app *App) ignorePromptEvent(payload asyncEvent) bool { +func (app *App) ignorePromptEvent(payload *asyncEvent) bool { if !isPromptAsyncEvent(payload.Kind) { return false } @@ -195,6 +215,7 @@ func isPromptAsyncEvent(kind asyncEventKind) bool { asyncEventPromptToolStart, asyncEventPromptToolResult, asyncEventPromptRetry, + asyncEventPromptUsage, asyncEventPromptError: return true case asyncEventAuthURL, asyncEventAuthDone, asyncEventAuthError: @@ -204,7 +225,7 @@ func isPromptAsyncEvent(kind asyncEventKind) bool { return false } -func (app *App) handlePromptLifecycleEvent(ctx context.Context, payload asyncEvent) bool { +func (app *App) handlePromptLifecycleEvent(ctx context.Context, payload *asyncEvent) bool { switch payload.Kind { case asyncEventPromptDone: app.emitExtensionRuntimeEventOrMessage(ctx, extensionEventPromptDone, promptDoneExtensionData(payload.Response)) @@ -228,14 +249,18 @@ func (app *App) handlePromptLifecycleEvent(ctx context.Context, payload asyncEve return true case asyncEventAuthURL, asyncEventAuthDone, asyncEventAuthError: return true - case asyncEventPromptDelta, asyncEventPromptThinkingDelta, asyncEventPromptToolStart, asyncEventPromptToolResult: + case asyncEventPromptDelta, + asyncEventPromptThinkingDelta, + asyncEventPromptToolStart, + asyncEventPromptToolResult, + asyncEventPromptUsage: return false } return false } -func (app *App) handlePromptStreamEvent(ctx context.Context, payload asyncEvent) { +func (app *App) handlePromptStreamEvent(ctx context.Context, payload *asyncEvent) { if app.activePrompt != nil && app.activePrompt.Canceled { return } @@ -264,6 +289,9 @@ func (app *App) handlePromptStreamEvent(ctx context.Context, payload asyncEvent) map[string]any{extensionDataName: payload.Text}, ) return + case asyncEventPromptUsage: + app.applyTokenUsage(payload.Usage) + return case asyncEventPromptDone, asyncEventPromptUserEntry, asyncEventPromptRetry, @@ -275,7 +303,7 @@ func (app *App) handlePromptStreamEvent(ctx context.Context, payload asyncEvent) } } -func (app *App) emitPromptRetryExtensionEvent(ctx context.Context, payload asyncEvent) { +func (app *App) emitPromptRetryExtensionEvent(ctx context.Context, payload *asyncEvent) { eventName := extensionEventRetryStart if payload.Provider == string(assistant.RetryEventEnd) { eventName = extensionEventRetryEnd diff --git a/internal/terminal/auth_commands.go b/internal/terminal/auth_commands.go index d5c332c..9a50594 100644 --- a/internal/terminal/auth_commands.go +++ b/internal/terminal/auth_commands.go @@ -284,9 +284,10 @@ func (app *App) loginOAuthProvider(ctx context.Context, config oauthLoginConfig) func (app *App) runOAuthLogin(ctx context.Context, config oauthLoginConfig) { credential, err := config.LoginFunc(ctx, func(info auth.OAuthAuthInfo) { - app.postAsyncEvent(ctx, asyncEvent{ + app.postAsyncEvent(ctx, &asyncEvent{ Response: nil, ToolEvent: nil, + Usage: nil, Kind: asyncEventAuthURL, Provider: config.Provider, Text: authInfoText(config.DisplayName, info), @@ -301,9 +302,10 @@ func (app *App) runOAuthLogin(ctx context.Context, config oauthLoginConfig) { app.postOAuthLoginError(ctx, config, err) return } - app.postAsyncEvent(ctx, asyncEvent{ + app.postAsyncEvent(ctx, &asyncEvent{ Response: nil, ToolEvent: nil, + Usage: nil, Kind: asyncEventAuthDone, Provider: config.Provider, Text: "", @@ -312,9 +314,10 @@ func (app *App) runOAuthLogin(ctx context.Context, config oauthLoginConfig) { } func (app *App) postOAuthLoginError(ctx context.Context, config oauthLoginConfig, err error) { - app.postAsyncEvent(ctx, asyncEvent{ + app.postAsyncEvent(ctx, &asyncEvent{ Response: nil, ToolEvent: nil, + Usage: nil, Kind: asyncEventAuthError, Provider: config.Provider, Text: config.LoginFailed + err.Error(), diff --git a/internal/terminal/input.go b/internal/terminal/input.go index 42eced7..71cb4e0 100644 --- a/internal/terminal/input.go +++ b/internal/terminal/input.go @@ -334,6 +334,7 @@ func (app *App) sendPrompt(ctx context.Context, text string) { app.scrollOffset = 0 app.streamingText = "" app.streamingThinkingText = "" + app.tokenUsage = model.EmptyTokenUsage() app.resetStreamingBlocks() app.streamedToolEvents = 0 app.activePrompt = &activePromptState{ @@ -354,9 +355,10 @@ func (app *App) sendPrompt(ctx context.Context, text string) { defer cancel() response, err := app.runtime.Prompt(promptCtx, request) if err != nil { - app.postAsyncEvent(ctx, asyncEvent{ + app.postAsyncEvent(ctx, &asyncEvent{ Response: nil, ToolEvent: nil, + Usage: nil, Kind: asyncEventPromptError, Provider: "", Text: err.Error(), @@ -364,9 +366,10 @@ func (app *App) sendPrompt(ctx context.Context, text string) { }) return } - app.postAsyncEvent(ctx, asyncEvent{ + app.postAsyncEvent(ctx, &asyncEvent{ Response: response, ToolEvent: nil, + Usage: nil, Kind: asyncEventPromptDone, Provider: "", Text: "", @@ -394,6 +397,7 @@ func (app *App) applyPromptResponse(ctx context.Context, response *assistant.Pro return } app.sessionID = response.SessionID + app.applyTokenUsage(&response.Usage) app.applyPromptResponseSideEffects(response, streamingBlocks) app.streamedToolEvents = 0 app.addMessage(database.RoleAssistant, response.Text) diff --git a/internal/terminal/render_parity_test.go b/internal/terminal/render_parity_test.go index f97c0a4..dee69a1 100644 --- a/internal/terminal/render_parity_test.go +++ b/internal/terminal/render_parity_test.go @@ -14,6 +14,7 @@ import ( "github.com/omarluq/librecode/internal/assistant" "github.com/omarluq/librecode/internal/config" "github.com/omarluq/librecode/internal/database" + "github.com/omarluq/librecode/internal/model" ) func TestRenderParityComposerFrame(t *testing.T) { @@ -60,6 +61,28 @@ func TestRenderParityStatuslineFrame(t *testing.T) { } } +func TestRenderParityStatuslineTokenUsage(t *testing.T) { + t.Parallel() + + app := newRenderTestApp(t) + app.tokenUsage = model.TokenUsage{ + ContextWindow: 1000, + ContextTokens: 250, + InputTokens: 0, + OutputTokens: 0, + } + + layout := app.defaultRuntimeLayout(80, 12) + app.frame = newCellBuffer(layout.Width, layout.Height, tcell.StyleDefault) + app.drawStatusWindow(&layout) + + second := frameRowText(app.frame, layout.Status.Y+1) + assertFrameContainsAll(t, second, "ctx 250/1.0k 25%") + if strings.Contains(second, "↑") || strings.Contains(second, "↓") { + t.Fatalf("status should not include input/output tokens: %q", second) + } +} + func TestRenderParityToolBlockFrame(t *testing.T) { t.Parallel() diff --git a/internal/terminal/render_test.go b/internal/terminal/render_test.go index 4cbc2f6..3679e73 100644 --- a/internal/terminal/render_test.go +++ b/internal/terminal/render_test.go @@ -9,10 +9,12 @@ import ( "github.com/gdamore/tcell/v3" cellcolor "github.com/gdamore/tcell/v3/color" + "github.com/stretchr/testify/assert" "github.com/omarluq/librecode/internal/assistant" "github.com/omarluq/librecode/internal/database" "github.com/omarluq/librecode/internal/extension" + "github.com/omarluq/librecode/internal/model" ) func TestRenderStreamingMessageUsesTextColor(t *testing.T) { @@ -357,6 +359,41 @@ func TestFlushFrameHighlightsSelection(t *testing.T) { } } +func TestApplyPromptResponsePreservesLargerStreamedContextUsage(t *testing.T) { + t.Parallel() + + app := newRenderTestApp(t) + app.applyTokenUsage(&model.TokenUsage{ + ContextWindow: 100_000, + ContextTokens: 14_000, + InputTokens: 14_000, + OutputTokens: 0, + }) + + app.applyPromptResponse(context.Background(), &assistant.PromptResponse{ + SessionID: "test-session", + UserEntryID: "user", + AssistantEntryID: "assistant", + Text: "ok", + Thinking: nil, + ToolEvents: nil, + Usage: model.TokenUsage{ + ContextWindow: 100_000, + ContextTokens: 12_000, + InputTokens: 12_000, + OutputTokens: 700, + }, + Cached: false, + }, 0) + + assert.Equal(t, model.TokenUsage{ + ContextWindow: 100_000, + ContextTokens: 14_000, + InputTokens: 0, + OutputTokens: 0, + }, app.tokenUsage) +} + func TestHighVolumeStreamEventsDoNotForceImmediateDraw(t *testing.T) { t.Parallel() @@ -366,6 +403,7 @@ func TestHighVolumeStreamEventsDoNotForceImmediateDraw(t *testing.T) { asyncEventPromptThinkingDelta, asyncEventPromptToolStart, asyncEventPromptToolResult, + asyncEventPromptUsage, } { event := tcell.NewEventInterrupt(newTestAsyncEvent(kind, "")) if app.shouldDrawImmediately(event) { @@ -887,10 +925,11 @@ func newRenderTestApp(t *testing.T) *App { return app } -func newTestAsyncEvent(kind asyncEventKind, text string) asyncEvent { - return asyncEvent{ +func newTestAsyncEvent(kind asyncEventKind, text string) *asyncEvent { + return &asyncEvent{ Response: nil, ToolEvent: nil, + Usage: nil, Kind: kind, Provider: "", Text: text, diff --git a/internal/terminal/runtime_buffers.go b/internal/terminal/runtime_buffers.go index 666815f..a1a8367 100644 --- a/internal/terminal/runtime_buffers.go +++ b/internal/terminal/runtime_buffers.go @@ -239,6 +239,9 @@ func (app *App) defaultStatusLineTexts() []string { if app.currentThinkingLevel() != "" { modelText += " • " + app.currentThinkingLevel() } + if tokenText := app.tokenStatusText(); tokenText != "" { + modelText += " • " + tokenText + } return []string{pathLine, modelText} } diff --git a/internal/terminal/token_usage.go b/internal/terminal/token_usage.go new file mode 100644 index 0000000..4ac3baf --- /dev/null +++ b/internal/terminal/token_usage.go @@ -0,0 +1,73 @@ +package terminal + +import ( + "fmt" + + "github.com/omarluq/librecode/internal/model" +) + +func (app *App) applyTokenUsage(usage *model.TokenUsage) { + if usage == nil || !usage.HasAny() { + return + } + app.tokenUsage = mergeTerminalUsage(app.tokenUsage, *usage) +} + +func mergeTerminalUsage(current, next model.TokenUsage) model.TokenUsage { + if next.ContextWindow > 0 { + current.ContextWindow = next.ContextWindow + } + if next.ContextTokens > current.ContextTokens { + current.ContextTokens = next.ContextTokens + } + + return current +} + +func (app *App) tokenStatusText() string { + return formatTokenStatus(app.tokenUsage) +} + +func formatTokenStatus(usage model.TokenUsage) string { + if !usage.HasAny() { + return "" + } + contextText := formatContextUsage(usage) + if contextText == "" { + return "" + } + + return contextText +} + +func formatContextUsage(usage model.TokenUsage) string { + switch { + case usage.ContextTokens > 0 && usage.ContextWindow > 0: + return fmt.Sprintf( + "ctx %s/%s %d%%", + compactCount(usage.ContextTokens), + compactCount(usage.ContextWindow), + usage.ContextPercent(), + ) + case usage.ContextTokens > 0: + return "ctx " + compactCount(usage.ContextTokens) + case usage.ContextWindow > 0: + return "ctx 0/" + compactCount(usage.ContextWindow) + default: + return "" + } +} + +func compactCount(value int) string { + if value >= 1_000_000 { + return fmt.Sprintf("%.1fm", float64(value)/1_000_000) + } + if value >= 10_000 { + return fmt.Sprintf("%dk", value/1_000) + } + if value >= 1_000 { + return fmt.Sprintf("%.1fk", float64(value)/1_000) + } + + return fmt.Sprintf("%d", value) +} diff --git a/internal/terminal/token_usage_export_test.go b/internal/terminal/token_usage_export_test.go new file mode 100644 index 0000000..7e2b747 --- /dev/null +++ b/internal/terminal/token_usage_export_test.go @@ -0,0 +1,7 @@ +package terminal + +import "github.com/omarluq/librecode/internal/model" + +func MergeTerminalUsageForTest(current, next model.TokenUsage) model.TokenUsage { + return mergeTerminalUsage(current, next) +} diff --git a/internal/terminal/token_usage_test.go b/internal/terminal/token_usage_test.go new file mode 100644 index 0000000..3727293 --- /dev/null +++ b/internal/terminal/token_usage_test.go @@ -0,0 +1,38 @@ +package terminal_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/omarluq/librecode/internal/model" + "github.com/omarluq/librecode/internal/terminal" +) + +func TestMergeTerminalUsageIgnoresInputOutputTokens(t *testing.T) { + t.Parallel() + + current := model.TokenUsage{ContextWindow: 1_000_000, ContextTokens: 0, InputTokens: 0, OutputTokens: 0} + next := model.TokenUsage{ContextWindow: 0, ContextTokens: 0, InputTokens: 12_000, OutputTokens: 700} + + assert.Equal(t, model.TokenUsage{ + ContextWindow: 1_000_000, + ContextTokens: 0, + InputTokens: 0, + OutputTokens: 0, + }, terminal.MergeTerminalUsageForTest(current, next)) +} + +func TestMergeTerminalUsagePreservesEstimatedContext(t *testing.T) { + t.Parallel() + + current := model.TokenUsage{ContextWindow: 1_000_000, ContextTokens: 17_000, InputTokens: 17_000, OutputTokens: 0} + next := model.TokenUsage{ContextWindow: 0, ContextTokens: 0, InputTokens: 12_000, OutputTokens: 700} + + assert.Equal(t, model.TokenUsage{ + ContextWindow: 1_000_000, + ContextTokens: 17_000, + InputTokens: 17_000, + OutputTokens: 0, + }, terminal.MergeTerminalUsageForTest(current, next)) +} From b0ef5cbd5c5d9b0612aa1c1d4d7923dd8a5dc7c9 Mon Sep 17 00:00:00 2001 From: Omar Alani Date: Wed, 13 May 2026 10:09:08 -0500 Subject: [PATCH 2/3] fix(assistant): estimate active context only --- internal/assistant/runtime.go | 57 +++++++++++++++++++++-------- internal/assistant/runtime_test.go | 58 ++++++++++++++++++++++++++++++ internal/assistant/usage.go | 4 --- internal/assistant/usage_test.go | 14 ++++++++ 4 files changed, 115 insertions(+), 18 deletions(-) diff --git a/internal/assistant/runtime.go b/internal/assistant/runtime.go index 62a6f15..2577cb9 100644 --- a/internal/assistant/runtime.go +++ b/internal/assistant/runtime.go @@ -557,9 +557,9 @@ func (runtime *Runtime) modelResponse( With("provider", selectedModel.Provider). Wrapf(fmt.Errorf("%s", auth.Error), "resolve model auth") } - sessionMessages, err := runtime.sessions.Messages(ctx, sessionID) + messages, err := runtime.modelContextMessages(ctx, sessionID) if err != nil { - return nil, oops.In("assistant").Code("load_context").Wrapf(err, "load session context") + return nil, err } systemPrompt := defaultSystemPrompt(cwd) @@ -590,7 +590,6 @@ func (runtime *Runtime) modelResponse( } } - messages := messageEntities(sessionMessages) estimatedUsage := estimateTokenUsage(systemPrompt, messages, &selectedModel) runtime.emitUsage(ctx, onEvent, estimatedUsage) request := &CompletionRequest{ @@ -779,20 +778,50 @@ func activeSkillMatchPayload(matches []core.SkillActivationDiagnostic) []map[str return payload } -func messageEntities(messages []database.SessionMessageEntity) []database.MessageEntity { - converted := make([]database.MessageEntity, 0, len(messages)) +func (runtime *Runtime) modelContextMessages(ctx context.Context, sessionID string) ([]database.MessageEntity, error) { + leafEntry, _, err := runtime.sessions.LeafEntry(ctx, sessionID) + if err != nil { + return nil, oops.In("assistant").Code("load_context_leaf").Wrapf(err, "load session leaf") + } + leafID := "" + if leafEntry != nil { + leafID = leafEntry.ID + } + contextEntity, err := runtime.sessions.BuildContext(ctx, sessionID, leafID) + if err != nil { + return nil, oops.In("assistant").Code("load_context").Wrapf(err, "load session context") + } + + return modelFacingMessages(contextEntity.Messages), nil +} + +func modelFacingMessages(messages []database.MessageEntity) []database.MessageEntity { + filtered := make([]database.MessageEntity, 0, len(messages)) for index := range messages { - message := &messages[index] - converted = append(converted, database.MessageEntity{ - Timestamp: message.CreatedAt, - Role: message.Role, - Content: message.Content, - Provider: message.Provider, - Model: message.Model, - }) + message := messages[index] + if !isModelFacingRole(message.Role) || strings.TrimSpace(message.Content) == "" { + continue + } + filtered = append(filtered, message) + } + + return filtered +} + +func isModelFacingRole(role database.Role) bool { + switch role { + case database.RoleUser, database.RoleAssistant: + return true + case database.RoleToolResult, + database.RoleThinking, + database.RoleCustom, + database.RoleBashExecution, + database.RoleBranchSummary, + database.RoleCompactionSummary: + return false } - return converted + return false } func defaultSystemPrompt(cwd string) string { diff --git a/internal/assistant/runtime_test.go b/internal/assistant/runtime_test.go index 7590d7b..90ad37c 100644 --- a/internal/assistant/runtime_test.go +++ b/internal/assistant/runtime_test.go @@ -262,6 +262,54 @@ func TestRuntime_SlashSkillShowsContent(t *testing.T) { assert.Equal(t, assistant.StreamEventSkillLoaded, events[0].Kind) } +func TestRuntime_PromptEstimatesContextFromModelFacingBranch(t *testing.T) { + home := t.TempDir() + cwd := t.TempDir() + t.Setenv("HOME", home) + + _, repository := newTestRuntime(t) + ctx := context.Background() + session, err := repository.CreateSession(ctx, cwd, "usage", "") + require.NoError(t, err) + userEntry, err := repository.AppendMessage(ctx, session.ID, nil, &database.MessageEntity{ + Timestamp: time.Now().UTC(), + Role: database.RoleUser, + Content: "hello", + Provider: "", + Model: "", + }) + require.NoError(t, err) + _, err = repository.AppendMessage(ctx, session.ID, &userEntry.ID, &database.MessageEntity{ + Timestamp: time.Now().UTC(), + Role: database.RoleToolResult, + Content: strings.Repeat("tool output ", 10_000), + Provider: "", + Model: "", + }) + require.NoError(t, err) + client := &capturingCompletionClient{request: nil} + runtime, _ := newTestRuntimeWithRepositoryAndClient(t, repository, client) + + var usageEvents []assistant.StreamEvent + request := newRuntimePromptRequest(cwd, "next", "") + request.SessionID = session.ID + request.OnEvent = func(event assistant.StreamEvent) { + if event.Kind == assistant.StreamEventUsage { + usageEvents = append(usageEvents, event) + } + } + + _, err = runtime.Prompt(ctx, request) + require.NoError(t, err) + require.NotNil(t, client.request) + for _, message := range client.request.Messages { + assert.NotEqual(t, database.RoleToolResult, message.Role) + } + require.NotEmpty(t, usageEvents) + require.NotNil(t, usageEvents[0].Usage) + assert.Less(t, usageEvents[0].Usage.ContextTokens, 1000) +} + func TestRuntime_PromptIncludesDiscoveredSkills(t *testing.T) { home := t.TempDir() cwd := t.TempDir() @@ -317,6 +365,16 @@ func newTestRuntimeWithClient( require.NoError(t, database.Migrate(context.Background(), connection)) repository := database.NewSessionRepository(connection) + return newTestRuntimeWithRepositoryAndClient(t, repository, client) +} + +func newTestRuntimeWithRepositoryAndClient( + t *testing.T, + repository *database.SessionRepository, + client assistant.CompletionClient, +) (*assistant.Runtime, *database.SessionRepository) { + t.Helper() + manager := extension.NewManager(slog.New(slog.NewTextHandler(io.Discard, nil))) t.Cleanup(manager.Shutdown) cache := assistant.NewResponseCache(true, 32, time.Minute) diff --git a/internal/assistant/usage.go b/internal/assistant/usage.go index b42d022..3a877b3 100644 --- a/internal/assistant/usage.go +++ b/internal/assistant/usage.go @@ -61,10 +61,6 @@ func mergeUsage(estimated, reported model.TokenUsage) model.TokenUsage { if reported.OutputTokens > 0 { usage.OutputTokens = reported.OutputTokens } - if reportedTotal := reported.TotalTokens(); reportedTotal > usage.ContextTokens { - usage.ContextTokens = reportedTotal - } - return usage } diff --git a/internal/assistant/usage_test.go b/internal/assistant/usage_test.go index f43c822..360fa19 100644 --- a/internal/assistant/usage_test.go +++ b/internal/assistant/usage_test.go @@ -79,6 +79,20 @@ func TestMergeUsageNeverShrinksEstimatedContext(t *testing.T) { }, mergeUsage(estimated, reported)) } +func TestMergeUsageDoesNotPromoteProviderTotalToContext(t *testing.T) { + t.Parallel() + + estimated := model.TokenUsage{ContextWindow: 272_000, ContextTokens: 0, InputTokens: 0, OutputTokens: 0} + reported := model.TokenUsage{ContextWindow: 0, ContextTokens: 0, InputTokens: 13_000_000, OutputTokens: 100} + + assert.Equal(t, model.TokenUsage{ + ContextWindow: 272_000, + ContextTokens: 0, + InputTokens: 13_000_000, + OutputTokens: 100, + }, mergeUsage(estimated, reported)) +} + func TestParseSSEResultPreservesUsageWhenItemsProvideText(t *testing.T) { t.Parallel() From 0a542c6b928a4c4a501f58c456c363076d316bcc Mon Sep 17 00:00:00 2001 From: Omar Alani Date: Wed, 13 May 2026 10:22:20 -0500 Subject: [PATCH 3/3] fix(usage): address context status review --- internal/assistant/sse.go | 31 ++++++++++++---- internal/assistant/usage_test.go | 22 +++++++++++ internal/terminal/app.go | 2 + internal/terminal/token_usage_export_test.go | 39 +++++++++++++++++++- internal/terminal/token_usage_test.go | 23 ++++++++++++ 5 files changed, 109 insertions(+), 8 deletions(-) diff --git a/internal/assistant/sse.go b/internal/assistant/sse.go index d8fe6bd..2f62de1 100644 --- a/internal/assistant/sse.go +++ b/internal/assistant/sse.go @@ -28,13 +28,8 @@ func newSSEAccumulator() *sseAccumulator { } func (accumulator *sseAccumulator) add(event map[string]any, onEvent func(StreamEvent)) { - if response, ok := event["response"].(map[string]any); ok { - accumulator.finalResponse = response - } - if usage, ok := event["usage"].(map[string]any); ok { - accumulator.finalResponse = ensureSSEFinalResponse(accumulator.finalResponse) - accumulator.finalResponse["usage"] = usage - } + accumulator.addResponse(event) + accumulator.addUsage(event) if text, delta := thinkingTextFromSSEEvent(event); delta && text != "" { emitStreamEvent(onEvent, StreamEvent{ ToolEvent: nil, @@ -60,6 +55,28 @@ func (accumulator *sseAccumulator) add(event map[string]any, onEvent func(Stream } } +func (accumulator *sseAccumulator) addResponse(event map[string]any) { + response, ok := event["response"].(map[string]any) + if !ok { + return + } + if accumulator.finalResponse != nil { + if usage := accumulator.finalResponse["usage"]; usage != nil && response["usage"] == nil { + response["usage"] = usage + } + } + accumulator.finalResponse = response +} + +func (accumulator *sseAccumulator) addUsage(event map[string]any) { + usage, ok := event["usage"].(map[string]any) + if !ok { + return + } + accumulator.finalResponse = ensureSSEFinalResponse(accumulator.finalResponse) + accumulator.finalResponse["usage"] = usage +} + func (accumulator *sseAccumulator) addItem(item map[string]any) { itemID := stringValue(item["id"]) if itemID != "" { diff --git a/internal/assistant/usage_test.go b/internal/assistant/usage_test.go index 360fa19..ccfa3d1 100644 --- a/internal/assistant/usage_test.go +++ b/internal/assistant/usage_test.go @@ -114,6 +114,28 @@ func TestParseSSEResultPreservesUsageWhenItemsProvideText(t *testing.T) { assert.Equal(t, "hello", result.Text) } +func TestParseSSEResultPreservesUsageAcrossLaterResponseEvents(t *testing.T) { + t.Parallel() + + stream := strings.Join([]string{ + `data: {"usage":{"input_tokens":12,"output_tokens":7}}`, + `data: {"response":{"output":[{"id":"msg_1","type":"message",` + + `"content":[{"type":"output_text","text":"hello"}]}]}}`, + `data: [DONE]`, + ``, + }, "\n") + + result, err := parseSSEResult(strings.NewReader(stream), nil) + require.NoError(t, err) + assert.Equal(t, model.TokenUsage{ + ContextWindow: 0, + ContextTokens: 0, + InputTokens: 12, + OutputTokens: 7, + }, result.Usage) + assert.Equal(t, "hello", result.Text) +} + type usageParseTest struct { usage map[string]any name string diff --git a/internal/terminal/app.go b/internal/terminal/app.go index efb3991..27059da 100644 --- a/internal/terminal/app.go +++ b/internal/terminal/app.go @@ -579,6 +579,7 @@ func (app *App) resetMessages() { app.messageCacheWarmIndex = 0 app.messageCacheWarm = false app.messageCacheWarmQueued = false + app.tokenUsage = model.EmptyTokenUsage() app.resetPromptHistory() } @@ -590,6 +591,7 @@ func (app *App) truncateMessages(length int) { app.messageRowPrefixSums = nil app.messageCacheWarmIndex = 0 app.messageCacheWarm = false + app.tokenUsage = model.EmptyTokenUsage() } func (app *App) resetStreamingBlocks() { diff --git a/internal/terminal/token_usage_export_test.go b/internal/terminal/token_usage_export_test.go index 7e2b747..40e82f7 100644 --- a/internal/terminal/token_usage_export_test.go +++ b/internal/terminal/token_usage_export_test.go @@ -1,7 +1,44 @@ package terminal -import "github.com/omarluq/librecode/internal/model" +import ( + "github.com/omarluq/librecode/internal/database" + "github.com/omarluq/librecode/internal/model" +) func MergeTerminalUsageForTest(current, next model.TokenUsage) model.TokenUsage { return mergeTerminalUsage(current, next) } + +func NewAppForTest() *App { + return newApp(nil, &RunOptions{ + Extensions: nil, + Resources: nil, + Runtime: nil, + Settings: nil, + Models: nil, + Auth: nil, + Config: nil, + CWD: "", + SessionID: "", + }) +} + +func (app *App) SetTokenUsageForTest(usage model.TokenUsage) { + app.tokenUsage = usage +} + +func (app *App) TokenUsageForTest() model.TokenUsage { + return app.tokenUsage +} + +func (app *App) ResetMessagesForTest() { + app.resetMessages() +} + +func (app *App) TruncateMessagesForTest(length int) { + app.truncateMessages(length) +} + +func (app *App) AddMessageForTest(role, content string) { + app.addMessage(database.Role(role), content) +} diff --git a/internal/terminal/token_usage_test.go b/internal/terminal/token_usage_test.go index 3727293..eb0a25d 100644 --- a/internal/terminal/token_usage_test.go +++ b/internal/terminal/token_usage_test.go @@ -36,3 +36,26 @@ func TestMergeTerminalUsagePreservesEstimatedContext(t *testing.T) { OutputTokens: 0, }, terminal.MergeTerminalUsageForTest(current, next)) } + +func TestResetMessagesClearsTokenUsage(t *testing.T) { + t.Parallel() + + app := terminal.NewAppForTest() + app.SetTokenUsageForTest(model.TokenUsage{ContextWindow: 1000, ContextTokens: 250, InputTokens: 0, OutputTokens: 0}) + + app.ResetMessagesForTest() + + assert.Equal(t, model.EmptyTokenUsage(), app.TokenUsageForTest()) +} + +func TestTruncateMessagesClearsTokenUsage(t *testing.T) { + t.Parallel() + + app := terminal.NewAppForTest() + app.SetTokenUsageForTest(model.TokenUsage{ContextWindow: 1000, ContextTokens: 250, InputTokens: 0, OutputTokens: 0}) + app.AddMessageForTest("user", "hello") + + app.TruncateMessagesForTest(0) + + assert.Equal(t, model.EmptyTokenUsage(), app.TokenUsageForTest()) +}