From bf9972a2eefa1acf5941244adb1e5eb00223cf79 Mon Sep 17 00:00:00 2001 From: Ben Perlmutter Date: Wed, 27 Aug 2025 10:38:54 -0400 Subject: [PATCH] support AI SDK stopWhen (loop behavior) --- .../src/responsesApi.test.ts | 71 +++++++++++++++++++ .../src/routes/responses/createResponse.ts | 4 +- 2 files changed, 73 insertions(+), 2 deletions(-) diff --git a/packages/chatbot-server-mongodb-public/src/responsesApi.test.ts b/packages/chatbot-server-mongodb-public/src/responsesApi.test.ts index 993a332d4..c347adb3a 100644 --- a/packages/chatbot-server-mongodb-public/src/responsesApi.test.ts +++ b/packages/chatbot-server-mongodb-public/src/responsesApi.test.ts @@ -7,11 +7,14 @@ import { createOpenAI, streamText, generateText, + stepCountIs, + tool, } from "mongodb-rag-core/aiSdk"; import { CREATE_RESPONSE_ERR_MSG } from "mongodb-chatbot-server"; import { OpenAI } from "mongodb-rag-core/openai"; import { makeTestApp } from "./test/testHelpers"; import { Logger, makeBraintrustLogger } from "mongodb-rag-core/braintrust"; +import { z } from "zod"; jest.setTimeout(100 * 1000); // 100 seconds @@ -414,6 +417,19 @@ describe("Responses API with OpenAI Client", () => { }); describe("AI SDK integration", () => { + const sampleToolName = "execute-code"; + const sampleToolResult = { + result: `[{id: 1, name: "Foo"}, {id: 2, name: "Bar"}]`, + }; + const sampleTool = tool({ + name: sampleToolName, + inputSchema: z.object({ + code: z.string(), + }), + execute: async () => { + return sampleToolResult; + }, + }); it("Should handle basic text streaming", async () => { const result = await streamText({ model: aiSDKClient.responses(MONGO_CHAT_MODEL), @@ -446,6 +462,61 @@ describe("Responses API with OpenAI Client", () => { expect(resultText.toLowerCase()).toContain("mongodb"); }); + it("should support stopWhen with multiple steps", async () => { + const result = streamText({ + model: aiSDKClient.responses(MONGO_CHAT_MODEL), + system: `Call the ${sampleToolName} when the user gives you code to execute in the subsequent message.`, + messages: [ + { + role: "user", + content: "Code to execute: db.users.find({}).limit(2).toArray()", + }, + ], + + tools: { + [sampleToolName]: sampleTool, + }, + stopWhen: [stepCountIs(2)], + toolChoice: { + type: "tool", + toolName: sampleToolName, + }, + prepareStep: ({ stepNumber }) => { + if (stepNumber > 0) { + return { + toolChoice: "auto", + }; + } + }, + }); + + const steps = await result.steps; + expect(steps.length).toBe(2); + const toolCallStepContent = steps[0].content; + expect(toolCallStepContent).toHaveLength(2); + const toolCall = toolCallStepContent[0]; + expect(toolCall).toMatchObject({ + type: "tool-call", + toolName: sampleToolName, + toolCallId: expect.any(String), + input: { + code: expect.any(String), + }, + }); + const toolResult = toolCallStepContent[1]; + expect(toolResult).toMatchObject({ + type: "tool-result", + toolCallId: expect.any(String), + output: sampleToolResult, + }); + const textStepContent = steps[1].content; + expect(textStepContent).toHaveLength(1); + const text = textStepContent[0]; + expect(text).toMatchObject({ + type: "text", + text: expect.any(String), + }); + }); it("Should throw an error when generating text since we don't support non-streaming generation", async () => { try { diff --git a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts index d6d5c073d..6e3510704 100644 --- a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts +++ b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts @@ -134,7 +134,7 @@ const FunctionCallSchema = z.object({ arguments: z .string() .describe("JSON string of arguments passed to the function tool call"), - status: z.enum(["in_progress", "completed", "incomplete"]), + status: z.enum(["in_progress", "completed", "incomplete"]).optional(), }); const FunctionCallOutputSchema = z.object({ @@ -147,7 +147,7 @@ const FunctionCallOutputSchema = z.object({ .string() .describe("Unique ID of the function tool call generated by the model"), output: z.string().describe("JSON string of the function tool call"), - status: z.enum(["in_progress", "completed", "incomplete"]), + status: z.enum(["in_progress", "completed", "incomplete"]).optional(), }); const CreateResponseRequestBodySchema = z.object({