From b1c742b461153ca113520fe7c1ef59cb44157e47 Mon Sep 17 00:00:00 2001 From: mchen Date: Mon, 16 Mar 2026 19:02:27 -0400 Subject: [PATCH 1/2] feat: add sessionAffinity setting for prefix-cache optimization Add sessionAffinity option to WorkersAIChatSettings that sends an x-session-affinity header with inference requests. This routes requests with the same key to the same backend replica, improving prefix-cache hit rates across conversation turns. - Binding path: sessionAffinity is passed as extraHeaders to binding.run() - REST path: extraHeaders are now forwarded in fetch headers instead of being discarded --- .changeset/session-affinity-header.md | 5 + packages/workers-ai-provider/src/utils.ts | 7 +- .../src/workersai-chat-language-model.ts | 6 +- .../src/workersai-chat-settings.ts | 6 + .../test/text-generation.test.ts | 105 ++++++++++++++++++ 5 files changed, 126 insertions(+), 3 deletions(-) create mode 100644 .changeset/session-affinity-header.md diff --git a/.changeset/session-affinity-header.md b/.changeset/session-affinity-header.md new file mode 100644 index 000000000..c2fdfeef1 --- /dev/null +++ b/.changeset/session-affinity-header.md @@ -0,0 +1,5 @@ +--- +"workers-ai-provider": minor +--- + +Add `sessionAffinity` setting to send `x-session-affinity` header for prefix-cache optimization. Also forward `extraHeaders` in the REST API path instead of discarding them. diff --git a/packages/workers-ai-provider/src/utils.ts b/packages/workers-ai-provider/src/utils.ts index 8bd5405f5..372399b16 100644 --- a/packages/workers-ai-provider/src/utils.ts +++ b/packages/workers-ai-provider/src/utils.ts @@ -110,7 +110,7 @@ export function createRun(config: CreateRunConfig): AiRun { const { gateway: _gateway, prefix: _prefix, - extraHeaders: _extraHeaders, + extraHeaders, returnRawResponse, signal, // AbortSignal — not serializable as a query parameter ...passthroughOptions @@ -141,9 +141,12 @@ export function createRun(config: CreateRunConfig): AiRun { queryString ? `?${queryString}` : "" }`; - const headers = { + const headers: Record = { Authorization: `Bearer ${apiKey}`, "Content-Type": "application/json", + ...(extraHeaders && typeof extraHeaders === "object" + ? (extraHeaders as Record) + : {}), }; const body = JSON.stringify(inputs); diff --git a/packages/workers-ai-provider/src/workersai-chat-language-model.ts b/packages/workers-ai-provider/src/workersai-chat-language-model.ts index 6c5147530..a434555c2 100644 --- a/packages/workers-ai-provider/src/workersai-chat-language-model.ts +++ b/packages/workers-ai-provider/src/workersai-chat-language-model.ts @@ -153,9 +153,13 @@ export class WorkersAIChatLanguageModel implements LanguageModelV3 { * Get passthrough options for binding.run() from settings. */ private getRunOptions() { - const { gateway, safePrompt: _safePrompt, ...passthroughOptions } = this.settings; + const { gateway, safePrompt: _safePrompt, sessionAffinity, ...passthroughOptions } = + this.settings; return { gateway: this.config.gateway ?? gateway, + ...(sessionAffinity + ? { extraHeaders: { "x-session-affinity": sessionAffinity } } + : {}), ...passthroughOptions, }; } diff --git a/packages/workers-ai-provider/src/workersai-chat-settings.ts b/packages/workers-ai-provider/src/workersai-chat-settings.ts index b360c13ab..c99370fc0 100644 --- a/packages/workers-ai-provider/src/workersai-chat-settings.ts +++ b/packages/workers-ai-provider/src/workersai-chat-settings.ts @@ -10,6 +10,12 @@ export type WorkersAIChatSettings = { */ gateway?: GatewayOptions; + /** + * Session affinity key for prefix-cache optimization. + * Routes requests with the same key to the same backend replica. + */ + sessionAffinity?: string; + /** * Passthrough settings that are provided directly to the run function. * Use this for any provider-specific options not covered by the typed fields. diff --git a/packages/workers-ai-provider/test/text-generation.test.ts b/packages/workers-ai-provider/test/text-generation.test.ts index 1ff33fc82..e6a2be5dc 100644 --- a/packages/workers-ai-provider/test/text-generation.test.ts +++ b/packages/workers-ai-provider/test/text-generation.test.ts @@ -73,6 +73,64 @@ describe("REST API - Text Generation Tests", () => { expect(capturedOptions).toHaveProperty("aNumber", "1"); }); + it("should send x-session-affinity header when sessionAffinity is set", async () => { + let capturedHeaders: Record = {}; + + const workersai = createWorkersAI({ + accountId: TEST_ACCOUNT_ID, + apiKey: TEST_API_KEY, + }); + + server.use( + http.post( + `https://api.cloudflare.com/client/v4/accounts/${TEST_ACCOUNT_ID}/ai/run/${TEST_MODEL}`, + async ({ request }) => { + capturedHeaders = Object.fromEntries(request.headers.entries()); + return HttpResponse.json({ result: { response: "Hello" } }); + }, + ), + ); + + const model = workersai(TEST_MODEL, { + sessionAffinity: "session-123", + }); + + const result = await generateText({ + model: model, + prompt: "Write a greeting", + }); + + expect(result.text).toBe("Hello"); + expect(capturedHeaders["x-session-affinity"]).toBe("session-123"); + }); + + it("should not send x-session-affinity header when sessionAffinity is not set", async () => { + let capturedHeaders: Record = {}; + + const workersai = createWorkersAI({ + accountId: TEST_ACCOUNT_ID, + apiKey: TEST_API_KEY, + }); + + server.use( + http.post( + `https://api.cloudflare.com/client/v4/accounts/${TEST_ACCOUNT_ID}/ai/run/${TEST_MODEL}`, + async ({ request }) => { + capturedHeaders = Object.fromEntries(request.headers.entries()); + return HttpResponse.json({ result: { response: "Hello" } }); + }, + ), + ); + + const result = await generateText({ + model: workersai(TEST_MODEL), + prompt: "Write a greeting", + }); + + expect(result.text).toBe("Hello"); + expect(capturedHeaders["x-session-affinity"]).toBeUndefined(); + }); + it("should throw if passthrough option cannot be coerced into a string", async () => { const workersai = createWorkersAI({ accountId: TEST_ACCOUNT_ID, @@ -226,6 +284,53 @@ describe("Binding - Text Generation Tests", () => { expect(capturedOptions).toHaveProperty("aNumber", 1); }); + it("should pass extraHeaders with x-session-affinity when sessionAffinity is set", async () => { + let capturedOptions: any = null; + + const workersai = createWorkersAI({ + binding: { + run: async (_modelName: string, _inputs: any, options?: any) => { + capturedOptions = options; + return { response: "Hello" }; + }, + }, + }); + + const model = workersai(TEST_MODEL, { + sessionAffinity: "session-456", + }); + + const result = await generateText({ + model: model, + prompt: "Write a greeting", + }); + + expect(result.text).toBe("Hello"); + expect(capturedOptions).toHaveProperty("extraHeaders"); + expect(capturedOptions.extraHeaders).toEqual({ "x-session-affinity": "session-456" }); + }); + + it("should not pass extraHeaders when sessionAffinity is not set", async () => { + let capturedOptions: any = null; + + const workersai = createWorkersAI({ + binding: { + run: async (_modelName: string, _inputs: any, options?: any) => { + capturedOptions = options; + return { response: "Hello" }; + }, + }, + }); + + const result = await generateText({ + model: workersai(TEST_MODEL), + prompt: "Write a greeting", + }); + + expect(result.text).toBe("Hello"); + expect(capturedOptions).not.toHaveProperty("extraHeaders"); + }); + it("should pass tool_choice to binding.run()", async () => { let capturedInputs: any = null; From 414b4d5eb9ecec3e69b39c03fb4532afc1fae015 Mon Sep 17 00:00:00 2001 From: Sunil Pai Date: Wed, 18 Mar 2026 19:58:37 +0000 Subject: [PATCH 2/2] Add sessionAffinity header support Add a sessionAffinity option across Workers AI adapters/providers to route requests with the same key to the same backend replica via the x-session-affinity header for prefix-cache optimization. Implementation details: - Extend WorkersAiAdapterConfig with an optional sessionAffinity string. - Propagate sessionAffinity as x-session-affinity to binding.run() via createWorkersAiBindingFetch(extraHeaders), to REST requests via defaultHeaders, and to gateway mode via createGatewayFetch call. - Merge sessionAffinity with user-provided extraHeaders in the WorkersAI provider so both headers are forwarded together. Other changes: - Add and update tests covering binding.fetch, adapter behavior, and REST/binding header merging. - Update README docs for tanstack-ai and workers-ai-provider to document sessionAffinity usage. - Add changeset files to trigger a patch release for the relevant packages and minor formatting updates to demos.json. --- .changeset/session-affinity-header.md | 2 +- .changeset/tanstack-session-affinity.md | 5 + demos.json | 222 +++++++++--------- packages/tanstack-ai/README.md | 9 + .../tanstack-ai/src/adapters/workers-ai.ts | 19 +- .../tanstack-ai/src/utils/create-fetcher.ts | 22 +- .../tanstack-ai/test/binding-fetch.test.ts | 38 +++ .../test/workers-ai-adapter.test.ts | 37 +++ packages/workers-ai-provider/README.md | 15 +- packages/workers-ai-provider/src/streaming.ts | 4 +- .../src/workersai-chat-language-model.ts | 21 +- .../test/text-generation.test.ts | 62 +++++ 12 files changed, 325 insertions(+), 131 deletions(-) create mode 100644 .changeset/tanstack-session-affinity.md diff --git a/.changeset/session-affinity-header.md b/.changeset/session-affinity-header.md index c2fdfeef1..d97cd5258 100644 --- a/.changeset/session-affinity-header.md +++ b/.changeset/session-affinity-header.md @@ -1,5 +1,5 @@ --- -"workers-ai-provider": minor +"workers-ai-provider": patch --- Add `sessionAffinity` setting to send `x-session-affinity` header for prefix-cache optimization. Also forward `extraHeaders` in the REST API path instead of discarding them. diff --git a/.changeset/tanstack-session-affinity.md b/.changeset/tanstack-session-affinity.md new file mode 100644 index 000000000..db03b86a1 --- /dev/null +++ b/.changeset/tanstack-session-affinity.md @@ -0,0 +1,5 @@ +--- +"@cloudflare/tanstack-ai": patch +--- + +Add `sessionAffinity` option to `WorkersAiAdapterConfig` for prefix-cache optimization. Routes requests with the same key to the same backend replica via the `x-session-affinity` header. Supported across binding, REST, and gateway modes. diff --git a/demos.json b/demos.json index 458c57581..59685e00d 100644 --- a/demos.json +++ b/demos.json @@ -1,112 +1,112 @@ { - "demos": { - "./demos/agent-scheduler": { - "package_json_hash": "2fe8785345d56ff37e675baaa06380c1eed736ba" - }, - "./demos/agent-task-manager": { - "package_json_hash": "f2245bb30e95d0785aa195635a0928fba5b621ae" - }, - "./demos/agent-task-manager-human-in-the-loop": { - "package_json_hash": "d8856a52bd0bf9641b8d1ff98e06de650aece34a" - }, - "./demos/evaluator-optimiser": { - "package_json_hash": "8c4e9a71c91d806dcbef586b50a594e650d8f090" - }, - "./demos/image-generation": { - "package_json_hash": "697d55539ad024faa349faa3dcd3bcbdfacf37bc" - }, - "./demos/mcp-client": { - "package_json_hash": "5129c9edfcdd03c8625615c85329d85138fb9773" - }, - "./demos/mcp-server-bearer-auth": { - "package_json_hash": "8703a8f8992a06377cce9a139ce6709450b51b5c" - }, - "./demos/mcp-slack-oauth": { - "package_json_hash": "3134658fb11397626329bc344eba48fd57b21d46" - }, - "./demos/mcp-stytch-b2b-okr-manager": { - "package_json_hash": "48232b81779a5f5fb0253842b4243c7dad032c0c" - }, - "./demos/mcp-stytch-consumer-todo-list": { - "package_json_hash": "f53fe23dcebec62f51f9a6e332d2c192b8598cf6" - }, - "./demos/model-scraper": { - "package_json_hash": "5a20ad46b257699c313bdd7c0b520701d739ed12" - }, - "./demos/orchestrator-workers": { - "package_json_hash": "e159d1ce03c17bf13239ee4ac76c0290a210bc38" - }, - "./demos/parallelisation": { - "package_json_hash": "6dbc55c3277b3ea634776821e60642b3dd03d8c0" - }, - "./demos/prompt-chaining": { - "package_json_hash": "510159b05545a2d7f9c8cb240def56649cd25989" - }, - "./demos/remote-mcp-authkit": { - "package_json_hash": "d3a0122c45d27140db96df6859e191aa7d2f8ac1" - }, - "./demos/remote-mcp-github-oauth": { - "package_json_hash": "c59a2ecc4937d54c658383c3d7fe95e7c123f5c1" - }, - "./demos/remote-mcp-server": { - "package_json_hash": "6240672fd54010c3b03a8af553b420306e11bc78" - }, - "./demos/routing": { - "package_json_hash": "5f547b98f4e9a6167a2913e3a6c61681312986dd" - }, - "./demos/structured-output": { - "package_json_hash": "a66aacd49c57e74c0937bf4bea0986168086debb" - }, - "./demos/structured-output-node": { - "package_json_hash": "f64cc27508f9dda6fbb3bf4192c031dcc671e64a" - }, - "./demos/text-generation": { - "package_json_hash": "d52767521e285b05c3235eaf2c8cc0e47fdbf90d" - }, - "./demos/text-generation-stream": { - "package_json_hash": "f8272f5b1f5f1c83c53395dfc76646cab18a32b7" - }, - "./demos/tool-calling": { - "package_json_hash": "3a0b1d91022d706b96e7b429c1349116ba9373b5" - }, - "./demos/tool-calling-stream": { - "package_json_hash": "7c92250cda46aaac7eb6aeea0255828781c4abcb" - }, - "./demos/tool-calling-stream-traditional": { - "package_json_hash": "c610c334d5f53a6e399bddddf68098ca0dec96d7" - }, - "./demos/ui-worker": { - "package_json_hash": "831702fff4771ce9ce7d93afe6824ec6fa316125" - }, - "./demos/remote-mcp-cf-access": { - "package_json_hash": "1a09d449c88cfe3b989f352d18813385578b98ca" - }, - "./demos/remote-mcp-authless": { - "package_json_hash": "ba9953ce57a26cb271144e67609ed98fd1c1110e" - }, - "./demos/python-workers-mcp": { - "package_json_hash": "0e710d7b27bb34edba396dc2b3365db230c076cb" - }, - "./demos/vision": { - "package_json_hash": "e53450d50753f0574995feef3b2f845045fc3dc3" - }, - "./demos/remote-mcp-google-oauth": { - "package_json_hash": "21bdab2ebbbe336c5fe6fb032fde804373f1b489" - }, - "./demos/remote-mcp-logto": { - "package_json_hash": "a98a0cb367641ff86d89a7127f5e2551d2a1532f" - }, - "./demos/remote-mcp-server-descope-auth": { - "package_json_hash": "c5de845803aae734fa60185200d4bfa2e1d0fb23" - }, - "./demos/remote-mcp-server-autorag": { - "package_json_hash": "2b4e9b35192362b3be2743370469ce3a627a72b0" - }, - "./demos/use-mcp-inspector": { - "package_json_hash": "d1d084f1aa9a752ead5250e0a070f97a9114dcea" - }, - "./demos/hello-world": { - "package_json_hash": "ab24a12893c001fe3416fadea2a8bf5e7e68392e" - } - } -} \ No newline at end of file + "demos": { + "./demos/agent-scheduler": { + "package_json_hash": "2fe8785345d56ff37e675baaa06380c1eed736ba" + }, + "./demos/agent-task-manager": { + "package_json_hash": "f2245bb30e95d0785aa195635a0928fba5b621ae" + }, + "./demos/agent-task-manager-human-in-the-loop": { + "package_json_hash": "d8856a52bd0bf9641b8d1ff98e06de650aece34a" + }, + "./demos/evaluator-optimiser": { + "package_json_hash": "8c4e9a71c91d806dcbef586b50a594e650d8f090" + }, + "./demos/image-generation": { + "package_json_hash": "697d55539ad024faa349faa3dcd3bcbdfacf37bc" + }, + "./demos/mcp-client": { + "package_json_hash": "5129c9edfcdd03c8625615c85329d85138fb9773" + }, + "./demos/mcp-server-bearer-auth": { + "package_json_hash": "8703a8f8992a06377cce9a139ce6709450b51b5c" + }, + "./demos/mcp-slack-oauth": { + "package_json_hash": "3134658fb11397626329bc344eba48fd57b21d46" + }, + "./demos/mcp-stytch-b2b-okr-manager": { + "package_json_hash": "48232b81779a5f5fb0253842b4243c7dad032c0c" + }, + "./demos/mcp-stytch-consumer-todo-list": { + "package_json_hash": "f53fe23dcebec62f51f9a6e332d2c192b8598cf6" + }, + "./demos/model-scraper": { + "package_json_hash": "5a20ad46b257699c313bdd7c0b520701d739ed12" + }, + "./demos/orchestrator-workers": { + "package_json_hash": "e159d1ce03c17bf13239ee4ac76c0290a210bc38" + }, + "./demos/parallelisation": { + "package_json_hash": "6dbc55c3277b3ea634776821e60642b3dd03d8c0" + }, + "./demos/prompt-chaining": { + "package_json_hash": "510159b05545a2d7f9c8cb240def56649cd25989" + }, + "./demos/remote-mcp-authkit": { + "package_json_hash": "d3a0122c45d27140db96df6859e191aa7d2f8ac1" + }, + "./demos/remote-mcp-github-oauth": { + "package_json_hash": "c59a2ecc4937d54c658383c3d7fe95e7c123f5c1" + }, + "./demos/remote-mcp-server": { + "package_json_hash": "6240672fd54010c3b03a8af553b420306e11bc78" + }, + "./demos/routing": { + "package_json_hash": "5f547b98f4e9a6167a2913e3a6c61681312986dd" + }, + "./demos/structured-output": { + "package_json_hash": "a66aacd49c57e74c0937bf4bea0986168086debb" + }, + "./demos/structured-output-node": { + "package_json_hash": "f64cc27508f9dda6fbb3bf4192c031dcc671e64a" + }, + "./demos/text-generation": { + "package_json_hash": "d52767521e285b05c3235eaf2c8cc0e47fdbf90d" + }, + "./demos/text-generation-stream": { + "package_json_hash": "f8272f5b1f5f1c83c53395dfc76646cab18a32b7" + }, + "./demos/tool-calling": { + "package_json_hash": "3a0b1d91022d706b96e7b429c1349116ba9373b5" + }, + "./demos/tool-calling-stream": { + "package_json_hash": "7c92250cda46aaac7eb6aeea0255828781c4abcb" + }, + "./demos/tool-calling-stream-traditional": { + "package_json_hash": "c610c334d5f53a6e399bddddf68098ca0dec96d7" + }, + "./demos/ui-worker": { + "package_json_hash": "831702fff4771ce9ce7d93afe6824ec6fa316125" + }, + "./demos/remote-mcp-cf-access": { + "package_json_hash": "1a09d449c88cfe3b989f352d18813385578b98ca" + }, + "./demos/remote-mcp-authless": { + "package_json_hash": "ba9953ce57a26cb271144e67609ed98fd1c1110e" + }, + "./demos/python-workers-mcp": { + "package_json_hash": "0e710d7b27bb34edba396dc2b3365db230c076cb" + }, + "./demos/vision": { + "package_json_hash": "e53450d50753f0574995feef3b2f845045fc3dc3" + }, + "./demos/remote-mcp-google-oauth": { + "package_json_hash": "21bdab2ebbbe336c5fe6fb032fde804373f1b489" + }, + "./demos/remote-mcp-logto": { + "package_json_hash": "a98a0cb367641ff86d89a7127f5e2551d2a1532f" + }, + "./demos/remote-mcp-server-descope-auth": { + "package_json_hash": "c5de845803aae734fa60185200d4bfa2e1d0fb23" + }, + "./demos/remote-mcp-server-autorag": { + "package_json_hash": "2b4e9b35192362b3be2743370469ce3a627a72b0" + }, + "./demos/use-mcp-inspector": { + "package_json_hash": "d1d084f1aa9a752ead5250e0a070f97a9114dcea" + }, + "./demos/hello-world": { + "package_json_hash": "ab24a12893c001fe3416fadea2a8bf5e7e68392e" + } + } +} diff --git a/packages/tanstack-ai/README.md b/packages/tanstack-ai/README.md index 6767eb687..e1273d80f 100644 --- a/packages/tanstack-ai/README.md +++ b/packages/tanstack-ai/README.md @@ -275,6 +275,15 @@ Workers AI supports four configuration modes: Third-party providers (OpenAI, Anthropic, Gemini, Grok, OpenRouter) only support the gateway modes. +All Workers AI config modes also accept `sessionAffinity` to route requests with the same key to the same backend replica for prefix-cache optimization: + +```typescript +const adapter = createWorkersAiChat("@cf/meta/llama-3.3-70b-instruct-fp8-fast", { + binding: env.AI, + sessionAffinity: "my-unique-session-id", +}); +``` + ## Links - [TanStack AI Documentation](https://tanstack.com/ai) diff --git a/packages/tanstack-ai/src/adapters/workers-ai.ts b/packages/tanstack-ai/src/adapters/workers-ai.ts index 4432f8063..84f3bd12f 100644 --- a/packages/tanstack-ai/src/adapters/workers-ai.ts +++ b/packages/tanstack-ai/src/adapters/workers-ai.ts @@ -33,11 +33,18 @@ export type WorkersAiTextModel = function buildWorkersAiClient(config: WorkersAiAdapterConfig): OpenAI { validateWorkersAiConfig(config); + const sessionHeaders: Record | undefined = config.sessionAffinity + ? { "x-session-affinity": config.sessionAffinity } + : undefined; + if (isDirectBindingConfig(config)) { // Plain binding mode: shim translates OpenAI fetch calls to env.AI.run() return new OpenAI({ apiKey: "unused", - fetch: createWorkersAiBindingFetch(config.binding), + fetch: createWorkersAiBindingFetch( + config.binding, + sessionHeaders ? { extraHeaders: sessionHeaders } : undefined, + ), }); } @@ -46,13 +53,14 @@ function buildWorkersAiClient(config: WorkersAiAdapterConfig): OpenAI { return new OpenAI({ baseURL: `https://api.cloudflare.com/client/v4/accounts/${config.accountId}/ai/v1`, apiKey: config.apiKey, + defaultHeaders: sessionHeaders, }); } // Gateway mode (existing): use createGatewayFetch const gatewayConfig = config as AiGatewayAdapterConfig; return new OpenAI({ - fetch: createGatewayFetch("workers-ai", gatewayConfig), + fetch: createGatewayFetch("workers-ai", gatewayConfig, sessionHeaders), apiKey: gatewayConfig.apiKey ?? "unused", }); } @@ -377,11 +385,8 @@ export class WorkersAiTextAdapter extends Bas // Reasoning content (used by models like QwQ, DeepSeek R1, Kimi K2.5) // The OpenAI SDK doesn't type this field, but models send it as an extension. - const reasoningContent = ((delta as Record) - .reasoning_content ?? - (delta as Record).reasoning) as - | string - | undefined; + const reasoningContent = ((delta as Record).reasoning_content ?? + (delta as Record).reasoning) as string | undefined; if (reasoningContent) { // RUN_STARTED is already guaranteed by the guard above if (!hasEmittedStepStarted) { diff --git a/packages/tanstack-ai/src/utils/create-fetcher.ts b/packages/tanstack-ai/src/utils/create-fetcher.ts index 1524f4c17..b5fb9baf9 100644 --- a/packages/tanstack-ai/src/utils/create-fetcher.ts +++ b/packages/tanstack-ai/src/utils/create-fetcher.ts @@ -105,10 +105,17 @@ export interface WorkersAiDirectCredentialsConfig { * upstream provider), distinct from `cfApiKey` (used in the `cf-aig-authorization` * header for authenticated gateways). */ -export type WorkersAiAdapterConfig = +export type WorkersAiAdapterConfig = ( | WorkersAiDirectBindingConfig | WorkersAiDirectCredentialsConfig - | (AiGatewayAdapterConfig & { apiKey?: string }); + | (AiGatewayAdapterConfig & { apiKey?: string }) +) & { + /** + * Session affinity key for prefix-cache optimization. + * Routes requests with the same key to the same backend replica. + */ + sessionAffinity?: string; +}; // --------------------------------------------------------------------------- // Config detection helpers @@ -330,7 +337,10 @@ function sanitizeToolCallId(id: string): string { * request parameters are extracted from the JSON body, matching Workers AI's * `binding.run(model, inputs)` calling convention. */ -export function createWorkersAiBindingFetch(binding: WorkersAiBinding): typeof fetch { +export function createWorkersAiBindingFetch( + binding: WorkersAiBinding, + options?: { extraHeaders?: Record }, +): typeof fetch { return async (_input, init) => { if (!init?.body) { return new Response("No body", { status: 400 }); @@ -359,7 +369,11 @@ export function createWorkersAiBindingFetch(binding: WorkersAiBinding): typeof f if (body.response_format) inputs.response_format = body.response_format; if (stream) inputs.stream = true; - const result = await binding.run(model, inputs); + const result = await binding.run( + model, + inputs, + options?.extraHeaders ? { extraHeaders: options.extraHeaders } : undefined, + ); if (stream && result instanceof ReadableStream) { // Workers AI returns an SSE stream with `data: {"response":"chunk"}` format. diff --git a/packages/tanstack-ai/test/binding-fetch.test.ts b/packages/tanstack-ai/test/binding-fetch.test.ts index 594c25c99..c4e715983 100644 --- a/packages/tanstack-ai/test/binding-fetch.test.ts +++ b/packages/tanstack-ai/test/binding-fetch.test.ts @@ -545,6 +545,44 @@ describe("createWorkersAiBindingFetch", () => { expect(json.choices[0]!.finish_reason).toBe("stop"); }); + it("should forward extraHeaders to binding.run() when configured", async () => { + const binding = mockBinding(vi.fn().mockResolvedValue({ response: "ok" })); + + const fetcher = createWorkersAiBindingFetch(binding, { + extraHeaders: { "x-session-affinity": "session-123" }, + }); + + await fetcher("https://api.openai.com/v1/chat/completions", { + method: "POST", + body: JSON.stringify({ + model: "@cf/meta/llama-3.3-70b-instruct-fp8-fast", + messages: [{ role: "user", content: "Hi" }], + }), + }); + + expect(binding.run).toHaveBeenCalledOnce(); + const [, , options] = binding.run.mock.calls[0]!; + expect(options).toEqual({ extraHeaders: { "x-session-affinity": "session-123" } }); + }); + + it("should not pass extraHeaders to binding.run() when not configured", async () => { + const binding = mockBinding(vi.fn().mockResolvedValue({ response: "ok" })); + + const fetcher = createWorkersAiBindingFetch(binding); + + await fetcher("https://api.openai.com/v1/chat/completions", { + method: "POST", + body: JSON.stringify({ + model: "@cf/meta/llama-3.3-70b-instruct-fp8-fast", + messages: [{ role: "user", content: "Hi" }], + }), + }); + + expect(binding.run).toHaveBeenCalledOnce(); + const [, , options] = binding.run.mock.calls[0]!; + expect(options).toBeUndefined(); + }); + it("should pass response_format to binding for structured output", async () => { const binding = mockBinding(vi.fn().mockResolvedValue({ response: '{"name":"test"}' })); diff --git a/packages/tanstack-ai/test/workers-ai-adapter.test.ts b/packages/tanstack-ai/test/workers-ai-adapter.test.ts index f70cea52b..4439e7b53 100644 --- a/packages/tanstack-ai/test/workers-ai-adapter.test.ts +++ b/packages/tanstack-ai/test/workers-ai-adapter.test.ts @@ -951,4 +951,41 @@ describe("WorkersAiTextAdapter config modes", () => { const adapter = new WorkersAiTextAdapter("@cf/my-org/custom-model-v1", { binding }); expect(adapter).toBeDefined(); }); + + it("should pass sessionAffinity as extraHeaders to binding.run()", async () => { + const binding = createStreamingBinding(['data: {"response":"ok"}\n\n']); + const adapter = new WorkersAiTextAdapter(MODEL, { + binding, + sessionAffinity: "my-session-id", + }); + + await collectChunks( + adapter.chatStream({ + model: MODEL, + messages: [{ role: "user", content: "Hi" }], + } as any), + ); + + expect(binding.run).toHaveBeenCalledOnce(); + const [, , options] = binding.run.mock.calls[0]!; + expect(options).toEqual({ + extraHeaders: { "x-session-affinity": "my-session-id" }, + }); + }); + + it("should not pass extraHeaders when sessionAffinity is not set", async () => { + const binding = createStreamingBinding(['data: {"response":"ok"}\n\n']); + const adapter = new WorkersAiTextAdapter(MODEL, { binding }); + + await collectChunks( + adapter.chatStream({ + model: MODEL, + messages: [{ role: "user", content: "Hi" }], + } as any), + ); + + expect(binding.run).toHaveBeenCalledOnce(); + const [, , options] = binding.run.mock.calls[0]!; + expect(options).toBeUndefined(); + }); }); diff --git a/packages/workers-ai-provider/README.md b/packages/workers-ai-provider/README.md index d6d9c5876..0b46a803f 100644 --- a/packages/workers-ai-provider/README.md +++ b/packages/workers-ai-provider/README.md @@ -287,7 +287,20 @@ Streaming works the same way — use `streamText` instead of `generateText`. | `apiKey` | `string` | Cloudflare API token. Required with `accountId`. | | `gateway` | `GatewayOptions` | Optional [AI Gateway](https://developers.cloudflare.com/ai-gateway/) config. | -Returns a provider with model factories: +Returns a provider with model factories. Each factory accepts an optional second argument for per-model settings: + +```ts +workersai("@cf/meta/llama-3.3-70b-instruct-fp8-fast", { + sessionAffinity: "my-unique-session-id", +}); +``` + +| Setting | Type | Description | +| ----------------- | --------- | -------------------------------------------------------------------------------------------- | +| `safePrompt` | `boolean` | Inject a safety prompt before all conversations. | +| `sessionAffinity` | `string` | Routes requests with the same key to the same backend replica for prefix-cache optimization. | + +Model factories: ```ts // Chat — for generateText / streamText diff --git a/packages/workers-ai-provider/src/streaming.ts b/packages/workers-ai-provider/src/streaming.ts index f8538f0a1..4fd900aac 100644 --- a/packages/workers-ai-provider/src/streaming.ts +++ b/packages/workers-ai-provider/src/streaming.ts @@ -164,8 +164,8 @@ export function getMappedStream( const delta = choices[0].delta; const reasoningDelta = (delta.reasoning_content ?? delta.reasoning) as - | string - | undefined; + | string + | undefined; if (reasoningDelta && reasoningDelta.length > 0) { if (!reasoningId) { reasoningId = generateId(); diff --git a/packages/workers-ai-provider/src/workersai-chat-language-model.ts b/packages/workers-ai-provider/src/workersai-chat-language-model.ts index a434555c2..43e7febd6 100644 --- a/packages/workers-ai-provider/src/workersai-chat-language-model.ts +++ b/packages/workers-ai-provider/src/workersai-chat-language-model.ts @@ -153,13 +153,24 @@ export class WorkersAIChatLanguageModel implements LanguageModelV3 { * Get passthrough options for binding.run() from settings. */ private getRunOptions() { - const { gateway, safePrompt: _safePrompt, sessionAffinity, ...passthroughOptions } = - this.settings; + const { + gateway, + safePrompt: _safePrompt, + sessionAffinity, + extraHeaders, + ...passthroughOptions + } = this.settings; + + const mergedHeaders = { + ...(extraHeaders && typeof extraHeaders === "object" + ? (extraHeaders as Record) + : {}), + ...(sessionAffinity ? { "x-session-affinity": sessionAffinity } : {}), + }; + return { gateway: this.config.gateway ?? gateway, - ...(sessionAffinity - ? { extraHeaders: { "x-session-affinity": sessionAffinity } } - : {}), + ...(Object.keys(mergedHeaders).length > 0 ? { extraHeaders: mergedHeaders } : {}), ...passthroughOptions, }; } diff --git a/packages/workers-ai-provider/test/text-generation.test.ts b/packages/workers-ai-provider/test/text-generation.test.ts index e6a2be5dc..7e77e7943 100644 --- a/packages/workers-ai-provider/test/text-generation.test.ts +++ b/packages/workers-ai-provider/test/text-generation.test.ts @@ -104,6 +104,39 @@ describe("REST API - Text Generation Tests", () => { expect(capturedHeaders["x-session-affinity"]).toBe("session-123"); }); + it("should merge sessionAffinity with user-provided extraHeaders", async () => { + let capturedHeaders: Record = {}; + + const workersai = createWorkersAI({ + accountId: TEST_ACCOUNT_ID, + apiKey: TEST_API_KEY, + }); + + server.use( + http.post( + `https://api.cloudflare.com/client/v4/accounts/${TEST_ACCOUNT_ID}/ai/run/${TEST_MODEL}`, + async ({ request }) => { + capturedHeaders = Object.fromEntries(request.headers.entries()); + return HttpResponse.json({ result: { response: "Hello" } }); + }, + ), + ); + + const model = workersai(TEST_MODEL, { + sessionAffinity: "session-123", + extraHeaders: { "x-custom-trace": "trace-abc" }, + }); + + const result = await generateText({ + model: model, + prompt: "Write a greeting", + }); + + expect(result.text).toBe("Hello"); + expect(capturedHeaders["x-session-affinity"]).toBe("session-123"); + expect(capturedHeaders["x-custom-trace"]).toBe("trace-abc"); + }); + it("should not send x-session-affinity header when sessionAffinity is not set", async () => { let capturedHeaders: Record = {}; @@ -310,6 +343,35 @@ describe("Binding - Text Generation Tests", () => { expect(capturedOptions.extraHeaders).toEqual({ "x-session-affinity": "session-456" }); }); + it("should merge sessionAffinity with user-provided extraHeaders", async () => { + let capturedOptions: any = null; + + const workersai = createWorkersAI({ + binding: { + run: async (_modelName: string, _inputs: any, options?: any) => { + capturedOptions = options; + return { response: "Hello" }; + }, + }, + }); + + const model = workersai(TEST_MODEL, { + sessionAffinity: "session-456", + extraHeaders: { "x-custom-trace": "trace-xyz" }, + }); + + const result = await generateText({ + model: model, + prompt: "Write a greeting", + }); + + expect(result.text).toBe("Hello"); + expect(capturedOptions.extraHeaders).toEqual({ + "x-custom-trace": "trace-xyz", + "x-session-affinity": "session-456", + }); + }); + it("should not pass extraHeaders when sessionAffinity is not set", async () => { let capturedOptions: any = null;