Skip to content

Commit 51e533d

Browse files
committed
🤖 fix: correct context usage display for multi-step tool calls
The Context Usage UI was showing inflated cachedInputTokens for plan messages with multi-step tool calls (e.g., ~150k instead of ~50k). Root cause: contextUsage was falling back to cumulative usage (summed across all steps) when contextUsage was undefined. For multi-step requests, cachedInputTokens gets summed because each step reads from cache, but the actual context window only sees one step's worth. Changes: - Backend: Refactor getStreamMetadata() to fetch totalUsage (for costs) and contextUsage (last step, for context window) separately from AI SDK - Backend: Add contextProviderMetadata from streamResult.providerMetadata for accurate cache creation token display - Frontend: Remove fallback from contextUsage to usage - only use contextUsage for context window display The fix ensures context window shows last step's inputTokens (actual context size) while cost calculation still uses cumulative totals.
1 parent 4eef087 commit 51e533d

File tree

5 files changed

+48
-21
lines changed

5 files changed

+48
-21
lines changed

bun.lock

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
{
22
"lockfileVersion": 1,
3+
"configVersion": 0,
34
"workspaces": {
45
"": {
56
"name": "mux",

src/browser/stores/WorkspaceStore.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ export class WorkspaceStore {
486486
if (msg.metadata?.compacted) {
487487
continue;
488488
}
489-
const rawUsage = msg.metadata?.contextUsage ?? msg.metadata?.usage;
489+
const rawUsage = msg.metadata?.contextUsage;
490490
const providerMeta =
491491
msg.metadata?.contextProviderMetadata ?? msg.metadata?.providerMetadata;
492492
if (rawUsage) {

src/browser/stories/mockFactory.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ export function createAssistantMessage(
199199
timestamp: opts.timestamp ?? STABLE_TIMESTAMP,
200200
model: opts.model ?? DEFAULT_MODEL,
201201
usage: { inputTokens: 100, outputTokens: 50, totalTokens: 150 },
202+
contextUsage: { inputTokens: 100, outputTokens: 50, totalTokens: 150 },
202203
duration: 1000,
203204
},
204205
};

src/common/orpc/schemas/message.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ export const MuxMessageSchema = z.object({
7676
timestamp: z.number().optional(),
7777
model: z.string().optional(),
7878
usage: z.any().optional(),
79+
contextUsage: z.any().optional(),
7980
providerMetadata: z.record(z.string(), z.unknown()).optional(),
81+
contextProviderMetadata: z.record(z.string(), z.unknown()).optional(),
8082
duration: z.number().optional(),
8183
systemMessageTokens: z.number().optional(),
8284
muxMetadata: z.any().optional(),

src/node/services/streamManager.ts

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -334,22 +334,42 @@ export class StreamManager extends EventEmitter {
334334
private async getStreamMetadata(
335335
streamInfo: WorkspaceStreamInfo,
336336
timeoutMs = 1000
337-
): Promise<{ usage?: LanguageModelV2Usage; duration: number }> {
338-
let usage = undefined;
337+
): Promise<{
338+
totalUsage?: LanguageModelV2Usage;
339+
contextUsage?: LanguageModelV2Usage;
340+
contextProviderMetadata?: Record<string, unknown>;
341+
duration: number;
342+
}> {
343+
let totalUsage: LanguageModelV2Usage | undefined;
344+
let contextUsage: LanguageModelV2Usage | undefined;
345+
let contextProviderMetadata: Record<string, unknown> | undefined;
346+
339347
try {
340-
// Race usage retrieval against timeout to prevent hanging on abort
341-
// CRITICAL: Use totalUsage (sum of all steps) not usage (last step only)
342-
// For multi-step tool calls, usage would severely undercount actual token consumption
343-
usage = await Promise.race([
344-
streamInfo.streamResult.totalUsage,
345-
new Promise<undefined>((resolve) => setTimeout(() => resolve(undefined), timeoutMs)),
348+
// Fetch all metadata in parallel with timeout
349+
// - totalUsage: sum of all steps (for cost calculation)
350+
// - usage: last step only (for context window display)
351+
// - providerMetadata: last step (for context window cache display)
352+
const [total, context, contextMeta] = await Promise.race([
353+
Promise.all([
354+
streamInfo.streamResult.totalUsage,
355+
streamInfo.streamResult.usage,
356+
streamInfo.streamResult.providerMetadata,
357+
]),
358+
new Promise<[undefined, undefined, undefined]>((resolve) =>
359+
setTimeout(() => resolve([undefined, undefined, undefined]), timeoutMs)
360+
),
346361
]);
362+
totalUsage = total;
363+
contextUsage = context;
364+
contextProviderMetadata = contextMeta;
347365
} catch (error) {
348-
log.debug("Could not retrieve usage:", error);
366+
log.debug("Could not retrieve stream metadata:", error);
349367
}
350368

351369
return {
352-
usage,
370+
totalUsage,
371+
contextUsage,
372+
contextProviderMetadata,
353373
duration: Date.now() - streamInfo.startTime,
354374
};
355375
}
@@ -1071,17 +1091,20 @@ export class StreamManager extends EventEmitter {
10711091

10721092
// Check if stream completed successfully
10731093
if (!streamInfo.abortController.signal.aborted) {
1074-
// Get usage, duration, and provider metadata from stream result
1075-
// CRITICAL: Use totalUsage (via getStreamMetadata) and aggregated providerMetadata
1076-
// to correctly account for all steps in multi-tool-call conversations
1077-
const { usage, duration } = await this.getStreamMetadata(streamInfo);
1094+
// Get all metadata from stream result in one call
1095+
// - totalUsage: sum of all steps (for cost calculation)
1096+
// - contextUsage: last step only (for context window display)
1097+
// - contextProviderMetadata: last step (for context window cache tokens)
1098+
// Falls back to tracked values from finish-step if streamResult fails/times out
1099+
const streamMeta = await this.getStreamMetadata(streamInfo);
1100+
const totalUsage = streamMeta.totalUsage;
1101+
const contextUsage = streamMeta.contextUsage ?? streamInfo.lastStepUsage;
1102+
const contextProviderMetadata =
1103+
streamMeta.contextProviderMetadata ?? streamInfo.lastStepProviderMetadata;
1104+
const duration = streamMeta.duration;
1105+
// Aggregated provider metadata across all steps (for cost calculation with cache tokens)
10781106
const providerMetadata = await this.getAggregatedProviderMetadata(streamInfo);
10791107

1080-
// For context window display, use last step's usage (inputTokens = current context size)
1081-
// This is stored in streamInfo during finish-step handling
1082-
const contextUsage = streamInfo.lastStepUsage;
1083-
const contextProviderMetadata = streamInfo.lastStepProviderMetadata;
1084-
10851108
// Emit stream end event with parts preserved in temporal order
10861109
const streamEndEvent: StreamEndEvent = {
10871110
type: "stream-end",
@@ -1090,7 +1113,7 @@ export class StreamManager extends EventEmitter {
10901113
metadata: {
10911114
...streamInfo.initialMetadata, // AIService-provided metadata (systemMessageTokens, etc)
10921115
model: streamInfo.model,
1093-
usage, // Total across all steps (for cost calculation)
1116+
usage: totalUsage, // Total across all steps (for cost calculation)
10941117
contextUsage, // Last step only (for context window display)
10951118
providerMetadata, // Aggregated (for cost calculation)
10961119
contextProviderMetadata, // Last step (for context window display)

0 commit comments

Comments
 (0)