From 19d7450f33f9f9fa443838975a97023180a0abb8 Mon Sep 17 00:00:00 2001 From: Hubert Zub Date: Tue, 5 May 2026 14:35:04 +0200 Subject: [PATCH 1/2] feat(appkit): supervisor api adapter --- apps/dev-playground/server/index.ts | 36 +- docs/docs/plugins/agents.md | 75 ++ packages/appkit/src/agents/supervisor-api.ts | 577 +++++++++++++++ .../src/agents/tests/supervisor-api.test.ts | 662 ++++++++++++++++++ packages/appkit/src/beta.ts | 10 + .../appkit/src/connectors/serving/client.ts | 63 +- packages/appkit/src/stream/index.ts | 1 + packages/appkit/src/stream/sse-reader.ts | 114 +++ .../src/stream/tests/sse-reader.test.ts | 182 +++++ 9 files changed, 1704 insertions(+), 16 deletions(-) create mode 100644 packages/appkit/src/agents/supervisor-api.ts create mode 100644 packages/appkit/src/agents/tests/supervisor-api.test.ts create mode 100644 packages/appkit/src/stream/sse-reader.ts create mode 100644 packages/appkit/src/stream/tests/sse-reader.test.ts diff --git a/apps/dev-playground/server/index.ts b/apps/dev-playground/server/index.ts index ecbd18e78..67187dcbe 100644 --- a/apps/dev-playground/server/index.ts +++ b/apps/dev-playground/server/index.ts @@ -11,7 +11,12 @@ import { serving, WRITE_ACTIONS, } from "@databricks/appkit"; -import { agents, createAgent, tool } from "@databricks/appkit/beta"; +import { + agents, + createAgent, + fromSupervisorApi, + tool, +} from "@databricks/appkit/beta"; import { WorkspaceClient } from "@databricks/sdk-experimental"; import { z } from "zod"; import { lakebaseExamples } from "./lakebase-examples-plugin"; @@ -68,6 +73,33 @@ const helper = createAgent({ }, }); +// Supervisor API demo agent. Tools are configured on the adapter (the SA +// endpoint executes them server-side), not on the createAgent definition. +// Uncomment a `supervisorTools.*` entry (and import 'supervisorTools' from +// '@databricks/appkit/beta') to give the model real powers. +// +// We `await` the factory at module init so a misconfigured workspace +// (missing host, bad credentials) fails fast with a clear error here +// instead of as an unhandled rejection. Top-level await is fine in this +// ESM module. +const supervisor = createAgent({ + instructions: + "You are an assistant powered by the Databricks Supervisor API.", + model: fromSupervisorApi({ + model: "databricks-claude-sonnet-4-5", + tools: [ + // supervisorTools.genieSpace( + // "01ABCDEF12345678", + // "NYC taxi trip records and zones", + // ), + // supervisorTools.ucFunction( + // "main.default.add", + // "Adds two integers and returns the sum.", + // ), + ], + }), +}); + /* * Smart-Dashboard agents. * @@ -385,7 +417,7 @@ createApp({ }), serving(), agents({ - agents: { helper, sql_analyst, dashboard_pilot }, + agents: { helper, sql_analyst, dashboard_pilot, supervisor }, // `query` (markdown dispatcher) + `sql_analyst` + `dashboard_pilot` // wire the /smart-dashboard route. `insights` and `anomaly` are // ephemeral markdown agents auto-fired by the route's AgentSidebar. diff --git a/docs/docs/plugins/agents.md b/docs/docs/plugins/agents.md index 0ba2ab301..c228551e2 100644 --- a/docs/docs/plugins/agents.md +++ b/docs/docs/plugins/agents.md @@ -16,6 +16,8 @@ This page covers the full lifecycle. For the hand-written primitives (`tool()`, The agents plugin drives the LLM over Server-Sent Events. Foundation Model APIs (Claude, Llama, GPT, etc.) and other chat-style endpoints support streaming and work out of the box. Custom model endpoints that return a single JSON response (e.g. typical `sklearn` or MLflow `pyfunc` deployments) do **not** stream — pointing an agent at one will fail with "Response body is null — streaming not supported" on the first turn. If you list a serving endpoint in `apps init`, pick one whose model implements the chat-completions streaming protocol; the agents plugin reads its name from `DATABRICKS_SERVING_ENDPOINT_NAME` whenever an agent doesn't pin `model:` itself. For the non-streaming path against a custom endpoint, use the `serving` plugin's `/invoke` route with `useServingInvoke` instead. + +Or skip serving-endpoint setup entirely with the managed [Supervisor API adapter](#managed-agents-the-supervisor-api-adapter) (beta). ::: ## Install @@ -217,6 +219,79 @@ const result = await runAgent(classifier, { Hosted tools (MCP) are still `agents()`-only since they require the live MCP client. Plugin tool dispatch in standalone mode runs as the service principal (no OBO) and **bypasses the agents-plugin approval gate** — treat standalone runAgent as a trusted-prompt environment (CI, batch eval, internal scripts), not as an exposed user-facing surface. +## Managed agents: the Supervisor API adapter + +`fromSupervisorApi` (beta) is the zero-config way to run an agent: instead of provisioning and pointing at a model-serving endpoint, you run the agentic loop in the Databricks workspace by targeting the AI Gateway Responses API (`/ai-gateway/mlflow/v1/responses`), which runs the LLM — and any hosted tools — as a managed service on Databricks. No `DATABRICKS_SERVING_ENDPOINT_NAME`, no stream-capability check, no JS tool plumbing for the common cases. + +The minimal agent is one extra line versus a markdown agent: + +```ts +import { createApp, createAgent } from "@databricks/appkit"; +import { agents, fromSupervisorApi } from "@databricks/appkit/beta"; + +await createApp({ + plugins: [ + agents({ + agents: { + assistant: createAgent({ + instructions: "You are a helpful assistant.", + model: fromSupervisorApi({ model: "databricks-claude-sonnet-4-5" }), + }), + }, + }), + ], +}); +``` + +`createAgent({ model })` already accepts adapters and adapter promises in addition to the model-name string used in earlier examples, so you can drop the factory result straight in. `fromSupervisorApi` resolves credentials through the SDK chain (`DATABRICKS_HOST`, OAuth, PAT, …); pass `workspaceClient` to reuse an existing client. + +### Hosted tools + +Expose Genie spaces, Unity Catalog functions/connections, Knowledge Assistants, or other AppKit apps to the model by listing them on the adapter — execution stays server-side, you write no tool code: + +```ts +import { fromSupervisorApi, supervisorTools } from "@databricks/appkit/beta"; + +const model = fromSupervisorApi({ + model: "databricks-claude-sonnet-4-5", + tools: [ + supervisorTools.genieSpace( + "01ABCDEF12345678", + "NYC taxi trip records and zones", + ), + supervisorTools.ucFunction( + "main.default.add", + "Adds two integers and returns the sum.", + ), + ], +}); +``` + +`description` is **required and non-empty** — the LLM uses it to route between tools, so two Genie spaces both labelled "Genie space" will be indistinguishable. + +| Factory | Tool kind | Identifier | +|---|---|---| +| `supervisorTools.genieSpace(id, description)` | Genie space | space id | +| `supervisorTools.ucFunction(name, description)` | Unity Catalog function | three-part name | +| `supervisorTools.knowledgeAssistant(id, description)` | Knowledge Assistant | assistant id | +| `supervisorTools.app(name, description)` | Databricks App | app name | +| `supervisorTools.ucConnection(name, description)` | UC connection | connection name | + +### What does *not* apply to Supervisor-API agents + +The managed runtime owns its own tool execution, so the adapter intentionally **ignores the agents-plugin tool index**. For any agent whose `model:` is a Supervisor adapter: + +- Tools wired via markdown `tools:` or the `tools(plugins)` function form are not exposed to the model — declare hosted tools via `fromSupervisorApi({ tools: […] })` instead. +- The **human-in-the-loop approval gate** does not fire (tool calls never enter the Node process; `effect: "destructive"` annotations on plugin tools are irrelevant here). +- `limits.maxToolCalls` is not enforced (the managed runtime accounts for its own calls). +- Per-call **OBO** does not apply to hosted tools; they run with the credentials the managed runtime uses for the target resource. + +Standard-adapter agents and Supervisor-API agents can coexist in the same `agents({ agents: { … } })` map and can be composed as sub-agents (Level 4) — only the agent whose `model:` points at a Supervisor adapter is exempt from the items above. + +:::note Recovery path for non-streaming tool turns +Some hosted tool kinds return their final assistant text without incremental `output_text.delta` events. The adapter has a recovery path that pulls the text out of `response.completed.output[]` so the turn is not silently empty. Set `DEBUG=appkit:agents:supervisor-api` to log the per-turn event-type histogram if you want to verify which path a turn took. +::: + ## Configuration reference ```ts diff --git a/packages/appkit/src/agents/supervisor-api.ts b/packages/appkit/src/agents/supervisor-api.ts new file mode 100644 index 000000000..228eb8be9 --- /dev/null +++ b/packages/appkit/src/agents/supervisor-api.ts @@ -0,0 +1,577 @@ +import type { + AgentAdapter, + AgentEvent, + AgentInput, + AgentRunContext, + Message, + ResponseStreamEvent, +} from "shared"; +import { type ApiClientLike, streamPath } from "../connectors/serving/client"; +import { createLogger } from "../logging/logger"; +import { readSseEvents } from "../stream"; + +const logger = createLogger("agents:supervisor-api"); + +/** + * Transport shim: given a request body, returns the raw SSE byte stream from + * the Supervisor API endpoint. Injected at construction time so callers can + * swap in the workspace SDK (the {@link fromSupervisorApi} factory), a bare + * `fetch` (a reverse proxy / mock), or a test fake. Mirrors `StreamBody` in + * `agents/databricks.ts` so both adapters share one transport surface. + */ +type StreamBody = ( + body: Record, + signal?: AbortSignal, +) => Promise>; + +/** + * Structural shape of a Databricks SDK client used by {@link fromSupervisorApi}. + * Only what we need: `apiClient.request` for streaming and + * `config.ensureResolved` to materialise the host/credentials. + */ +interface WorkspaceClientLike extends ApiClientLike { + config: { ensureResolved(): Promise }; +} + +// --------------------------------------------------------------------------- +// Supervisor API tool surface (wire format) +// --------------------------------------------------------------------------- + +/** + * Tools supported by the Databricks AI Gateway Responses API. The shapes match + * the wire format the endpoint expects, so the adapter passes the array + * straight into the request body. + * + * Prefer the {@link supervisorTools} factories — they fill in the + * SA-validation-bug workaround for `description` (must be non-empty). + */ +export type SupervisorTool = + | { type: "genie_space"; genie_space: { id: string; description: string } } + | { type: "uc_function"; uc_function: { name: string; description: string } } + | { + type: "knowledge_assistant"; + knowledge_assistant: { + knowledge_assistant_id: string; + description: string; + }; + } + | { type: "app"; app: { name: string; description: string } } + | { + type: "uc_connection"; + uc_connection: { name: string; description: string }; + }; + +/** + * Concise factories for declaring Supervisor API tools. + * + * `description` is required: SA's protobuf validation rejects `null`/`""`, + * AND the LLM running on SA reads this string to decide when to route to + * the tool. Two genie spaces both labelled "Genie space" give the model + * nothing to discriminate on, so callers always own the routing hint. + * + * @example + * ```ts + * fromSupervisorApi({ + * model: "databricks-claude-sonnet-4", + * tools: [ + * supervisorTools.genieSpace( + * "01ABCDEF12345678", + * "NYC taxi trip records and zones", + * ), + * supervisorTools.ucFunction( + * "main.default.add", + * "Adds two integers and returns the sum.", + * ), + * ], + * }); + * ``` + */ +export const supervisorTools = { + genieSpace: (id: string, description: string): SupervisorTool => ({ + type: "genie_space", + genie_space: { id, description }, + }), + ucFunction: (name: string, description: string): SupervisorTool => ({ + type: "uc_function", + uc_function: { name, description }, + }), + knowledgeAssistant: ( + knowledgeAssistantId: string, + description: string, + ): SupervisorTool => ({ + type: "knowledge_assistant", + knowledge_assistant: { + knowledge_assistant_id: knowledgeAssistantId, + description, + }, + }), + app: (name: string, description: string): SupervisorTool => ({ + type: "app", + app: { name, description }, + }), + ucConnection: (name: string, description: string): SupervisorTool => ({ + type: "uc_connection", + uc_connection: { name, description }, + }), +}; + +// --------------------------------------------------------------------------- +// Adapter +// --------------------------------------------------------------------------- + +export interface SupervisorApiAdapterOptions { + /** + * Model identifier to pass in the request body + * (e.g. "databricks-claude-sonnet-4"). + */ + model: string; + /** + * Hosted tools the SA endpoint should expose to the model. Use the + * {@link supervisorTools} factories for the most common shapes. + */ + tools?: SupervisorTool[]; + /** + * A WorkspaceClient (or structural equivalent) used for host resolution + * and per-request authentication. When omitted, a `WorkspaceClient({})` + * is created internally using the default SDK credential chain + * (`DATABRICKS_HOST`, OAuth, PAT, etc.). + */ + workspaceClient?: WorkspaceClientLike; +} + +export interface SupervisorApiAdapterCtorOptions { + streamBody: StreamBody; + model: string; + tools?: SupervisorTool[]; +} + +/** + * Adapter that calls the Databricks AI Gateway Responses API + * (`/ai-gateway/mlflow/v1/responses`). + * + * Streams SSE events in the OpenAI Responses API wire format and maps them + * to the AppKit `AgentEvent` protocol. Tool execution is handled + * server-side, so the adapter ignores the agents-plugin tool index. + * + * Authentication is handled via the Databricks SDK credential chain — the + * same mechanism used by `DatabricksAdapter.fromModelServing`. The transport + * is injected via {@link SupervisorApiAdapterCtorOptions.streamBody}; the + * {@link fromSupervisorApi} factory wires it through the SDK's + * `apiClient.request({ raw: true })`. + * + * Set `DEBUG=appkit:agents:supervisor-api` to log the outbound request + * shape (model, instructions length, input shape, tool count) and to be + * notified when the recovery path engages (no incremental deltas, text + * pulled from `response.completed.output[]`). The no-delta warning includes + * a per-turn event-type histogram and the SA-reported status/error/ + * incomplete_details, so it's already actionable without DEBUG. + * + * @example + * ```ts + * import { createApp, createAgent, agents } from "@databricks/appkit"; + * import { + * fromSupervisorApi, + * supervisorTools, + * } from "@databricks/appkit/agents/supervisor-api"; + * + * const adapter = await fromSupervisorApi({ + * model: "databricks-claude-sonnet-4", + * tools: [ + * supervisorTools.genieSpace( + * "01ABCDEF12345678", + * "NYC taxi trip records and zones", + * ), + * ], + * }); + * + * await createApp({ + * plugins: [ + * agents({ + * agents: { + * assistant: createAgent({ + * instructions: "You are a helpful assistant.", + * model: adapter, + * }), + * }, + * }), + * ], + * }); + * ``` + */ +export class SupervisorApiAdapter implements AgentAdapter { + private streamBody: StreamBody; + private model: string; + private tools: SupervisorTool[]; + + constructor(options: SupervisorApiAdapterCtorOptions) { + this.streamBody = options.streamBody; + this.model = options.model; + this.tools = options.tools ?? []; + } + + async *run( + input: AgentInput, + context: AgentRunContext, + ): AsyncGenerator { + if (context.signal?.aborted) return; + + yield { type: "status", status: "running" }; + + const { instructions, input: payloadInput } = this.buildInput( + input.messages, + ); + yield* this.streamResponse(instructions, payloadInput, context.signal); + } + + private async *streamResponse( + instructions: string | undefined, + input: ResponseInput, + signal?: AbortSignal, + ): AsyncGenerator { + const body: Record = { + model: this.model, + input, + stream: true, + }; + if (instructions) { + body.instructions = instructions; + } + // SA's protobuf validation rejects `tools: []` and `tools: null`. Only + // include the field when at least one tool is configured. + if (this.tools.length > 0) { + body.tools = this.tools; + } + + logger.debug( + "model=%s instructionsLen=%d inputType=%s tools=%d", + this.model, + instructions?.length ?? 0, + typeof input === "string" ? "string" : `array[${input.length}]`, + this.tools.length, + ); + + let stream: ReadableStream; + try { + stream = await this.streamBody(body, signal); + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + logger.warn("Supervisor API request failed: %s", message); + yield { + type: "status", + status: "error", + error: `Supervisor API error: ${message}`, + }; + return; + } + + let receivedAnyDelta = false; + // Tracks `item_id`s we've already streamed text deltas for. Used by + // `mapEvent` to fall back to the final item text on `output_item.done` + // only when no incremental deltas streamed for that item — avoids + // double-emitting text when SA does both delta and done. + const streamedItemIds = new Set(); + // Histogram of received event types — surfaced in the no-delta warning + // so it's actionable without re-running with DEBUG. + const eventCounts = new Map(); + // Set to true once we've yielded a terminal `{status:"error"}` event so + // the recovery / completion / no-delta-warning blocks below all bail + // out — the consumer's already seen the terminal status, anything + // further would contradict the protocol's terminal-event semantics. + let terminated = false; + // Diagnostic snapshot of the last `response.completed` event. SA stuffs + // the final assistant message into `response.output[]` even when it + // didn't emit any deltas (e.g. when a tool failed or the model produced + // nothing). Keeping it lets us recover the text and surface useful + // errors instead of a silent empty turn. + let lastCompleted: + | { + status?: string; + output?: Array<{ + type?: string; + content?: Array<{ type?: string; text?: string }>; + }>; + error?: unknown; + incomplete_details?: unknown; + } + | undefined; + + for await (const { event, data } of readSseEvents(stream, signal)) { + if (data === "[DONE]") continue; + + let parsed: Record; + try { + parsed = JSON.parse(data); + } catch (err) { + logger.debug( + "Failed to parse SSE data line: %s (%O)", + data.slice(0, 200), + err, + ); + continue; + } + + const eventType = event || (parsed.type as string) || ""; + eventCounts.set(eventType, (eventCounts.get(eventType) ?? 0) + 1); + + // `response.completed` is held back until after the loop so we can + // synthesise a `message_delta` from `response.output[]` when the + // stream produced no incremental deltas (intermittent SA behaviour). + // Emitting `complete` first would let UIs finalise the turn before the + // recovered text arrives. + if (eventType === "response.completed") { + lastCompleted = parsed.response as typeof lastCompleted; + continue; + } + + const out = mapEvent(eventType, parsed, streamedItemIds); + if (out) { + if (out.type === "message_delta") receivedAnyDelta = true; + yield out; + if (out.type === "status" && out.status === "error") { + terminated = true; + break; + } + } + } + + if (signal?.aborted) return; + + if (eventCounts.size === 0) { + logger.warn( + "Supervisor API stream closed without emitting any SSE events.", + ); + return; + } + + if (terminated) return; + + // Recovery path: no deltas streamed but SA finished — pull the assistant + // text out of `response.completed.response.output[]`. + if (!receivedAnyDelta) { + const recovered = extractTextFromCompletedResponse(lastCompleted); + if (recovered) { + logger.debug( + "Recovered %d chars from response.completed.output[]", + recovered.length, + ); + yield { type: "message_delta", content: recovered }; + receivedAnyDelta = true; + } + } + + if (eventCounts.has("response.completed")) { + yield { type: "status", status: "complete" }; + } + + if (!receivedAnyDelta) { + const histogram = [...eventCounts.entries()] + .map(([t, n]) => `${t}=${n}`) + .join(", "); + const completedError = lastCompleted?.error + ? JSON.stringify(lastCompleted.error) + : undefined; + const completedIncomplete = lastCompleted?.incomplete_details + ? JSON.stringify(lastCompleted.incomplete_details) + : undefined; + logger.warn( + "Supervisor API stream completed without any output_text deltas. " + + "events={%s} completed.status=%s completed.error=%s completed.incomplete=%s", + histogram, + lastCompleted?.status ?? "", + completedError ?? "", + completedIncomplete ?? "", + ); + } + } + + /** + * Splits the agent's message list into a Responses-API payload. System + * messages are concatenated (in order) into the top-level `instructions` + * field; user/assistant turns become `input` (as a plain string for the + * common single-user-turn case, otherwise as `{role,content}[]`). Tool-role + * messages are skipped — SA owns its own tool history server-side, so + * re-feeding our tool-result records would only confuse it. + */ + private buildInput(messages: Message[]): { + instructions: string | undefined; + input: ResponseInput; + } { + const instructionsParts: string[] = []; + const turns: Array<{ + role: "user" | "assistant" | "system"; + content: string; + }> = []; + + for (const m of messages) { + if (m.role === "system") instructionsParts.push(m.content); + else if (m.role !== "tool") + turns.push({ role: m.role, content: m.content }); + } + + const instructions = instructionsParts.length + ? instructionsParts.join("\n\n") + : undefined; + + if (turns.length === 1 && turns[0].role === "user") { + return { instructions, input: turns[0].content }; + } + return { instructions, input: turns }; + } +} + +type ResponseInput = + | string + | Array<{ role: "user" | "assistant" | "system"; content: string }>; + +/** + * Pulls the final assistant text out of the `response` payload attached to a + * `response.completed` event. SA always materialises the full response there, + * so this is our last-resort recovery path when the stream produced neither + * `output_text.delta` nor an actionable `output_item.done` (observed + * intermittently with tool-enabled SA agents). + */ +function extractTextFromCompletedResponse( + response: + | { + output?: Array<{ + type?: string; + content?: Array<{ type?: string; text?: string }>; + }>; + } + | undefined, +): string { + if (!response?.output) return ""; + let text = ""; + for (const item of response.output) { + if (item?.type !== "message" || !Array.isArray(item.content)) continue; + for (const part of item.content) { + if (part?.type === "output_text" && typeof part.text === "string") { + text += part.text; + } + } + } + return text; +} + +function mapEvent( + eventType: string, + data: Record, + streamedItemIds: Set, +): AgentEvent | null { + // The cast restricts the switch domain to the closed wire-event union + // exported by `shared`, so typos in case clauses (e.g. `response.faled`) + // become compile errors instead of silent string mismatches. Unknown + // event names still fall through to `default` at runtime — we don't + // require exhaustive matching since SA emits more lifecycle events + // than we care to map. + switch (eventType as ResponseStreamEvent["type"]) { + case "response.output_text.delta": { + const itemId = data.item_id as string | undefined; + if (itemId) streamedItemIds.add(itemId); + return { type: "message_delta", content: (data.delta as string) ?? "" }; + } + + // `response.completed` is intentionally absent: `streamResponse` holds + // it back so it can synthesise a delta from `response.output[]` when + // the stream produced none, then emits `{status:"complete"}` itself. + + case "response.failed": + return { type: "status", status: "error", error: "Response failed" }; + + case "error": { + const errMsg = + typeof data.error === "string" + ? data.error + : JSON.stringify(data.error ?? "Unknown error"); + return { type: "status", status: "error", error: errMsg }; + } + + case "response.output_item.done": { + const item = data.item as + | { + id?: string; + type?: string; + content?: Array<{ text?: string; type?: string }>; + } + | undefined; + + if (item?.id === "error") { + const errText = item.content?.[0]?.text ?? "Unknown tool error from SA"; + return { type: "status", status: "error", error: errText }; + } + + // Fallback: when SA produces a tool-driven response (e.g. Genie space), + // it often omits `response.output_text.delta` events and only emits the + // final assistant message via `output_item.done`. Surface that text as + // a single delta so the UI sees the answer. + if ( + item?.type === "message" && + item.id && + !streamedItemIds.has(item.id) + ) { + const text = (item.content ?? []) + .map((c) => (c.type === "output_text" ? (c.text ?? "") : "")) + .join(""); + if (text.length > 0) { + streamedItemIds.add(item.id); + return { type: "message_delta", content: text }; + } + } + return null; + } + + // All other event types are intentionally ignored. Notable lifecycle + // events we drop on the floor: `response.created`, `response.in_progress`, + // `response.output_text.done`, `response.output_item.added`, + // `response.content_part.added`, `response.content_part.done`. + default: + return null; + } +} + +/** + * Creates an {@link AgentAdapter} backed by the Databricks AI Gateway + * Responses API (`/ai-gateway/mlflow/v1/responses`). + * + * Uses the SDK's default credential chain for auth (reads DATABRICKS_HOST, + * DATABRICKS_TOKEN, OAuth config, etc.). + * + * @example + * ```ts + * import { + * fromSupervisorApi, + * supervisorTools, + * } from "@databricks/appkit/agents/supervisor-api"; + * + * const adapter = await fromSupervisorApi({ + * model: "databricks-claude-sonnet-4", + * tools: [ + * supervisorTools.genieSpace( + * "01ABCDEF12345678", + * "NYC taxi trip records and zones", + * ), + * ], + * }); + * ``` + */ +export async function fromSupervisorApi( + options: SupervisorApiAdapterOptions, +): Promise { + let client = options.workspaceClient; + if (!client) { + const sdk = await import("@databricks/sdk-experimental"); + client = new sdk.WorkspaceClient({}) as unknown as WorkspaceClientLike; + } + + await client.config.ensureResolved(); + + // Capture the resolved client so the closure doesn't depend on the outer + // `let` binding being reassigned later. + const resolved = client; + return new SupervisorApiAdapter({ + streamBody: (body, signal) => + streamPath(resolved, "/ai-gateway/mlflow/v1/responses", body, signal), + model: options.model, + tools: options.tools ?? [], + }); +} diff --git a/packages/appkit/src/agents/tests/supervisor-api.test.ts b/packages/appkit/src/agents/tests/supervisor-api.test.ts new file mode 100644 index 000000000..9606b1c6a --- /dev/null +++ b/packages/appkit/src/agents/tests/supervisor-api.test.ts @@ -0,0 +1,662 @@ +import type { AgentEvent, AgentInput } from "shared"; +import { afterEach, describe, expect, test, vi } from "vitest"; +import { + fromSupervisorApi, + SupervisorApiAdapter, + type SupervisorTool, + supervisorTools, +} from "../supervisor-api"; + +function createReadableStream(chunks: string[]): ReadableStream { + const encoder = new TextEncoder(); + let i = 0; + return new ReadableStream({ + pull(controller) { + if (i < chunks.length) { + controller.enqueue(encoder.encode(chunks[i])); + i++; + } else { + controller.close(); + } + }, + }); +} + +function sseEvent(eventName: string, data: Record): string { + return `event: ${eventName}\ndata: ${JSON.stringify(data)}\n\n`; +} + +/** + * Captures the body the adapter posts and returns a fake stream of SSE + * chunks. Mirrors the `streamBody` test pattern used by `DatabricksAdapter`. + */ +function makeStreamBody(chunks: string[]): { + streamBody: ReturnType; + lastBody: () => Record | undefined; +} { + let captured: Record | undefined; + const streamBody = vi.fn(async (body: Record) => { + captured = body; + return createReadableStream(chunks); + }); + return { streamBody, lastBody: () => captured }; +} + +function createInput(): AgentInput { + return { + messages: [ + { id: "1", role: "user", content: "Hello", createdAt: new Date() }, + ], + tools: [], + threadId: "thread-1", + }; +} + +async function collect( + gen: AsyncGenerator, +): Promise { + const out: AgentEvent[] = []; + for await (const e of gen) out.push(e); + return out; +} + +describe("supervisorTools factories", () => { + test("genieSpace produces correct wire shape", () => { + expect(supervisorTools.genieSpace("space123", "NYC taxi data")).toEqual({ + type: "genie_space", + genie_space: { id: "space123", description: "NYC taxi data" }, + }); + }); + + test("ucFunction produces correct wire shape", () => { + expect( + supervisorTools.ucFunction("main.default.add", "Adds two integers."), + ).toEqual({ + type: "uc_function", + uc_function: { + name: "main.default.add", + description: "Adds two integers.", + }, + }); + }); + + test("knowledgeAssistant maps id into knowledge_assistant_id", () => { + expect( + supervisorTools.knowledgeAssistant("ka-1", "Internal docs Q&A"), + ).toEqual({ + type: "knowledge_assistant", + knowledge_assistant: { + knowledge_assistant_id: "ka-1", + description: "Internal docs Q&A", + }, + }); + }); + + test("app produces correct wire shape", () => { + expect(supervisorTools.app("my-app", "Demo Databricks app.")).toEqual({ + type: "app", + app: { name: "my-app", description: "Demo Databricks app." }, + }); + }); + + test("ucConnection produces correct wire shape", () => { + expect( + supervisorTools.ucConnection("my-conn", "Connection to external DB."), + ).toEqual({ + type: "uc_connection", + uc_connection: { + name: "my-conn", + description: "Connection to external DB.", + }, + }); + }); +}); + +describe("SupervisorApiAdapter", () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + test("posts model, input, tools, and stream:true through streamBody", async () => { + const { streamBody, lastBody } = makeStreamBody([ + sseEvent("response.output_text.delta", { delta: "Hi" }), + sseEvent("response.completed", {}), + ]); + + const tools: SupervisorTool[] = [ + supervisorTools.genieSpace("g1", "Test genie space"), + ]; + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + tools, + }); + + await collect(adapter.run(createInput(), { executeTool: vi.fn() })); + + expect(streamBody).toHaveBeenCalledTimes(1); + expect(lastBody()).toMatchObject({ + model: "databricks-claude-sonnet-4", + input: "Hello", + stream: true, + tools, + }); + }); + + test("omits the tools field entirely when no tools are configured", async () => { + const { streamBody, lastBody } = makeStreamBody([ + sseEvent("response.completed", {}), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + await collect(adapter.run(createInput(), { executeTool: vi.fn() })); + expect(lastBody()).not.toHaveProperty("tools"); + }); + + test("hoists system messages into the top-level instructions field", async () => { + const { streamBody, lastBody } = makeStreamBody([ + sseEvent("response.completed", {}), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + await collect( + adapter.run( + { + messages: [ + { + id: "s", + role: "system", + content: "Be terse.", + createdAt: new Date(), + }, + { id: "u", role: "user", content: "Hi", createdAt: new Date() }, + ], + tools: [], + threadId: "t", + }, + { executeTool: vi.fn() }, + ), + ); + const body = lastBody(); + expect(body?.instructions).toBe("Be terse."); + expect(body?.input).toBe("Hi"); + }); + + test("omits instructions when there is no system message", async () => { + const { streamBody, lastBody } = makeStreamBody([ + sseEvent("response.completed", {}), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + await collect(adapter.run(createInput(), { executeTool: vi.fn() })); + expect(lastBody()).not.toHaveProperty("instructions"); + }); + + test("emits message_delta and complete on the happy path", async () => { + const { streamBody } = makeStreamBody([ + sseEvent("response.output_text.delta", { delta: "Hello" }), + sseEvent("response.output_text.delta", { delta: " world" }), + sseEvent("response.completed", {}), + ]); + + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + + expect(events).toEqual([ + { type: "status", status: "running" }, + { type: "message_delta", content: "Hello" }, + { type: "message_delta", content: " world" }, + { type: "status", status: "complete" }, + ]); + }); + + test("maps response.failed to a status:error event", async () => { + const { streamBody } = makeStreamBody([sseEvent("response.failed", {})]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + expect(events).toContainEqual({ + type: "status", + status: "error", + error: "Response failed", + }); + }); + + test("maps top-level error events", async () => { + const { streamBody } = makeStreamBody([ + sseEvent("error", { error: "rate limited" }), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + expect(events).toContainEqual({ + type: "status", + status: "error", + error: "rate limited", + }); + }); + + test("maps response.output_item.done with id:'error' to status:error", async () => { + const { streamBody } = makeStreamBody([ + sseEvent("response.output_item.done", { + item: { + id: "error", + content: [{ text: "Tool execution failed" }], + }, + }), + sseEvent("response.completed", {}), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + expect(events).toContainEqual({ + type: "status", + status: "error", + error: "Tool execution failed", + }); + }); + + test("falls back to output_item.done text when no deltas streamed (tool-driven SA response)", async () => { + const { streamBody } = makeStreamBody([ + sseEvent("response.output_item.added", { + item: { type: "message", id: "msg-1", role: "assistant", content: [] }, + }), + sseEvent("response.output_item.done", { + item: { + type: "message", + id: "msg-1", + status: "completed", + role: "assistant", + content: [{ type: "output_text", text: "Genie says hi." }], + }, + }), + sseEvent("response.completed", {}), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + expect(events).toEqual([ + { type: "status", status: "running" }, + { type: "message_delta", content: "Genie says hi." }, + { type: "status", status: "complete" }, + ]); + }); + + test("does not double-emit when both deltas and output_item.done arrive for the same item", async () => { + const { streamBody } = makeStreamBody([ + sseEvent("response.output_text.delta", { + item_id: "msg-1", + delta: "Hello", + }), + sseEvent("response.output_text.delta", { + item_id: "msg-1", + delta: " world", + }), + sseEvent("response.output_item.done", { + item: { + type: "message", + id: "msg-1", + status: "completed", + role: "assistant", + content: [{ type: "output_text", text: "Hello world" }], + }, + }), + sseEvent("response.completed", {}), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + expect(events).toEqual([ + { type: "status", status: "running" }, + { type: "message_delta", content: "Hello" }, + { type: "message_delta", content: " world" }, + { type: "status", status: "complete" }, + ]); + }); + + test("emits status:error when the underlying streamBody throws", async () => { + const streamBody = vi + .fn() + .mockRejectedValue(new Error("Supervisor API error (500): boom")); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + expect(events).toContainEqual({ + type: "status", + status: "error", + error: "Supervisor API error: Supervisor API error (500): boom", + }); + }); + + test("short-circuits when the signal is already aborted", async () => { + const streamBody = vi.fn(); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + + const controller = new AbortController(); + controller.abort(); + + const events = await collect( + adapter.run(createInput(), { + executeTool: vi.fn(), + signal: controller.signal, + }), + ); + + expect(events).toEqual([]); + expect(streamBody).not.toHaveBeenCalled(); + }); + + test("multi-turn input (user + assistant + user) is sent as a structured array", async () => { + const { streamBody, lastBody } = makeStreamBody([ + sseEvent("response.completed", {}), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + + await collect( + adapter.run( + { + messages: [ + { id: "u1", role: "user", content: "Hi", createdAt: new Date() }, + { + id: "a", + role: "assistant", + content: "Hello!", + createdAt: new Date(), + }, + { + id: "u2", + role: "user", + content: "Tell me more", + createdAt: new Date(), + }, + ], + tools: [], + threadId: "t", + }, + { executeTool: vi.fn() }, + ), + ); + + expect(lastBody()?.input).toEqual([ + { role: "user", content: "Hi" }, + { role: "assistant", content: "Hello!" }, + { role: "user", content: "Tell me more" }, + ]); + }); + + test("drops tool-role messages from the request payload", async () => { + const { streamBody, lastBody } = makeStreamBody([ + sseEvent("response.completed", {}), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + await collect( + adapter.run( + { + messages: [ + { id: "u", role: "user", content: "Hi", createdAt: new Date() }, + { + id: "t1", + role: "tool", + content: "(genie result)", + createdAt: new Date(), + }, + ], + tools: [], + threadId: "t", + }, + { executeTool: vi.fn() }, + ), + ); + expect(lastBody()?.input).toBe("Hi"); + }); + + test("recovers final assistant text from response.completed.output when no deltas streamed", async () => { + // Real-world flake: SA occasionally finishes a turn with zero + // `output_text.delta` events and no `output_item.done` for the message, + // but still mirrors the full assistant text in `response.completed`. + // Without recovery the UI sees a silent empty turn. + const { streamBody } = makeStreamBody([ + sseEvent("response.created", {}), + sseEvent("response.in_progress", {}), + sseEvent("response.completed", { + response: { + status: "completed", + output: [ + { + type: "message", + id: "msg-x", + role: "assistant", + content: [ + { type: "output_text", text: "Recovered " }, + { type: "output_text", text: "answer." }, + ], + }, + ], + }, + }), + ]); + + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + + expect(events).toEqual([ + { type: "status", status: "running" }, + { type: "message_delta", content: "Recovered answer." }, + { type: "status", status: "complete" }, + ]); + }); + + test("does not recover from response.completed when deltas already streamed", async () => { + const { streamBody } = makeStreamBody([ + sseEvent("response.output_text.delta", { + item_id: "msg-x", + delta: "Hi", + }), + sseEvent("response.completed", { + response: { + status: "completed", + output: [ + { + type: "message", + id: "msg-x", + role: "assistant", + content: [{ type: "output_text", text: "Hi" }], + }, + ], + }, + }), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + const deltas = events.filter((e) => e.type === "message_delta"); + expect(deltas).toHaveLength(1); + expect(deltas[0]).toEqual({ type: "message_delta", content: "Hi" }); + }); + + test("treats response.failed as terminal: no events follow the error", async () => { + // SA may keep sending events after `response.failed` (and even a stray + // `response.completed`). The adapter must stop yielding once it has + // surfaced a terminal `status: error` so consumers don't see contradictory + // `message_delta`/`complete` events after the failure. + const { streamBody } = makeStreamBody([ + sseEvent("response.failed", {}), + sseEvent("response.output_text.delta", { delta: "ignored" }), + sseEvent("response.completed", {}), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + expect(events).toEqual([ + { type: "status", status: "running" }, + { type: "status", status: "error", error: "Response failed" }, + ]); + }); + + test("does not yield complete when the consumer aborts mid-stream", async () => { + // Stream that yields one delta, then waits forever — the consumer aborts + // after the first event arrives. The adapter must NOT subsequently yield + // a synthesised `complete` from a buffered `response.completed`. + const controller = new AbortController(); + const encoder = new TextEncoder(); + const stream = new ReadableStream({ + start(c) { + c.enqueue( + encoder.encode( + sseEvent("response.output_text.delta", { delta: "Hi" }), + ), + ); + }, + pull() { + return new Promise(() => { + /* never resolves until cancel() */ + }); + }, + }); + + const adapter = new SupervisorApiAdapter({ + streamBody: async () => stream, + model: "databricks-claude-sonnet-4", + }); + + const events: AgentEvent[] = []; + for await (const e of adapter.run(createInput(), { + executeTool: vi.fn(), + signal: controller.signal, + })) { + events.push(e); + if (e.type === "message_delta") controller.abort(); + } + + expect(events).toEqual([ + { type: "status", status: "running" }, + { type: "message_delta", content: "Hi" }, + ]); + }); + + test("recovers when event: and data: lines arrive in separate chunks", async () => { + const { streamBody } = makeStreamBody([ + "event: response.output_text.delta\n", + `data: ${JSON.stringify({ delta: "split" })}\n\n`, + "event: response.completed\ndata: {}\n\n", + ]); + + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + expect(events).toContainEqual({ + type: "message_delta", + content: "split", + }); + expect(events).toContainEqual({ type: "status", status: "complete" }); + }); +}); + +describe("fromSupervisorApi", () => { + test("calls ensureResolved on the supplied workspace client", async () => { + const ensureResolved = vi.fn(async () => {}); + const adapter = await fromSupervisorApi({ + model: "databricks-claude-sonnet-4", + workspaceClient: { + config: { ensureResolved }, + apiClient: { request: vi.fn() }, + }, + }); + expect(ensureResolved).toHaveBeenCalledTimes(1); + expect(adapter).toBeInstanceOf(SupervisorApiAdapter); + }); + + test("routes streaming through apiClient.request with the SA path", async () => { + const encoder = new TextEncoder(); + const contents = new ReadableStream({ + start(controller) { + controller.enqueue(encoder.encode(sseEvent("response.completed", {}))); + controller.close(); + }, + }); + const request = vi.fn().mockResolvedValue({ contents }); + + const adapter = await fromSupervisorApi({ + model: "databricks-claude-sonnet-4", + workspaceClient: { + config: { ensureResolved: vi.fn(async () => {}) }, + apiClient: { request }, + }, + }); + + await collect(adapter.run(createInput(), { executeTool: vi.fn() })); + + expect(request).toHaveBeenCalledTimes(1); + const [requestArgs] = request.mock.calls[0]; + expect(requestArgs.path).toBe("/ai-gateway/mlflow/v1/responses"); + expect(requestArgs.method).toBe("POST"); + expect(requestArgs.raw).toBe(true); + expect(requestArgs.payload).toMatchObject({ + model: "databricks-claude-sonnet-4", + input: "Hello", + stream: true, + }); + expect(requestArgs.payload).not.toHaveProperty("tools"); + }); +}); diff --git a/packages/appkit/src/beta.ts b/packages/appkit/src/beta.ts index 3f5bba80c..7ccc77c5b 100644 --- a/packages/appkit/src/beta.ts +++ b/packages/appkit/src/beta.ts @@ -19,6 +19,16 @@ export type { ToolProvider, } from "shared"; export { DatabricksAdapter, parseTextToolCalls } from "./agents/databricks"; +export type { + SupervisorApiAdapterCtorOptions, + SupervisorApiAdapterOptions, + SupervisorTool, +} from "./agents/supervisor-api"; +export { + fromSupervisorApi, + SupervisorApiAdapter, + supervisorTools, +} from "./agents/supervisor-api"; // Agent runtime export { createAgent } from "./core/agent/create-agent"; diff --git a/packages/appkit/src/connectors/serving/client.ts b/packages/appkit/src/connectors/serving/client.ts index 83f065e69..f75993a39 100644 --- a/packages/appkit/src/connectors/serving/client.ts +++ b/packages/appkit/src/connectors/serving/client.ts @@ -41,6 +41,20 @@ function cancellationTokenFromAbortSignal( }; } +/** + * Structural shape of a Databricks SDK client we need for the low-level + * `apiClient.request` call. Lets `streamPath` be reused by adapters that + * don't want a hard dependency on the concrete `WorkspaceClient` type. + */ +export interface ApiClientLike { + apiClient: { + request( + options: Record, + context?: unknown, + ): Promise; + }; +} + /** * Invokes a serving endpoint using the SDK's high-level query API. * Returns a typed QueryEndpointResponse. @@ -62,22 +76,23 @@ export async function invoke( } /** - * Returns the raw SSE byte stream from a serving endpoint. - * No parsing is performed — bytes are passed through as-is. + * POSTs `body` as JSON to an arbitrary workspace API path and returns the raw + * SSE byte stream. No parsing is performed — bytes are passed through as-is. + * + * Uses the SDK's low-level `apiClient.request({ raw: true })` so callers + * inherit URL resolution, the SDK credential chain (PAT/OAuth/OIDC), and + * any future retries/telemetry baked into the SDK transport. * - * Uses the SDK's low-level `apiClient.request({ raw: true })` because - * the high-level `servingEndpoints.query()` returns `Promise` - * and does not support SSE streaming. + * When `signal` is provided it is bridged to the SDK's `Context` / + * `CancellationToken` so aborts cancel the outbound HTTP request. */ -export async function stream( - client: WorkspaceClient, - endpointName: string, +export async function streamPath( + client: ApiClientLike, + path: string, body: Record, signal?: AbortSignal, ): Promise> { - const { stream: _stream, ...cleanBody } = body; - - logger.debug("Streaming from endpoint %s", endpointName); + logger.debug("Streaming from path %s", path); const context = signal ? new Context({ @@ -87,17 +102,17 @@ export async function stream( const response = (await client.apiClient.request( { - path: `/serving-endpoints/${encodeURIComponent(endpointName)}/invocations`, + path, method: "POST", headers: new Headers({ "Content-Type": "application/json", Accept: "text/event-stream", }), - payload: { ...cleanBody, stream: true }, + payload: body, raw: true, }, context, - )) as { contents: ReadableStream }; + )) as { contents: ReadableStream | null }; if (!response.contents) { throw new Error("Response body is null — streaming not supported"); @@ -105,3 +120,23 @@ export async function stream( return response.contents; } + +/** + * Returns the raw SSE byte stream from a serving endpoint. Thin wrapper over + * {@link streamPath} that handles serving-specific URL encoding and forces + * `stream: true` in the payload. + */ +export async function stream( + client: WorkspaceClient, + endpointName: string, + body: Record, + signal?: AbortSignal, +): Promise> { + const { stream: _stream, ...cleanBody } = body; + return streamPath( + client as unknown as ApiClientLike, + `/serving-endpoints/${encodeURIComponent(endpointName)}/invocations`, + { ...cleanBody, stream: true }, + signal, + ); +} diff --git a/packages/appkit/src/stream/index.ts b/packages/appkit/src/stream/index.ts index cc756130a..75ad8b5c4 100644 --- a/packages/appkit/src/stream/index.ts +++ b/packages/appkit/src/stream/index.ts @@ -1 +1,2 @@ +export { readSseEvents } from "./sse-reader"; export { StreamManager } from "./stream-manager"; diff --git a/packages/appkit/src/stream/sse-reader.ts b/packages/appkit/src/stream/sse-reader.ts new file mode 100644 index 000000000..091f132dc --- /dev/null +++ b/packages/appkit/src/stream/sse-reader.ts @@ -0,0 +1,114 @@ +/** + * One parsed Server-Sent Event. Field names follow the spec: + * https://html.spec.whatwg.org/multipage/server-sent-events.html + * + * The reader does not interpret `data` (no JSON parsing), so callers control + * the wire shape they expect. + */ +export interface SseEvent { + /** Value of the most recent `event:` field, or `""` for an unnamed event. */ + event: string; + /** Joined `data:` lines for the event (empty string when no data was set). */ + data: string; + /** Value of the most recent `id:` field, or `undefined` if none. */ + id?: string; +} + +/** + * Async-iterates Server-Sent Events from a UTF-8 byte stream. + * + * Block-oriented parser: events are delimited by blank lines (`\n\n` after + * CRLF normalization), so an `event:` line in chunk N pairs correctly with a + * `data:` line in chunk N+1 — no hoisted state needed. + * + * The reader passes through the sentinel string `[DONE]` as `event=""`, + * `data="[DONE]"`. Callers that care about it should match `data === "[DONE]"` + * after destructuring. + * + * Terminates when the stream closes or `signal` aborts; releases the reader + * lock in either case. + */ +export async function* readSseEvents( + stream: ReadableStream, + signal?: AbortSignal, +): AsyncGenerator { + const reader = stream.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + + // Cancel the reader on abort so an in-flight `reader.read()` returns + // immediately instead of waiting for the next chunk. Without this, an + // aborted consumer would only notice between reads — fine for chatty + // streams, but unbounded for an idle/heartbeat-less upstream. + const onAbort = () => { + reader.cancel().catch(() => { + // `cancel()` rejects if the stream is already errored/closed; ignore. + }); + }; + if (signal) { + if (signal.aborted) onAbort(); + else signal.addEventListener("abort", onAbort, { once: true }); + } + + try { + while (true) { + if (signal?.aborted) break; + const { done, value } = await reader.read(); + if (done) { + const tail = parseSseBlock(buffer); + if (tail) yield tail; + break; + } + + buffer += decoder.decode(value, { stream: true }); + + const normalized = buffer.replace(/\r\n/g, "\n"); + const blocks = normalized.split("\n\n"); + // Last entry is either an incomplete block or "" (when the chunk ended + // exactly on a boundary). Either way, keep it for the next iteration. + buffer = blocks.pop() ?? ""; + + for (const block of blocks) { + const event = parseSseBlock(block); + if (event) yield event; + } + } + } finally { + if (signal) signal.removeEventListener("abort", onAbort); + reader.releaseLock(); + } +} + +function parseSseBlock(block: string): SseEvent | null { + if (block.length === 0) return null; + const lines = block.split("\n"); + + let eventName = ""; + let id: string | undefined; + const dataLines: string[] = []; + + for (const rawLine of lines) { + const line = rawLine.replace(/\r$/, ""); + if (line === "" || line.startsWith(":")) continue; + + if (line.startsWith("event:")) { + eventName = line.slice(6).trimStart(); + } else if (line.startsWith("data:")) { + dataLines.push(line.slice(5).trimStart()); + } else if (line.startsWith("id:")) { + id = line.slice(3).trimStart(); + } + // Other fields (`retry:`, custom) are ignored by design. + } + + // Per the SSE spec, a block is only dispatched when the data buffer is + // non-empty. Blocks containing only `event:`/`id:` (or comments) do not + // surface as events. + if (dataLines.length === 0) return null; + + return { + event: eventName, + data: dataLines.join("\n"), + id, + }; +} diff --git a/packages/appkit/src/stream/tests/sse-reader.test.ts b/packages/appkit/src/stream/tests/sse-reader.test.ts new file mode 100644 index 000000000..6f7176b62 --- /dev/null +++ b/packages/appkit/src/stream/tests/sse-reader.test.ts @@ -0,0 +1,182 @@ +import { describe, expect, test } from "vitest"; +import { readSseEvents, type SseEvent } from "../sse-reader"; + +function streamOf(chunks: string[]): ReadableStream { + const encoder = new TextEncoder(); + let i = 0; + return new ReadableStream({ + pull(controller) { + if (i < chunks.length) { + controller.enqueue(encoder.encode(chunks[i])); + i++; + } else { + controller.close(); + } + }, + }); +} + +async function collect( + gen: AsyncGenerator, +): Promise { + const out: SseEvent[] = []; + for await (const e of gen) out.push(e); + return out; +} + +describe("readSseEvents", () => { + test("parses a single named event with JSON data", async () => { + const events = await collect( + readSseEvents( + streamOf(['event: response.completed\ndata: {"ok":true}\n\n']), + ), + ); + expect(events).toEqual([ + { event: "response.completed", data: '{"ok":true}', id: undefined }, + ]); + }); + + test("pairs event: and data: across chunk boundaries", async () => { + const events = await collect( + readSseEvents( + streamOf([ + "event: response.output_text.delta\n", + 'data: {"delta":"split"}\n\n', + ]), + ), + ); + expect(events).toEqual([ + { + event: "response.output_text.delta", + data: '{"delta":"split"}', + id: undefined, + }, + ]); + }); + + test("ignores blank lines, comment lines, and unknown fields", async () => { + const events = await collect( + readSseEvents( + streamOf([": heartbeat\n\nretry: 1000\nevent: ping\ndata: hi\n\n"]), + ), + ); + expect(events).toEqual([{ event: "ping", data: "hi", id: undefined }]); + }); + + test("captures id: when present", async () => { + const events = await collect( + readSseEvents(streamOf(["id: abc-123\nevent: ping\ndata: hi\n\n"])), + ); + expect(events).toEqual([{ event: "ping", data: "hi", id: "abc-123" }]); + }); + + test("falls back to empty event name when only data: is present", async () => { + const events = await collect(readSseEvents(streamOf(["data: 1\n\n"]))); + expect(events).toEqual([{ event: "", data: "1", id: undefined }]); + }); + + test("joins multi-line data: payloads with \\n", async () => { + const events = await collect( + readSseEvents(streamOf(["data: line1\ndata: line2\n\n"])), + ); + expect(events).toEqual([ + { event: "", data: "line1\nline2", id: undefined }, + ]); + }); + + test("normalises CRLF line endings", async () => { + const events = await collect( + readSseEvents(streamOf(["event: x\r\ndata: y\r\n\r\n"])), + ); + expect(events).toEqual([{ event: "x", data: "y", id: undefined }]); + }); + + test("emits a trailing event when the stream closes without a final blank line", async () => { + const events = await collect( + readSseEvents(streamOf(["event: ping\ndata: hi"])), + ); + expect(events).toEqual([{ event: "ping", data: "hi", id: undefined }]); + }); + + test("passes through [DONE] sentinels as data", async () => { + const events = await collect(readSseEvents(streamOf(["data: [DONE]\n\n"]))); + expect(events).toEqual([{ event: "", data: "[DONE]", id: undefined }]); + }); + + test("aborts when the signal fires before the next read", async () => { + const controller = new AbortController(); + let pulls = 0; + const stream = new ReadableStream({ + pull(c) { + pulls++; + if (pulls === 1) { + c.enqueue(new TextEncoder().encode("event: a\ndata: 1\n\n")); + } else { + controller.abort(); + c.enqueue(new TextEncoder().encode("event: b\ndata: 2\n\n")); + } + }, + }); + + const out: SseEvent[] = []; + for await (const e of readSseEvents(stream, controller.signal)) { + out.push(e); + if (out.length === 1) controller.abort(); + } + expect(out.map((e) => e.event)).toEqual(["a"]); + }); + + test("aborts an idle reader immediately via reader.cancel()", async () => { + // Stream that sends one event then never resolves further reads — models + // an upstream that has stopped sending data. Without `reader.cancel()` + // the consumer would block forever after aborting. + const controller = new AbortController(); + let cancelled = false; + const stream = new ReadableStream({ + start(c) { + c.enqueue(new TextEncoder().encode("event: a\ndata: 1\n\n")); + }, + pull() { + return new Promise(() => { + /* never resolves */ + }); + }, + cancel() { + cancelled = true; + }, + }); + + const out: SseEvent[] = []; + const iterator = readSseEvents(stream, controller.signal); + const first = await iterator.next(); + if (!first.done) out.push(first.value); + controller.abort(); + const second = await iterator.next(); + expect(second.done).toBe(true); + expect(out.map((e) => e.event)).toEqual(["a"]); + expect(cancelled).toBe(true); + }); + + test("does not dispatch a block whose only field is id: (spec compliance)", async () => { + const events = await collect( + readSseEvents(streamOf(["id: only\n\nevent: ping\ndata: hi\n\n"])), + ); + expect(events).toEqual([{ event: "ping", data: "hi", id: undefined }]); + }); + + test("decodes a multi-byte UTF-8 character split across chunks", async () => { + const checkBytes = new TextEncoder().encode("✓"); + expect(checkBytes.length).toBe(3); + const stream = new ReadableStream({ + start(c) { + c.enqueue(new TextEncoder().encode("data: ")); + c.enqueue(checkBytes.subarray(0, 1)); + c.enqueue(checkBytes.subarray(1)); + c.enqueue(new TextEncoder().encode("\n\n")); + c.close(); + }, + }); + const events = await collect(readSseEvents(stream)); + expect(events).toEqual([{ event: "", data: "✓", id: undefined }]); + }); +}); From 6514dc03ddc368a631ada0a01911a8135ad510ac Mon Sep 17 00:00:00 2001 From: Hubert Zub Date: Fri, 22 May 2026 09:45:15 +0200 Subject: [PATCH 2/2] fix(appkit): address PR #345 review findings (section 9) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Apply Mario's defensive/correctness fixes from the supervisor API adapter review without touching the public API shape (sections 1-8 will land in a stacked branch). Highlights: High - Route the three SSE error-leak sites in supervisor-api.ts (streamBody catch, mapEvent "error", output_item.done with id="error") through a single emitError helper that returns a stable client-facing code (`Supervisor API error (transport|upstream_failed|upstream_tool| upstream_unknown)`) and logs the verbose detail server-side only. Addresses CWE-209 verbatim-upstream-error-text leak. Medium - Gate the terminal {status:"complete"} emission on lastCompleted.status / .error / .incomplete_details so a `response.completed` with a nested failed status no longer silently succeeds; surface as upstream_failed instead. Regression tests added. - Skip the terminal error in the streamBody catch when signal.aborted — consumer-initiated aborts now end with a clean stop, not a contradictory terminal error event. Regression test added. - Tighten the output_item.done error match: require item.type === "error" (or pair the reserved id="error" with a non-message type) so a stray assistant message with id="error" is not mis-classified. - Add maxLineChars / maxBufferChars caps to readSseEvents with 1 MiB / 8 MiB defaults; throw on overflow. Addresses CWE-770. Tests added. - Docs: add a CWE-1427 callout warning that hosted-tool `description` is a prompt-injection sink — do not derive it from untrusted input. - Redact the no-delta warning log: summariseErrorPayload extracts a short `type: message` line; full payload only via DEBUG. Addresses CWE-532. - Gate the buffer-level CRLF normalize in sse-reader on `\r` presence to skip the regex on LF-only steady state. Low - mapEvent("error") fallback no longer wraps "Unknown error" with literal JSON quotes (uses string branch). - Drop the misleading "we await the factory at module init" comment in dev-playground; the code never awaits. - Fix @example imports in supervisor-api.ts JSDoc to use @databricks/appkit/beta (the actual public re-export). - Replace trimStart() with single-U+0020 strip in sse-reader per the SSE spec; remove the now-dead per-line `\r$` strip after the buffer-level CRLF normalise. - Flag streamPath as @internal in connectors/serving/client.ts noting the CWE-918 SSRF risk if it ever leaks to user-controlled input. - Add JSDoc warning to workspaceClient on SupervisorApiAdapterOptions: passing a per-request OBO client would leak identity across requests (CWE-664). Signed-off-by: Hubert Zub --- apps/dev-playground/server/index.ts | 8 +- docs/docs/plugins/agents.md | 6 + packages/appkit/src/agents/supervisor-api.ts | 141 ++++++++++--- .../src/agents/tests/supervisor-api.test.ts | 185 ++++++++++++++++-- .../appkit/src/connectors/serving/client.ts | 9 + packages/appkit/src/stream/sse-reader.ts | 77 +++++++- .../src/stream/tests/sse-reader.test.ts | 53 +++++ 7 files changed, 430 insertions(+), 49 deletions(-) diff --git a/apps/dev-playground/server/index.ts b/apps/dev-playground/server/index.ts index 67187dcbe..53f6cade8 100644 --- a/apps/dev-playground/server/index.ts +++ b/apps/dev-playground/server/index.ts @@ -78,10 +78,10 @@ const helper = createAgent({ // Uncomment a `supervisorTools.*` entry (and import 'supervisorTools' from // '@databricks/appkit/beta') to give the model real powers. // -// We `await` the factory at module init so a misconfigured workspace -// (missing host, bad credentials) fails fast with a clear error here -// instead of as an unhandled rejection. Top-level await is fine in this -// ESM module. +// `createAgent({ model })` accepts an adapter promise, so the factory's +// host/credential resolution is awaited lazily on first dispatch (via +// `resolveAdapter` in the agents plugin). A misconfigured workspace will +// surface at first chat request, not at module init. const supervisor = createAgent({ instructions: "You are an assistant powered by the Databricks Supervisor API.", diff --git a/docs/docs/plugins/agents.md b/docs/docs/plugins/agents.md index c228551e2..f7a9548a8 100644 --- a/docs/docs/plugins/agents.md +++ b/docs/docs/plugins/agents.md @@ -269,6 +269,12 @@ const model = fromSupervisorApi({ `description` is **required and non-empty** — the LLM uses it to route between tools, so two Genie spaces both labelled "Genie space" will be indistinguishable. +:::warning Hosted-tool descriptions are trusted application configuration (CWE-1427) +A hosted tool's `description` is read by the LLM to decide when to route to that tool. **Do not derive it from untrusted input** — user messages, request bodies, freeform fields from external systems, or any value an attacker could influence. Treat `description` (and `id`/`name`) as application-controlled, alongside the agent's `instructions`. Allowing a user-controlled string here is a prompt-injection sink: a hostile description can convince the model to route to (or away from) a tool for any future request handled by the agent. + +The same caution applies to MCP `description`s and to any other field the model reads at routing time. +::: + | Factory | Tool kind | Identifier | |---|---|---| | `supervisorTools.genieSpace(id, description)` | Genie space | space id | diff --git a/packages/appkit/src/agents/supervisor-api.ts b/packages/appkit/src/agents/supervisor-api.ts index 228eb8be9..4e7edbb83 100644 --- a/packages/appkit/src/agents/supervisor-api.ts +++ b/packages/appkit/src/agents/supervisor-api.ts @@ -12,6 +12,59 @@ import { readSseEvents } from "../stream"; const logger = createLogger("agents:supervisor-api"); +/** + * Stable client-facing error codes. We never surface raw upstream error + * strings to the client (CWE-209) — the helper logs the verbose detail + * server-side and returns one of these codes in the {@link AgentEvent}. + */ +type SupervisorErrorCode = + | "transport" + | "upstream_failed" + | "upstream_tool" + | "upstream_unknown"; + +/** + * Single sink for all error events emitted by the adapter. Logs the verbose + * detail (stack, upstream payload, etc.) at `warn` level and returns a + * sanitised {@link AgentEvent} carrying only a stable code so the client + * never sees raw upstream text. + */ +function emitError(code: SupervisorErrorCode, detail: unknown): AgentEvent { + logger.warn("supervisor-api error code=%s detail=%O", code, detail); + return { + type: "status", + status: "error", + error: `Supervisor API error (${code})`, + }; +} + +/** + * Renders an upstream error / incomplete_details payload as a short + * single-line string for log lines. Avoids dumping the full JSON tree + * (CWE-532): we keep the discriminator (`type`/`code`) plus a trimmed + * message, and that's it. Full payloads are still available via + * `DEBUG=appkit:agents:supervisor-api`. + */ +function summariseErrorPayload(payload: unknown): string { + if (payload == null) return ""; + if (typeof payload === "string") { + return payload.length > 80 ? `${payload.slice(0, 80)}…` : payload; + } + if (typeof payload !== "object") return String(payload); + const obj = payload as Record; + const kind = + (typeof obj.type === "string" && obj.type) || + (typeof obj.code === "string" && obj.code) || + (typeof obj.reason === "string" && obj.reason) || + "object"; + const message = + (typeof obj.message === "string" && obj.message) || + (typeof obj.detail === "string" && obj.detail) || + ""; + const trimmed = message.length > 80 ? `${message.slice(0, 80)}…` : message; + return trimmed ? `${kind}: ${trimmed}` : kind; +} + /** * Transport shim: given a request body, returns the raw SSE byte stream from * the Supervisor API endpoint. Injected at construction time so callers can @@ -135,6 +188,12 @@ export interface SupervisorApiAdapterOptions { * and per-request authentication. When omitted, a `WorkspaceClient({})` * is created internally using the default SDK credential chain * (`DATABRICKS_HOST`, OAuth, PAT, etc.). + * + * ⚠ The `workspaceClient` is captured at construction and reused across + * every request. Passing a per-request OBO (On-Behalf-Of) client here + * would silently leak the first request's identity into all subsequent + * requests served by this adapter instance. Use the default credential + * chain or pass a service-principal client. (CWE-664) */ workspaceClient?: WorkspaceClientLike; } @@ -168,11 +227,12 @@ export interface SupervisorApiAdapterCtorOptions { * * @example * ```ts - * import { createApp, createAgent, agents } from "@databricks/appkit"; + * import { createApp, createAgent } from "@databricks/appkit"; * import { + * agents, * fromSupervisorApi, * supervisorTools, - * } from "@databricks/appkit/agents/supervisor-api"; + * } from "@databricks/appkit/beta"; * * const adapter = await fromSupervisorApi({ * model: "databricks-claude-sonnet-4", @@ -254,13 +314,11 @@ export class SupervisorApiAdapter implements AgentAdapter { try { stream = await this.streamBody(body, signal); } catch (err) { - const message = err instanceof Error ? err.message : String(err); - logger.warn("Supervisor API request failed: %s", message); - yield { - type: "status", - status: "error", - error: `Supervisor API error: ${message}`, - }; + // Aborts surface as exceptions thrown by `fetch`/SDK transports when + // the consumer cancels mid-request. Treat as a clean stop so consumers + // don't see a contradictory terminal `error` after their own abort. + if (signal?.aborted) return; + yield emitError("transport", err); return; } @@ -360,6 +418,22 @@ export class SupervisorApiAdapter implements AgentAdapter { } if (eventCounts.has("response.completed")) { + // SA sometimes signals a failed turn via `response.completed` with a + // nested `status: "failed"` (or a populated `error`/`incomplete_details`) + // rather than emitting `response.failed`. Without this gate the + // adapter would silently yield `complete` on a server-side failure. + if ( + lastCompleted?.status === "failed" || + lastCompleted?.error != null || + lastCompleted?.incomplete_details != null + ) { + yield emitError("upstream_failed", { + status: lastCompleted?.status, + error: lastCompleted?.error, + incomplete_details: lastCompleted?.incomplete_details, + }); + return; + } yield { type: "status", status: "complete" }; } @@ -367,19 +441,18 @@ export class SupervisorApiAdapter implements AgentAdapter { const histogram = [...eventCounts.entries()] .map(([t, n]) => `${t}=${n}`) .join(", "); - const completedError = lastCompleted?.error - ? JSON.stringify(lastCompleted.error) - : undefined; - const completedIncomplete = lastCompleted?.incomplete_details - ? JSON.stringify(lastCompleted.incomplete_details) - : undefined; logger.warn( "Supervisor API stream completed without any output_text deltas. " + "events={%s} completed.status=%s completed.error=%s completed.incomplete=%s", histogram, lastCompleted?.status ?? "", - completedError ?? "", - completedIncomplete ?? "", + summariseErrorPayload(lastCompleted?.error), + summariseErrorPayload(lastCompleted?.incomplete_details), + ); + logger.debug( + "Supervisor API no-delta full payload: error=%O incomplete=%O", + lastCompleted?.error, + lastCompleted?.incomplete_details, ); } } @@ -476,14 +549,20 @@ function mapEvent( // the stream produced none, then emits `{status:"complete"}` itself. case "response.failed": - return { type: "status", status: "error", error: "Response failed" }; + return emitError("upstream_failed", data); case "error": { - const errMsg = + // Branch detail extraction so a missing `error` field doesn't surface + // the JSON-stringified literal `'"Unknown error"'` (with quotes) in + // server logs. The client never sees this string — `emitError` + // sanitises it to a stable code. + const detail = typeof data.error === "string" ? data.error - : JSON.stringify(data.error ?? "Unknown error"); - return { type: "status", status: "error", error: errMsg }; + : data.error == null + ? "Unknown error" + : data.error; + return emitError("upstream_unknown", detail); } case "response.output_item.done": { @@ -495,9 +574,15 @@ function mapEvent( } | undefined; - if (item?.id === "error") { - const errText = item.content?.[0]?.text ?? "Unknown tool error from SA"; - return { type: "status", status: "error", error: errText }; + // SA's contract reserves `item.id === "error"` for tool failures, but + // a 5-char identifier collision is too small a margin. Require either + // an explicit `type === "error"` or pair the reserved id with a + // non-message type (a normal assistant message uses `type: "message"`). + if ( + item?.type === "error" || + (item?.id === "error" && item?.type !== "message") + ) { + return emitError("upstream_tool", item); } // Fallback: when SA produces a tool-driven response (e.g. Genie space), @@ -541,7 +626,7 @@ function mapEvent( * import { * fromSupervisorApi, * supervisorTools, - * } from "@databricks/appkit/agents/supervisor-api"; + * } from "@databricks/appkit/beta"; * * const adapter = await fromSupervisorApi({ * model: "databricks-claude-sonnet-4", @@ -553,6 +638,12 @@ function mapEvent( * ], * }); * ``` + * + * @remarks + * ⚠ When passing your own `workspaceClient`, see the warning on + * {@link SupervisorApiAdapterOptions.workspaceClient} — the client is + * captured once and reused, so per-request OBO clients would leak + * identity across requests. */ export async function fromSupervisorApi( options: SupervisorApiAdapterOptions, diff --git a/packages/appkit/src/agents/tests/supervisor-api.test.ts b/packages/appkit/src/agents/tests/supervisor-api.test.ts index 9606b1c6a..9877808e4 100644 --- a/packages/appkit/src/agents/tests/supervisor-api.test.ts +++ b/packages/appkit/src/agents/tests/supervisor-api.test.ts @@ -221,8 +221,15 @@ describe("SupervisorApiAdapter", () => { ]); }); - test("maps response.failed to a status:error event", async () => { - const { streamBody } = makeStreamBody([sseEvent("response.failed", {})]); + test("maps response.failed to a sanitised status:error event", async () => { + // The verbose upstream payload must NOT reach the client (CWE-209) — + // only the stable `upstream_failed` code does. Server logs still keep + // the full detail via logger.warn. + const { streamBody } = makeStreamBody([ + sseEvent("response.failed", { + response: { error: { message: "secret-internal-stack-trace" } }, + }), + ]); const adapter = new SupervisorApiAdapter({ streamBody, model: "databricks-claude-sonnet-4", @@ -233,13 +240,19 @@ describe("SupervisorApiAdapter", () => { expect(events).toContainEqual({ type: "status", status: "error", - error: "Response failed", + error: "Supervisor API error (upstream_failed)", }); + // Belt-and-braces: the leaky upstream string is never in the event. + for (const e of events) { + if (e.type === "status" && "error" in e) { + expect(e.error).not.toContain("secret-internal-stack-trace"); + } + } }); - test("maps top-level error events", async () => { + test("maps top-level error events to sanitised upstream_unknown code", async () => { const { streamBody } = makeStreamBody([ - sseEvent("error", { error: "rate limited" }), + sseEvent("error", { error: "rate limited (workspace abc-123)" }), ]); const adapter = new SupervisorApiAdapter({ streamBody, @@ -251,16 +264,22 @@ describe("SupervisorApiAdapter", () => { expect(events).toContainEqual({ type: "status", status: "error", - error: "rate limited", + error: "Supervisor API error (upstream_unknown)", }); + for (const e of events) { + if (e.type === "status" && "error" in e) { + expect(e.error).not.toContain("workspace abc-123"); + } + } }); - test("maps response.output_item.done with id:'error' to status:error", async () => { + test("maps response.output_item.done error item to sanitised upstream_tool code", async () => { const { streamBody } = makeStreamBody([ sseEvent("response.output_item.done", { item: { id: "error", - content: [{ text: "Tool execution failed" }], + type: "error", + content: [{ text: "Tool stack trace with /home/user paths" }], }, }), sseEvent("response.completed", {}), @@ -275,8 +294,43 @@ describe("SupervisorApiAdapter", () => { expect(events).toContainEqual({ type: "status", status: "error", - error: "Tool execution failed", + error: "Supervisor API error (upstream_tool)", + }); + for (const e of events) { + if (e.type === "status" && "error" in e) { + expect(e.error).not.toContain("/home/user"); + } + } + }); + + test("does NOT treat output_item.done id:'error' as error when type:'message' (collision guard)", async () => { + // SA reserves `id === "error"` for tool failures, but the 5-char id + // could collide with a legitimately-id'd assistant message. The guard + // requires `type === "error"` (or a non-message type alongside the + // reserved id) so a stray message with id="error" is not mis-classified. + const { streamBody } = makeStreamBody([ + sseEvent("response.output_item.done", { + item: { + id: "error", + type: "message", + role: "assistant", + content: [{ type: "output_text", text: "hello from error-id msg" }], + }, + }), + sseEvent("response.completed", {}), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + expect(events).toEqual([ + { type: "status", status: "running" }, + { type: "message_delta", content: "hello from error-id msg" }, + { type: "status", status: "complete" }, + ]); }); test("falls back to output_item.done text when no deltas streamed (tool-driven SA response)", async () => { @@ -345,10 +399,14 @@ describe("SupervisorApiAdapter", () => { ]); }); - test("emits status:error when the underlying streamBody throws", async () => { + test("emits sanitised transport error when the underlying streamBody throws", async () => { const streamBody = vi .fn() - .mockRejectedValue(new Error("Supervisor API error (500): boom")); + .mockRejectedValue( + new Error( + "HTTP 500 from https://workspace-internal.foo: stack trace ...", + ), + ); const adapter = new SupervisorApiAdapter({ streamBody, model: "databricks-claude-sonnet-4", @@ -359,8 +417,45 @@ describe("SupervisorApiAdapter", () => { expect(events).toContainEqual({ type: "status", status: "error", - error: "Supervisor API error: Supervisor API error (500): boom", + error: "Supervisor API error (transport)", }); + for (const e of events) { + if (e.type === "status" && "error" in e) { + expect(e.error).not.toContain("workspace-internal.foo"); + expect(e.error).not.toContain("stack trace"); + } + } + }); + + test("does NOT emit a terminal error when the consumer aborts before streamBody resolves", async () => { + // Regression: previously the streamBody catch yielded a sanitised + // `{status:"error"}` even when the underlying rejection was an abort + // triggered by the consumer. Consumers that issued the abort must see + // a clean stop (zero further events after their abort), not a + // contradictory terminal error. + const controller = new AbortController(); + const streamBody = vi.fn(async (_body, signal?: AbortSignal) => { + controller.abort(); + // Simulate the SDK transport rejecting because the signal aborted. + // The catch path must observe `signal.aborted` and return silently. + throw new DOMException( + signal?.aborted ? "aborted" : "fetch failed", + "AbortError", + ); + }); + + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { + executeTool: vi.fn(), + signal: controller.signal, + }), + ); + + expect(events).toEqual([{ type: "status", status: "running" }]); }); test("short-circuits when the signal is already aborted", async () => { @@ -546,8 +641,72 @@ describe("SupervisorApiAdapter", () => { ); expect(events).toEqual([ { type: "status", status: "running" }, - { type: "status", status: "error", error: "Response failed" }, + { + type: "status", + status: "error", + error: "Supervisor API error (upstream_failed)", + }, + ]); + }); + + test("does NOT yield complete when response.completed carries status:'failed'", async () => { + // Regression for the silent-success-on-failure bug: SA occasionally + // reports a failed turn via `response.completed.status = "failed"` + // (with optional `error`/`incomplete_details`) rather than emitting + // `response.failed`. The adapter must surface this as a terminal + // error and NOT yield `{status:"complete"}`. + const { streamBody } = makeStreamBody([ + sseEvent("response.completed", { + response: { + status: "failed", + error: { message: "tool timeout" }, + }, + }), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + expect(events).toEqual([ + { type: "status", status: "running" }, + { + type: "status", + status: "error", + error: "Supervisor API error (upstream_failed)", + }, + ]); + }); + + test("does NOT yield complete when response.completed carries a populated error", async () => { + // Variant: status reported as "completed" but `error` is non-null. + // Treat as a terminal failure rather than silently completing. + const { streamBody } = makeStreamBody([ + sseEvent("response.completed", { + response: { + status: "completed", + error: { code: "internal" }, + }, + }), ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + expect(events).toContainEqual({ + type: "status", + status: "error", + error: "Supervisor API error (upstream_failed)", + }); + expect(events).not.toContainEqual({ + type: "status", + status: "complete", + }); }); test("does not yield complete when the consumer aborts mid-stream", async () => { diff --git a/packages/appkit/src/connectors/serving/client.ts b/packages/appkit/src/connectors/serving/client.ts index f75993a39..3c556da79 100644 --- a/packages/appkit/src/connectors/serving/client.ts +++ b/packages/appkit/src/connectors/serving/client.ts @@ -85,6 +85,15 @@ export async function invoke( * * When `signal` is provided it is bridged to the SDK's `Context` / * `CancellationToken` so aborts cancel the outbound HTTP request. + * + * @internal + * + * Not part of the public AppKit surface. `path` is passed through to the + * SDK without any allowlist — exposing this to user-controlled input would + * turn it into workspace-credentialled SSRF (CWE-918). Internal callers + * must hard-code the path (or build it from a closed enum). New callers + * inside the package: keep this constraint, and do not re-export from + * `beta.ts` or any other entry point. */ export async function streamPath( client: ApiClientLike, diff --git a/packages/appkit/src/stream/sse-reader.ts b/packages/appkit/src/stream/sse-reader.ts index 091f132dc..f80f0738e 100644 --- a/packages/appkit/src/stream/sse-reader.ts +++ b/packages/appkit/src/stream/sse-reader.ts @@ -14,6 +14,33 @@ export interface SseEvent { id?: string; } +/** + * Configuration for {@link readSseEvents}. All limits are in UTF-16 code + * units (JS string `.length`) and exist as a DoS guard (CWE-770) for + * untrusted upstreams that might stream arbitrarily large lines or never + * emit a block terminator. + */ +interface ReadSseEventsOptions { + /** + * Maximum length of any single SSE event block (i.e. the text between + * two `\n\n` separators). Exceeding this throws. + * + * @default 1 MiB (1_048_576) + */ + maxLineChars?: number; + /** + * Maximum length of the rolling input buffer when no block terminator + * has been seen yet. Exceeding this throws — protects against an + * upstream that streams indefinitely without ever sending `\n\n`. + * + * @default 8 MiB (8_388_608) + */ + maxBufferChars?: number; +} + +const DEFAULT_MAX_SSE_LINE_CHARS = 1024 * 1024; +const DEFAULT_MAX_SSE_BUFFER_CHARS = 8 * 1024 * 1024; + /** * Async-iterates Server-Sent Events from a UTF-8 byte stream. * @@ -26,12 +53,18 @@ export interface SseEvent { * after destructuring. * * Terminates when the stream closes or `signal` aborts; releases the reader - * lock in either case. + * lock in either case. Throws when {@link ReadSseEventsOptions.maxLineChars} + * or {@link ReadSseEventsOptions.maxBufferChars} are exceeded. */ export async function* readSseEvents( stream: ReadableStream, signal?: AbortSignal, + options?: ReadSseEventsOptions, ): AsyncGenerator { + const maxLineChars = options?.maxLineChars ?? DEFAULT_MAX_SSE_LINE_CHARS; + const maxBufferChars = + options?.maxBufferChars ?? DEFAULT_MAX_SSE_BUFFER_CHARS; + const reader = stream.getReader(); const decoder = new TextDecoder(); let buffer = ""; @@ -55,6 +88,11 @@ export async function* readSseEvents( if (signal?.aborted) break; const { done, value } = await reader.read(); if (done) { + if (buffer.length > maxLineChars) { + throw new Error( + `readSseEvents: trailing SSE block exceeds maxLineChars (${maxLineChars} UTF-16 code units)`, + ); + } const tail = parseSseBlock(buffer); if (tail) yield tail; break; @@ -62,13 +100,27 @@ export async function* readSseEvents( buffer += decoder.decode(value, { stream: true }); - const normalized = buffer.replace(/\r\n/g, "\n"); + // Gate the CRLF normalize on `\r` presence — saves a full-buffer + // regex scan on every chunk for the common LF-only steady state. + const normalized = + buffer.indexOf("\r") !== -1 ? buffer.replace(/\r\n/g, "\n") : buffer; const blocks = normalized.split("\n\n"); // Last entry is either an incomplete block or "" (when the chunk ended // exactly on a boundary). Either way, keep it for the next iteration. buffer = blocks.pop() ?? ""; + if (buffer.length > maxBufferChars) { + throw new Error( + `readSseEvents: incomplete SSE block exceeds maxBufferChars (${maxBufferChars} UTF-16 code units) without a terminator`, + ); + } + for (const block of blocks) { + if (block.length > maxLineChars) { + throw new Error( + `readSseEvents: SSE block exceeds maxLineChars (${maxLineChars} UTF-16 code units)`, + ); + } const event = parseSseBlock(block); if (event) yield event; } @@ -79,24 +131,35 @@ export async function* readSseEvents( } } +/** + * Per the SSE spec, only a single leading `U+0020` is stripped from a field + * value — not arbitrary whitespace. `trimStart()` would also strip tabs, + * NBSP, etc.; for callers that feed binary or whitespace-prefixed payloads + * this is a footgun. + */ +function stripOneLeadingSpace(s: string): string { + return s.startsWith(" ") ? s.slice(1) : s; +} + function parseSseBlock(block: string): SseEvent | null { if (block.length === 0) return null; + // CRLF was already normalised at the buffer level, so each `line` here is + // already free of trailing `\r` — no per-line strip needed. const lines = block.split("\n"); let eventName = ""; let id: string | undefined; const dataLines: string[] = []; - for (const rawLine of lines) { - const line = rawLine.replace(/\r$/, ""); + for (const line of lines) { if (line === "" || line.startsWith(":")) continue; if (line.startsWith("event:")) { - eventName = line.slice(6).trimStart(); + eventName = stripOneLeadingSpace(line.slice(6)); } else if (line.startsWith("data:")) { - dataLines.push(line.slice(5).trimStart()); + dataLines.push(stripOneLeadingSpace(line.slice(5))); } else if (line.startsWith("id:")) { - id = line.slice(3).trimStart(); + id = stripOneLeadingSpace(line.slice(3)); } // Other fields (`retry:`, custom) are ignored by design. } diff --git a/packages/appkit/src/stream/tests/sse-reader.test.ts b/packages/appkit/src/stream/tests/sse-reader.test.ts index 6f7176b62..d83ba26f2 100644 --- a/packages/appkit/src/stream/tests/sse-reader.test.ts +++ b/packages/appkit/src/stream/tests/sse-reader.test.ts @@ -179,4 +179,57 @@ describe("readSseEvents", () => { const events = await collect(readSseEvents(stream)); expect(events).toEqual([{ event: "", data: "✓", id: undefined }]); }); + + test("throws when a single block exceeds maxLineChars (DoS guard)", async () => { + // A complete block whose total length exceeds the cap must throw rather + // than silently propagate to the consumer — protects callers from + // upstreams that stream arbitrarily large payloads (CWE-770). + const huge = `data: ${"x".repeat(200)}\n\n`; + await expect(async () => { + for await (const _ of readSseEvents(streamOf([huge]), undefined, { + maxLineChars: 100, + })) { + /* iterate */ + } + }).rejects.toThrow(/exceeds maxLineChars/); + }); + + test("throws when the rolling buffer exceeds maxBufferChars without a terminator", async () => { + // An upstream that streams forever without ever sending the `\n\n` + // block separator must not grow the buffer unboundedly — throw once + // the cap is exceeded. + const stream = new ReadableStream({ + pull(c) { + c.enqueue(new TextEncoder().encode("x".repeat(50))); + // No close() — keep feeding until the cap fires. + }, + }); + await expect(async () => { + for await (const _ of readSseEvents(stream, undefined, { + maxBufferChars: 200, + maxLineChars: 10_000, + })) { + /* iterate */ + } + }).rejects.toThrow(/exceeds maxBufferChars/); + }); + + test("strips only a single leading U+0020 from field values (spec compliance)", async () => { + // `trimStart()` would strip tabs / NBSP / multi-space prefixes, which + // is wrong per the SSE spec — only one leading U+0020 may be removed. + const events = await collect( + readSseEvents(streamOf(["data: with-leading-spaces\n\n"])), + ); + // First space is stripped; second is preserved. + expect(events).toEqual([ + { event: "", data: " with-leading-spaces", id: undefined }, + ]); + }); + + test("preserves tab-prefixed data values (trimStart would have stripped)", async () => { + const events = await collect( + readSseEvents(streamOf(["data:\t\tvalue\n\n"])), + ); + expect(events).toEqual([{ event: "", data: "\t\tvalue", id: undefined }]); + }); });