-
Notifications
You must be signed in to change notification settings - Fork 517
feat: track token usage while streaming responses for openai models #750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🦋 Changeset detectedLatest commit: c9b3f7c The changes in this PR will be included in the next version bump. This PR includes changesets to release 5 packages
Not sure what this means? Click here to learn what changesets are. Click here if you're a maintainer who wants to add another changeset to this PR |
seratch
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a nice enhancement but we'd like to hold off adding new events for this purpose. How about adding the usage data this way? Introducing a new event could simplify your code but we'd like to start with this primitive first.
diff --git a/packages/agents-core/src/run.ts b/packages/agents-core/src/run.ts
index f376673..4c09307 100644
--- a/packages/agents-core/src/run.ts
+++ b/packages/agents-core/src/run.ts
@@ -1183,6 +1183,7 @@ export class Runner extends RunHooks<any, AgentOutputType<unknown>> {
output: parsed.response.output,
responseId: parsed.response.id,
};
+ result.state._context.usage.add(finalResponse.usage);
}
if (result.cancelled) {
// When the user's code exits a loop to consume the stream, we need to break
diff --git a/packages/agents-core/test/run.stream.test.ts b/packages/agents-core/test/run.stream.test.ts
index 1a4893e..692dea0 100644
--- a/packages/agents-core/test/run.stream.test.ts
+++ b/packages/agents-core/test/run.stream.test.ts
@@ -210,6 +210,181 @@ describe('Runner.run (streaming)', () => {
expect(runnerEndEvents[0].output).toBe('Final output');
});
+ it('updates cumulative usage during streaming responses', async () => {
+ const testTool = tool({
+ name: 'calculator',
+ description: 'Does math',
+ parameters: z.object({ value: z.number() }),
+ execute: async ({ value }) => `result: ${value * 2}`,
+ });
+
+ const firstResponse: ModelResponse = {
+ output: [
+ {
+ type: 'function_call',
+ id: 'fc_1',
+ callId: 'call_1',
+ name: 'calculator',
+ status: 'completed',
+ arguments: JSON.stringify({ value: 5 }),
+ } as protocol.FunctionCallItem,
+ ],
+ usage: new Usage({ inputTokens: 10, outputTokens: 5, totalTokens: 15 }),
+ };
+
+ const secondResponse: ModelResponse = {
+ output: [fakeModelMessage('The answer is 10')],
+ usage: new Usage({ inputTokens: 20, outputTokens: 10, totalTokens: 30 }),
+ };
+
+ class MultiTurnStreamingModel implements Model {
+ #callCount = 0;
+
+ async getResponse(_req: ModelRequest): Promise<ModelResponse> {
+ const current = this.#callCount++;
+ return current === 0 ? firstResponse : secondResponse;
+ }
+
+ async *getStreamedResponse(
+ req: ModelRequest,
+ ): AsyncIterable<StreamEvent> {
+ const response = await this.getResponse(req);
+ yield {
+ type: 'response_done',
+ response: {
+ id: `r_${this.#callCount}`,
+ usage: {
+ requests: 1,
+ inputTokens: response.usage.inputTokens,
+ outputTokens: response.usage.outputTokens,
+ totalTokens: response.usage.totalTokens,
+ },
+ output: response.output,
+ },
+ } as any;
+ }
+ }
+
+ const agent = new Agent({
+ name: 'UsageTracker',
+ model: new MultiTurnStreamingModel(),
+ tools: [testTool],
+ });
+
+ const runner = new Runner();
+ const result = await runner.run(agent, 'calculate', { stream: true });
+
+ const totals: number[] = [];
+ for await (const event of result.toStream()) {
+ if (
+ event.type === 'raw_model_stream_event' &&
+ event.data.type === 'response_done'
+ ) {
+ totals.push(result.state.usage.totalTokens);
+ }
+ }
+ await result.completed;
+
+ expect(totals).toEqual([15, 45]);
+ expect(result.state.usage.inputTokens).toBe(30);
+ expect(result.state.usage.outputTokens).toBe(15);
+ expect(result.state.usage.requestUsageEntries?.length).toBe(2);
+ expect(result.finalOutput).toBe('The answer is 10');
+ });
+
+ it('allows aborting a stream based on cumulative usage', async () => {
+ const testTool = tool({
+ name: 'expensive',
+ description: 'Uses lots of tokens',
+ parameters: z.object({}),
+ execute: async () => 'expensive result',
+ });
+
+ const responses: ModelResponse[] = [
+ {
+ output: [
+ {
+ type: 'function_call',
+ id: 'fc_1',
+ callId: 'call_1',
+ name: 'expensive',
+ status: 'completed',
+ arguments: '{}',
+ } as protocol.FunctionCallItem,
+ ],
+ usage: new Usage({
+ inputTokens: 5000,
+ outputTokens: 2000,
+ totalTokens: 7000,
+ }),
+ },
+ {
+ output: [fakeModelMessage('continuing...')],
+ usage: new Usage({
+ inputTokens: 6000,
+ outputTokens: 3000,
+ totalTokens: 9000,
+ }),
+ },
+ ];
+
+ class ExpensiveStreamingModel implements Model {
+ #callCount = 0;
+
+ async getResponse(_req: ModelRequest): Promise<ModelResponse> {
+ return responses[this.#callCount++] ?? responses[responses.length - 1];
+ }
+
+ async *getStreamedResponse(
+ req: ModelRequest,
+ ): AsyncIterable<StreamEvent> {
+ const response = await this.getResponse(req);
+ yield {
+ type: 'response_done',
+ response: {
+ id: `r_${this.#callCount}`,
+ usage: {
+ requests: 1,
+ inputTokens: response.usage.inputTokens,
+ outputTokens: response.usage.outputTokens,
+ totalTokens: response.usage.totalTokens,
+ },
+ output: response.output,
+ },
+ } as any;
+ }
+ }
+
+ const agent = new Agent({
+ name: 'ExpensiveAgent',
+ model: new ExpensiveStreamingModel(),
+ tools: [testTool],
+ });
+
+ const runner = new Runner();
+ const result = await runner.run(agent, 'do expensive work', {
+ stream: true,
+ });
+
+ const MAX_TOKENS = 10_000;
+ let aborted = false;
+
+ for await (const event of result.toStream()) {
+ if (
+ event.type === 'raw_model_stream_event' &&
+ event.data.type === 'response_done' &&
+ result.state.usage.totalTokens > MAX_TOKENS
+ ) {
+ aborted = true;
+ break;
+ }
+ }
+
+ expect(aborted).toBe(true);
+ expect(result.state.usage.totalTokens).toBe(16_000);
+ expect(result.finalOutput).toBeUndefined();
+ });
+
it('streams tool_called before the tool finishes executing', async () => {
let releaseTool: (() => void) | undefined;
const toolExecuted = vi.fn();
diff --git a/packages/agents-openai/src/openaiChatCompletionsModel.ts b/packages/agents-openai/src/openaiChatCompletionsModel.ts
index 34d77fc..efa434b 100644
--- a/packages/agents-openai/src/openaiChatCompletionsModel.ts
+++ b/packages/agents-openai/src/openaiChatCompletionsModel.ts
@@ -344,6 +344,7 @@ export class OpenAIChatCompletionsModel implements Model {
response_format: responseFormat,
parallel_tool_calls: parallelToolCalls,
stream,
+ stream_options: stream ? { include_usage: true } : undefined,
store: request.modelSettings.store,
prompt_cache_retention: request.modelSettings.promptCacheRetention,
...providerData,…ge silently in run state context
|
True, makes sense. The new event is unnecessary as long as the state context usage gets updated since the values can be read from the generator function. This works nicely in my code too since I already have something similar setup. |
Emits a new event while streaming responses for tracking token usage in real time. Currently confirmed working with openai models only. I tested openai & lmstudio models through the ai sdk. I suspect it will work with any provider that implements the necessary endpoints in the openai api spec, though I have not confirmed this since this is all I need to continue my project. Though it doesnt actually calculate token usage for my models on lmstudio (returns zeroes across the
UsageDataobject), it may work as a place to put my own usage tracking code when the events get fired. There are a few new tests for this implementation.Edit: I have confirmed the usage tracking actually does work for the lmstudio provider in this package, which helps confirm my suspicion. It started working after I added the
includeUsage: trueoption to thecreateOpenAICompatiblefunction from the ai sdk before passing it to the agent class as the modelI also removed the ToolOptions export from my earlier pr (#704) since the
ToolOptionstype requires a few additional generic types that are not exported from the main package. When I tried usingToolOptions<any>it prevented the params for theexecutefield from being inferred from the zod object passed to theparametersfield of the sameToolOptionsobject. I messed around with it for some time to find a better solution in a separate pr, but I couldn't figure it out so it seemed best to remove the export entirely. Since providing a proper default type didn't work and the generic types required were not exported from the main package either, it was ultimately more annoying than it was helpful. When I made that pr I hadn't figured out a way to test my forked package locally in my project yet, so I was assuming it would work. I can move this change to a separate pr if preferred, but since it's only one line I included it here.