diff --git a/packages/appkit/src/index.ts b/packages/appkit/src/index.ts index b5155834..033e6d09 100644 --- a/packages/appkit/src/index.ts +++ b/packages/appkit/src/index.ts @@ -62,14 +62,25 @@ export { type ToPlugin, toPlugin, } from "./plugin"; -export { analytics, files, genie, lakebase, server, serving } from "./plugins"; +export { + agent, + analytics, + files, + genie, + lakebase, + server, + serving, +} from "./plugins"; export { type FunctionTool, type HostedTool, + isFunctionTool, + isHostedTool, mcpServer, type ToolConfig, tool, } from "./plugins/agent/tools"; +export type { AgentTool } from "./plugins/agent/types"; export type { EndpointConfig, ServingEndpointEntry, diff --git a/packages/appkit/src/plugins/agent/agent.ts b/packages/appkit/src/plugins/agent/agent.ts new file mode 100644 index 00000000..15827cad --- /dev/null +++ b/packages/appkit/src/plugins/agent/agent.ts @@ -0,0 +1,767 @@ +import { randomUUID } from "node:crypto"; +import path from "node:path"; +import type express from "express"; +import pc from "picocolors"; +import type { + AgentAdapter, + AgentToolDefinition, + IAppRouter, + Message, + PluginPhase, + ResponseStreamEvent, + ToolProvider, +} from "shared"; +import { createLogger } from "../../logging/logger"; +import { Plugin, toPlugin } from "../../plugin"; +import type { PluginManifest } from "../../registry"; +import { loadAgentConfigs } from "./config-loader"; +import { agentStreamDefaults } from "./defaults"; +import { AgentEventTranslator } from "./event-translator"; +import manifest from "./manifest.json"; +import { chatRequestSchema, invocationsRequestSchema } from "./schemas"; +import { buildBaseSystemPrompt, composeSystemPrompt } from "./system-prompt"; +import { InMemoryThreadStore } from "./thread-store"; +import { + AppKitMcpClient, + type FunctionTool, + functionToolToDefinition, + isFunctionTool, + isHostedTool, + resolveHostedTools, +} from "./tools"; +import type { AgentPluginConfig, RegisteredAgent, ToolEntry } from "./types"; + +const logger = createLogger("agent"); + +function isToolProvider(obj: unknown): obj is ToolProvider { + return ( + typeof obj === "object" && + obj !== null && + "getAgentTools" in obj && + typeof (obj as any).getAgentTools === "function" && + "executeAgentTool" in obj && + typeof (obj as any).executeAgentTool === "function" + ); +} + +export class AgentPlugin extends Plugin { + static manifest = manifest as PluginManifest<"agent">; + static phase: PluginPhase = "deferred"; + + protected declare config: AgentPluginConfig; + + private agents = new Map(); + private defaultAgentName: string | null = null; + private toolIndex = new Map(); + private threadStore; + private activeStreams = new Map(); + private mcpClient: AppKitMcpClient | null = null; + + constructor(config: AgentPluginConfig) { + super(config); + this.config = config; + this.threadStore = config.threadStore ?? new InMemoryThreadStore(); + } + + async setup() { + await this.collectTools(); + await this.loadAgents(); + this.mountInvocationsRoute(); + } + + private async loadAgents() { + // 1. Load config-file agents first + const agentsDir = + this.config.agentsDir ?? path.join(process.cwd(), "config/agents"); + const fileConfigs = loadAgentConfigs(agentsDir); + + for (const fc of fileConfigs) { + try { + const { DatabricksAdapter } = await import("../../agents/databricks"); + const adapter = await DatabricksAdapter.fromModelServing(fc.endpoint, { + maxSteps: fc.maxSteps, + maxTokens: fc.maxTokens, + }); + this.agents.set(fc.name, { + name: fc.name, + adapter, + systemPrompt: fc.systemPrompt || undefined, + }); + if (fc.default && !this.defaultAgentName) { + this.defaultAgentName = fc.name; + } + if (!this.defaultAgentName) { + this.defaultAgentName = fc.name; + } + } catch (error) { + logger.error( + "Failed to create agent '%s' from config: %O", + fc.name, + error, + ); + } + } + + // 2. Code-defined agents override config-file agents per-name + if (this.config.agents) { + const entries = Object.entries(this.config.agents); + for (const [name, entry] of entries) { + if ( + this.agents.has(name) && + fileConfigs.some((fc) => fc.name === name) + ) { + logger.warn( + "Agent '%s' defined in both code and config file. Code takes precedence.", + name, + ); + } + + const { adapter, systemPrompt } = await this.resolveAgentEntry(entry); + this.agents.set(name, { name, adapter, systemPrompt }); + if (!this.defaultAgentName) { + this.defaultAgentName = name; + } + } + } + + if (this.config.defaultAgent) { + this.defaultAgentName = this.config.defaultAgent; + } + + if (fileConfigs.length > 0) { + logger.info( + "Loaded %d agent(s) from config files: %s", + fileConfigs.length, + fileConfigs.map((c) => c.name).join(", "), + ); + } + } + + private async resolveAgentEntry( + entry: import("./types").AgentEntry, + ): Promise<{ adapter: AgentAdapter; systemPrompt?: string }> { + if (this.isAgentDefinition(entry)) { + const adapter = await entry.adapter; + return { adapter, systemPrompt: entry.systemPrompt }; + } + const adapter = await (entry as AgentAdapter | Promise); + return { adapter }; + } + + private isAgentDefinition( + entry: unknown, + ): entry is import("./types").AgentDefinition { + return typeof entry === "object" && entry !== null && "adapter" in entry; + } + + async reloadAgents() { + this.agents.clear(); + this.defaultAgentName = null; + await this.loadAgents(); + } + + private mountInvocationsRoute() { + const serverPlugin = this.config.plugins?.server as + | { addExtension?: (fn: (app: any) => void) => void } + | undefined; + + if (!serverPlugin?.addExtension) return; + + serverPlugin.addExtension((app: import("express").Application) => { + app.post( + "/invocations", + (req: express.Request, res: express.Response) => { + this._handleInvocations(req, res); + }, + ); + }); + + logger.info("Mounted POST /invocations route"); + } + + private async collectTools() { + // 1. Auto-discover from sibling ToolProvider plugins + const plugins = this.config.plugins; + if (plugins) { + for (const [pluginName, pluginInstance] of Object.entries(plugins)) { + if (pluginName === "agent") continue; + if (!isToolProvider(pluginInstance)) continue; + + const tools = (pluginInstance as ToolProvider).getAgentTools(); + for (const tool of tools) { + const qualifiedName = `${pluginName}.${tool.name}`; + this.toolIndex.set(qualifiedName, { + source: "plugin", + plugin: pluginInstance as ToolProvider & { + asUser(req: any): any; + }, + def: { ...tool, name: qualifiedName }, + localName: tool.name, + }); + } + + logger.info( + "Collected %d tools from plugin %s", + tools.length, + pluginName, + ); + } + } + + // 2. Process explicit tools from config + if (this.config.tools) { + const hostedTools = this.config.tools.filter(isHostedTool); + const functionTools = this.config.tools.filter(isFunctionTool); + + // 2a. Resolve HostedTools via MCP client + if (hostedTools.length > 0) { + await this.connectHostedTools(hostedTools); + } + + // 2b. Add FunctionTools + for (const ft of functionTools) { + this.addFunctionToolToIndex(ft); + } + } + + this.printTools(); + } + + private printTools() { + const entries = Array.from(this.toolIndex.values()); + if (entries.length === 0) return; + + const SOURCE_COLORS: Record string> = { + plugin: pc.blue, + function: pc.yellow, + mcp: pc.magenta, + }; + + const rows = entries + .map((e) => ({ + source: e.source, + name: e.def.name, + description: e.def.description.slice(0, 60), + })) + .sort( + (a, b) => + a.source.localeCompare(b.source) || a.name.localeCompare(b.name), + ); + + const maxSourceLen = Math.max(...rows.map((r) => r.source.length)); + const maxNameLen = Math.min( + 40, + Math.max(...rows.map((r) => r.name.length)), + ); + const separator = pc.dim("─".repeat(60)); + + console.log(""); + console.log(` ${pc.bold("Agent Tools")} ${pc.dim(`(${rows.length})`)}`); + console.log(` ${separator}`); + + for (const { source, name, description } of rows) { + const colorize = SOURCE_COLORS[source] ?? pc.white; + const sourceStr = colorize(pc.bold(source.padEnd(maxSourceLen))); + const nameStr = + name.length > maxNameLen + ? `${name.slice(0, maxNameLen - 1)}…` + : name.padEnd(maxNameLen); + console.log(` ${sourceStr} ${nameStr} ${pc.dim(description)}`); + } + + console.log(` ${separator}`); + console.log(""); + } + + private async connectHostedTools( + hostedTools: import("./tools/hosted-tools").HostedTool[], + ) { + let host: string | undefined; + let authenticate: () => Promise>; + + try { + const { getWorkspaceClient } = await import("../../context"); + const wsClient = getWorkspaceClient(); + await wsClient.config.ensureResolved(); + host = wsClient.config.host; + authenticate = async (): Promise> => { + const headers = new Headers(); + await wsClient.config.authenticate(headers); + return Object.fromEntries(headers.entries()); + }; + } catch { + host = process.env.DATABRICKS_HOST; + authenticate = async (): Promise> => { + const token = process.env.DATABRICKS_TOKEN; + if (token) return { Authorization: `Bearer ${token}` }; + return {}; + }; + } + + if (!host) { + logger.warn( + "No Databricks host available — skipping %d hosted tools", + hostedTools.length, + ); + return; + } + + this.mcpClient = new AppKitMcpClient(host, authenticate); + + const endpoints = resolveHostedTools(hostedTools); + await this.mcpClient.connectAll(endpoints); + + for (const def of this.mcpClient.getAllToolDefinitions()) { + this.toolIndex.set(def.name, { + source: "mcp", + mcpToolName: def.name, + def, + }); + } + + logger.info( + "Connected %d MCP tools from %d hosted tool(s)", + this.mcpClient.getAllToolDefinitions().length, + hostedTools.length, + ); + } + + private addFunctionToolToIndex(ft: FunctionTool) { + const def = functionToolToDefinition(ft); + this.toolIndex.set(ft.name, { + source: "function", + functionTool: ft, + def, + }); + } + + addTools(tools: FunctionTool[]) { + for (const ft of tools) { + this.addFunctionToolToIndex(ft); + } + logger.info( + "Added %d function tools, total: %d", + tools.length, + this.toolIndex.size, + ); + } + + injectRoutes(router: IAppRouter) { + this.route(router, { + name: "chat", + method: "post", + path: "/chat", + handler: async (req, res) => this._handleChat(req, res), + }); + + this.route(router, { + name: "cancel", + method: "post", + path: "/cancel", + handler: async (req, res) => this._handleCancel(req, res), + }); + + this.route(router, { + name: "threads", + method: "get", + path: "/threads", + handler: async (req, res) => this._handleListThreads(req, res), + }); + + this.route(router, { + name: "thread", + method: "get", + path: "/threads/:threadId", + handler: async (req, res) => this._handleGetThread(req, res), + }); + + this.route(router, { + name: "deleteThread", + method: "delete", + path: "/threads/:threadId", + handler: async (req, res) => this._handleDeleteThread(req, res), + }); + + this.route(router, { + name: "info", + method: "get", + path: "/info", + handler: async (_req, res) => { + res.json({ + toolCount: this.toolIndex.size, + tools: this.getAllToolDefinitions(), + agents: Array.from(this.agents.keys()), + defaultAgent: this.defaultAgentName, + }); + }, + }); + } + + clientConfig(): Record { + return { + tools: this.getAllToolDefinitions(), + agents: Array.from(this.agents.keys()), + defaultAgent: this.defaultAgentName, + }; + } + + private async _handleChat( + req: express.Request, + res: express.Response, + ): Promise { + const parsed = chatRequestSchema.safeParse(req.body); + if (!parsed.success) { + res.status(400).json({ + error: "Invalid request", + details: parsed.error.flatten().fieldErrors, + }); + return; + } + + const { message, threadId, agent: agentName } = parsed.data; + + const resolvedAgent = this.resolveAgent(agentName); + if (!resolvedAgent) { + res.status(400).json({ + error: agentName + ? `Agent "${agentName}" not found` + : "No agent registered", + }); + return; + } + + const userId = this.resolveUserId(req); + + let thread = threadId ? await this.threadStore.get(threadId, userId) : null; + + if (threadId && !thread) { + res.status(404).json({ error: `Thread ${threadId} not found` }); + return; + } + + if (!thread) { + thread = await this.threadStore.create(userId); + } + + const userMessage: Message = { + id: randomUUID(), + role: "user", + content: message, + createdAt: new Date(), + }; + await this.threadStore.addMessage(thread.id, userId, userMessage); + + return this._streamChat(req, res, resolvedAgent, thread, userId); + } + + private async _streamChat( + req: express.Request, + res: express.Response, + resolvedAgent: RegisteredAgent, + thread: import("shared").Thread, + userId: string, + ): Promise { + const tools = this.getAllToolDefinitions(); + const abortController = new AbortController(); + const signal = abortController.signal; + + const self = this; + const executeTool = async ( + qualifiedName: string, + args: unknown, + ): Promise => { + const entry = self.toolIndex.get(qualifiedName); + if (!entry) throw new Error(`Unknown tool: ${qualifiedName}`); + + const result = await self.execute( + async (execSignal) => { + switch (entry.source) { + case "plugin": { + const target = (entry.plugin as any).asUser(req); + return (target as ToolProvider).executeAgentTool( + entry.localName, + args, + execSignal, + ); + } + case "function": + return entry.functionTool.execute( + args as Record, + ); + case "mcp": { + if (!self.mcpClient) { + throw new Error("MCP client not connected"); + } + const oboToken = req.headers["x-forwarded-access-token"]; + const mcpAuth = + typeof oboToken === "string" + ? { Authorization: `Bearer ${oboToken}` } + : undefined; + return self.mcpClient.callTool(entry.mcpToolName, args, mcpAuth); + } + } + }, + { + default: { + telemetryInterceptor: { enabled: true }, + timeout: 30_000, + }, + }, + ); + + if (result === undefined) { + return `Error: Tool "${qualifiedName}" execution failed`; + } + + const MAX_TOOL_RESULT_CHARS = 50_000; + const serialized = + typeof result === "string" ? result : JSON.stringify(result); + if (serialized.length > MAX_TOOL_RESULT_CHARS) { + return `${serialized.slice(0, MAX_TOOL_RESULT_CHARS)}\n\n[Result truncated: ${serialized.length} chars exceeds ${MAX_TOOL_RESULT_CHARS} limit]`; + } + return result; + }; + + const requestId = randomUUID(); + this.activeStreams.set(requestId, abortController); + + await this.executeStream( + res, + async function* () { + const translator = new AgentEventTranslator(); + try { + for (const evt of translator.translate({ + type: "metadata", + data: { threadId: thread.id }, + })) { + yield evt; + } + + const pluginNames = self.config.plugins + ? Object.keys(self.config.plugins).filter( + (n) => n !== "agent" && n !== "server", + ) + : []; + const basePrompt = buildBaseSystemPrompt(pluginNames); + const fullPrompt = composeSystemPrompt( + basePrompt, + resolvedAgent.systemPrompt, + ); + + const messagesWithSystem: Message[] = [ + { + id: "system", + role: "system", + content: fullPrompt, + createdAt: new Date(), + }, + ...thread.messages, + ]; + + const stream = resolvedAgent.adapter.run( + { + messages: messagesWithSystem, + tools, + threadId: thread.id, + signal, + }, + { executeTool, signal }, + ); + + let fullContent = ""; + + for await (const event of stream) { + if (signal.aborted) break; + + if (event.type === "message_delta") { + fullContent += event.content; + } + + for (const translated of translator.translate(event)) { + yield translated; + } + } + + if (fullContent) { + const assistantMessage: Message = { + id: randomUUID(), + role: "assistant", + content: fullContent, + createdAt: new Date(), + }; + await self.threadStore.addMessage( + thread.id, + userId, + assistantMessage, + ); + } + + for (const evt of translator.finalize()) { + yield evt; + } + } catch (error) { + if (signal.aborted) return; + logger.error("Agent chat error: %O", error); + throw error; + } finally { + self.activeStreams.delete(requestId); + } + }, + { + ...agentStreamDefaults, + stream: { + ...agentStreamDefaults.stream, + streamId: requestId, + }, + }, + ); + } + + private async _handleInvocations( + req: express.Request, + res: express.Response, + ): Promise { + const parsed = invocationsRequestSchema.safeParse(req.body); + if (!parsed.success) { + res.status(400).json({ + error: "Invalid request", + details: parsed.error.flatten().fieldErrors, + }); + return; + } + + const { input } = parsed.data; + const resolvedAgent = this.resolveAgent(); + if (!resolvedAgent) { + res.status(400).json({ error: "No agent registered" }); + return; + } + + const userId = this.resolveUserId(req); + const thread = await this.threadStore.create(userId); + + if (typeof input === "string") { + const msg: Message = { + id: randomUUID(), + role: "user", + content: input, + createdAt: new Date(), + }; + await this.threadStore.addMessage(thread.id, userId, msg); + } else { + for (const item of input) { + const role = item.role ?? "user"; + const content = + typeof item.content === "string" + ? item.content + : JSON.stringify(item.content ?? ""); + if (!content) continue; + const msg: Message = { + id: randomUUID(), + role: role as Message["role"], + content, + createdAt: new Date(), + }; + await this.threadStore.addMessage(thread.id, userId, msg); + } + } + + return this._streamChat(req, res, resolvedAgent, thread, userId); + } + + private async _handleCancel( + req: express.Request, + res: express.Response, + ): Promise { + const { streamId } = req.body as { streamId?: string }; + if (!streamId) { + res.status(400).json({ error: "streamId is required" }); + return; + } + const controller = this.activeStreams.get(streamId); + if (controller) { + controller.abort("Cancelled by user"); + this.activeStreams.delete(streamId); + } + res.json({ cancelled: true }); + } + + private async _handleListThreads( + req: express.Request, + res: express.Response, + ): Promise { + const userId = this.resolveUserId(req); + const threads = await this.threadStore.list(userId); + res.json({ threads }); + } + + private async _handleGetThread( + req: express.Request, + res: express.Response, + ): Promise { + const userId = this.resolveUserId(req); + const thread = await this.threadStore.get(req.params.threadId, userId); + if (!thread) { + res.status(404).json({ error: "Thread not found" }); + return; + } + res.json(thread); + } + + private async _handleDeleteThread( + req: express.Request, + res: express.Response, + ): Promise { + const userId = this.resolveUserId(req); + const deleted = await this.threadStore.delete(req.params.threadId, userId); + if (!deleted) { + res.status(404).json({ error: "Thread not found" }); + return; + } + res.json({ deleted: true }); + } + + private resolveAgent(name?: string): RegisteredAgent | null { + if (name) return this.agents.get(name) ?? null; + if (this.defaultAgentName) { + return this.agents.get(this.defaultAgentName) ?? null; + } + const first = this.agents.values().next(); + return first.done ? null : first.value; + } + + private getAllToolDefinitions(): AgentToolDefinition[] { + return Array.from(this.toolIndex.values()).map((e) => e.def); + } + + async shutdown() { + if (this.mcpClient) { + await this.mcpClient.close(); + this.mcpClient = null; + } + } + + exports() { + return { + registerAgent: (name: string, adapter: AgentAdapter) => { + this.agents.set(name, { name, adapter }); + if (!this.defaultAgentName) { + this.defaultAgentName = name; + } + }, + addTools: (tools: FunctionTool[]) => this.addTools(tools), + getTools: () => this.getAllToolDefinitions(), + getThreads: (userId: string) => this.threadStore.list(userId), + getAgents: () => ({ + agents: Array.from(this.agents.keys()), + default: this.defaultAgentName, + }), + reloadAgents: () => this.reloadAgents(), + }; + } +} + +/** + * @internal + */ +export const agent = toPlugin(AgentPlugin); diff --git a/packages/appkit/src/plugins/agent/config-loader.ts b/packages/appkit/src/plugins/agent/config-loader.ts new file mode 100644 index 00000000..c9ae512a --- /dev/null +++ b/packages/appkit/src/plugins/agent/config-loader.ts @@ -0,0 +1,94 @@ +import fs from "node:fs"; +import path from "node:path"; +import { createLogger } from "../../logging/logger"; + +const logger = createLogger("agent:config"); + +export interface AgentFileConfig { + name: string; + endpoint?: string; + maxSteps?: number; + maxTokens?: number; + default?: boolean; + systemPrompt: string; +} + +/** + * Parse a frontmatter markdown string into data + content. + * Handles flat YAML key-value pairs (string, number, boolean). + */ +export function parseFrontmatter(raw: string): { + data: Record; + content: string; +} { + const match = raw.match(/^---\r?\n([\s\S]*?)\r?\n---\r?\n([\s\S]*)$/); + if (!match) return { data: {}, content: raw.trim() }; + + const data: Record = {}; + for (const line of match[1].split("\n")) { + const colonIdx = line.indexOf(":"); + if (colonIdx === -1) continue; + const key = line.slice(0, colonIdx).trim(); + const rawVal = line.slice(colonIdx + 1).trim(); + if (!key) continue; + + if (rawVal === "true") data[key] = true; + else if (rawVal === "false") data[key] = false; + else if (/^\d+$/.test(rawVal)) data[key] = Number(rawVal); + else if (/^\d+\.\d+$/.test(rawVal)) data[key] = Number(rawVal); + else data[key] = rawVal; + } + + return { data, content: match[2].trim() }; +} + +/** + * Load agent configs from a directory of frontmatter markdown files. + * Returns an empty array if the directory doesn't exist. + */ +export function loadAgentConfigs(agentsDir: string): AgentFileConfig[] { + if (!fs.existsSync(agentsDir)) return []; + + const files = fs.readdirSync(agentsDir).filter((f) => f.endsWith(".md")); + const configs: AgentFileConfig[] = []; + + for (const file of files) { + try { + const raw = fs.readFileSync(path.join(agentsDir, file), "utf-8"); + const { data, content } = parseFrontmatter(raw); + const name = path.basename(file, ".md"); + + const config: AgentFileConfig = { + name, + systemPrompt: content, + }; + + if (typeof data.endpoint === "string") config.endpoint = data.endpoint; + if (typeof data.maxSteps === "number") config.maxSteps = data.maxSteps; + if (typeof data.maxTokens === "number") config.maxTokens = data.maxTokens; + if (typeof data.default === "boolean") config.default = data.default; + + if (data.maxSteps !== undefined && typeof data.maxSteps !== "number") { + logger.warn( + "Agent '%s': maxSteps should be a number, got %s. Using default.", + name, + typeof data.maxSteps, + ); + } + + if (data.maxTokens !== undefined && typeof data.maxTokens !== "number") { + logger.warn( + "Agent '%s': maxTokens should be a number, got %s. Using default.", + name, + typeof data.maxTokens, + ); + } + + configs.push(config); + } catch (error) { + logger.error("Failed to load agent config '%s': %O", file, error); + } + } + + return configs; +} diff --git a/packages/appkit/src/plugins/agent/defaults.ts b/packages/appkit/src/plugins/agent/defaults.ts new file mode 100644 index 00000000..4da11bef --- /dev/null +++ b/packages/appkit/src/plugins/agent/defaults.ts @@ -0,0 +1,12 @@ +import type { StreamExecutionSettings } from "shared"; + +export const agentStreamDefaults: StreamExecutionSettings = { + default: { + cache: { enabled: false }, + retry: { enabled: false }, + timeout: 300_000, + }, + stream: { + bufferSize: 200, + }, +}; diff --git a/packages/appkit/src/plugins/agent/event-translator.ts b/packages/appkit/src/plugins/agent/event-translator.ts new file mode 100644 index 00000000..314f8066 --- /dev/null +++ b/packages/appkit/src/plugins/agent/event-translator.ts @@ -0,0 +1,230 @@ +import { randomUUID } from "node:crypto"; +import type { + AgentEvent, + ResponseFunctionCallOutput, + ResponseFunctionToolCall, + ResponseOutputMessage, + ResponseStreamEvent, +} from "shared"; + +/** + * Translates internal AgentEvent stream into Responses API SSE events. + * + * Stateful: one instance per streaming request. Tracks sequence numbers, + * output indices, and message accumulation state. + */ +export class AgentEventTranslator { + private seqNum = 0; + private outputIndex = 0; + private messageId: string | null = null; + private messageText = ""; + private finalized = false; + + translate(event: AgentEvent): ResponseStreamEvent[] { + switch (event.type) { + case "message_delta": + return this.handleMessageDelta(event.content); + case "message": + return this.handleFullMessage(event.content); + case "tool_call": + return this.handleToolCall(event.callId, event.name, event.args); + case "tool_result": + return this.handleToolResult(event.callId, event.result, event.error); + case "thinking": + return [ + { + type: "appkit.thinking", + content: event.content, + sequence_number: this.seqNum++, + }, + ]; + case "metadata": + return [ + { + type: "appkit.metadata", + data: event.data, + sequence_number: this.seqNum++, + }, + ]; + case "status": + return this.handleStatus(event.status, event.error); + } + } + + finalize(): ResponseStreamEvent[] { + if (this.finalized) return []; + this.finalized = true; + + const events: ResponseStreamEvent[] = []; + + if (this.messageId) { + const doneItem: ResponseOutputMessage = { + type: "message", + id: this.messageId, + status: "completed", + role: "assistant", + content: [{ type: "output_text", text: this.messageText }], + }; + events.push({ + type: "response.output_item.done", + output_index: 0, + item: doneItem, + sequence_number: this.seqNum++, + }); + } + + events.push({ + type: "response.completed", + sequence_number: this.seqNum++, + response: {}, + }); + + return events; + } + + private handleMessageDelta(content: string): ResponseStreamEvent[] { + const events: ResponseStreamEvent[] = []; + this.messageText += content; + + if (!this.messageId) { + this.messageId = `msg_${randomUUID()}`; + const item: ResponseOutputMessage = { + type: "message", + id: this.messageId, + status: "in_progress", + role: "assistant", + content: [], + }; + events.push({ + type: "response.output_item.added", + output_index: 0, + item, + sequence_number: this.seqNum++, + }); + } + + events.push({ + type: "response.output_text.delta", + item_id: this.messageId, + output_index: 0, + content_index: 0, + delta: content, + sequence_number: this.seqNum++, + }); + + return events; + } + + private handleFullMessage(content: string): ResponseStreamEvent[] { + if (!this.messageId) { + this.messageId = `msg_${randomUUID()}`; + } + this.messageText = content; + + const item: ResponseOutputMessage = { + type: "message", + id: this.messageId, + status: "completed", + role: "assistant", + content: [{ type: "output_text", text: content }], + }; + + return [ + { + type: "response.output_item.added", + output_index: 0, + item, + sequence_number: this.seqNum++, + }, + { + type: "response.output_item.done", + output_index: 0, + item, + sequence_number: this.seqNum++, + }, + ]; + } + + private handleToolCall( + callId: string, + name: string, + args: unknown, + ): ResponseStreamEvent[] { + this.outputIndex++; + const item: ResponseFunctionToolCall = { + type: "function_call", + id: `fc_${randomUUID()}`, + call_id: callId, + name, + arguments: typeof args === "string" ? args : JSON.stringify(args), + }; + + return [ + { + type: "response.output_item.added", + output_index: this.outputIndex, + item, + sequence_number: this.seqNum++, + }, + { + type: "response.output_item.done", + output_index: this.outputIndex, + item, + sequence_number: this.seqNum++, + }, + ]; + } + + private handleToolResult( + callId: string, + result: unknown, + error?: string, + ): ResponseStreamEvent[] { + this.outputIndex++; + const output = + error ?? (typeof result === "string" ? result : JSON.stringify(result)); + const item: ResponseFunctionCallOutput = { + type: "function_call_output", + id: `fc_output_${randomUUID()}`, + call_id: callId, + output, + }; + + return [ + { + type: "response.output_item.added", + output_index: this.outputIndex, + item, + sequence_number: this.seqNum++, + }, + { + type: "response.output_item.done", + output_index: this.outputIndex, + item, + sequence_number: this.seqNum++, + }, + ]; + } + + private handleStatus(status: string, error?: string): ResponseStreamEvent[] { + if (status === "error") { + return [ + { + type: "error", + error: error ?? "Unknown error", + sequence_number: this.seqNum++, + }, + { + type: "response.failed", + sequence_number: this.seqNum++, + }, + ]; + } + + if (status === "complete") { + return this.finalize(); + } + + return []; + } +} diff --git a/packages/appkit/src/plugins/agent/index.ts b/packages/appkit/src/plugins/agent/index.ts new file mode 100644 index 00000000..861a68cc --- /dev/null +++ b/packages/appkit/src/plugins/agent/index.ts @@ -0,0 +1 @@ +export { agent } from "./agent"; diff --git a/packages/appkit/src/plugins/agent/manifest.json b/packages/appkit/src/plugins/agent/manifest.json new file mode 100644 index 00000000..d73b94ea --- /dev/null +++ b/packages/appkit/src/plugins/agent/manifest.json @@ -0,0 +1,10 @@ +{ + "$schema": "https://databricks.github.io/appkit/schemas/plugin-manifest.schema.json", + "name": "agent", + "displayName": "Agent Plugin", + "description": "Framework-agnostic AI agent with auto-tool-discovery from all registered plugins", + "resources": { + "required": [], + "optional": [] + } +} diff --git a/packages/appkit/src/plugins/agent/schemas.ts b/packages/appkit/src/plugins/agent/schemas.ts new file mode 100644 index 00000000..84ab3b88 --- /dev/null +++ b/packages/appkit/src/plugins/agent/schemas.ts @@ -0,0 +1,19 @@ +import { z } from "zod"; + +export const chatRequestSchema = z.object({ + message: z.string().min(1, "message must not be empty"), + threadId: z.string().optional(), + agent: z.string().optional(), +}); + +const messageItemSchema = z.object({ + role: z.enum(["user", "assistant", "system"]).optional(), + content: z.union([z.string(), z.array(z.any())]).optional(), + type: z.string().optional(), +}); + +export const invocationsRequestSchema = z.object({ + input: z.union([z.string().min(1), z.array(messageItemSchema).min(1)]), + stream: z.boolean().optional().default(true), + model: z.string().optional(), +}); diff --git a/packages/appkit/src/plugins/agent/system-prompt.ts b/packages/appkit/src/plugins/agent/system-prompt.ts new file mode 100644 index 00000000..634f49c5 --- /dev/null +++ b/packages/appkit/src/plugins/agent/system-prompt.ts @@ -0,0 +1,40 @@ +/** + * Builds the AppKit base system prompt from active plugin names. + * + * The base prompt provides guidelines and app context. It does NOT + * include individual tool descriptions — those are sent via the + * structured `tools` API parameter to the LLM. + */ +export function buildBaseSystemPrompt(pluginNames: string[]): string { + const lines: string[] = [ + "You are an AI assistant running on Databricks AppKit.", + ]; + + if (pluginNames.length > 0) { + lines.push(""); + lines.push(`Active plugins: ${pluginNames.join(", ")}`); + } + + lines.push(""); + lines.push("Guidelines:"); + lines.push("- Use Databricks SQL syntax when writing queries"); + lines.push( + "- When results are large, summarize key findings rather than dumping raw data", + ); + lines.push("- If a tool call fails, explain the error clearly to the user"); + lines.push("- When browsing files, verify the path exists before reading"); + + return lines.join("\n"); +} + +/** + * Compose the full system prompt from the base prompt and an optional + * per-agent user prompt. + */ +export function composeSystemPrompt( + basePrompt: string, + agentPrompt?: string, +): string { + if (!agentPrompt) return basePrompt; + return `${basePrompt}\n\n${agentPrompt}`; +} diff --git a/packages/appkit/src/plugins/agent/tests/agent.test.ts b/packages/appkit/src/plugins/agent/tests/agent.test.ts new file mode 100644 index 00000000..357d68d0 --- /dev/null +++ b/packages/appkit/src/plugins/agent/tests/agent.test.ts @@ -0,0 +1,230 @@ +import { + createMockRequest, + createMockResponse, + createMockRouter, + setupDatabricksEnv, +} from "@tools/test-helpers"; +import type { + AgentAdapter, + AgentEvent, + AgentInput, + AgentRunContext, + AgentToolDefinition, + ToolProvider, +} from "shared"; +import { beforeEach, describe, expect, test, vi } from "vitest"; +import { AgentPlugin } from "../agent"; + +vi.mock("../../../cache", () => ({ + CacheManager: { + getInstanceSync: vi.fn(() => ({ + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + getOrExecute: vi.fn(), + generateKey: vi.fn(), + })), + }, +})); + +vi.mock("../../../context", async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + getCurrentUserId: vi.fn(() => "test-user"), + getExecutionContext: vi.fn(() => ({ + userId: "test-user", + isUserContext: false, + })), + }; +}); + +vi.mock("../../../telemetry", () => ({ + TelemetryManager: { + getProvider: vi.fn(() => ({ + getTracer: vi.fn(), + getMeter: vi.fn(), + getLogger: vi.fn(), + emit: vi.fn(), + startActiveSpan: vi.fn(), + registerInstrumentations: vi.fn(), + })), + }, + normalizeTelemetryOptions: vi.fn(() => ({ + traces: false, + metrics: false, + logs: false, + })), +})); + +function createMockToolProvider( + tools: AgentToolDefinition[], +): ToolProvider & { asUser: any } { + return { + getAgentTools: () => tools, + executeAgentTool: vi.fn().mockResolvedValue({ result: "ok" }), + asUser: vi.fn().mockReturnThis(), + }; +} + +async function* mockAdapterRun(): AsyncGenerator { + yield { type: "message_delta", content: "Hello " }; + yield { type: "message_delta", content: "world" }; +} + +function createMockAdapter(): AgentAdapter { + return { + run: vi.fn().mockReturnValue(mockAdapterRun()), + }; +} + +describe("AgentPlugin", () => { + beforeEach(() => { + setupDatabricksEnv(); + }); + + test("collectTools discovers ToolProvider plugins", async () => { + const mockProvider = createMockToolProvider([ + { + name: "query", + description: "Run a query", + parameters: { type: "object", properties: {} }, + }, + ]); + + const plugin = new AgentPlugin({ + name: "agent", + plugins: { analytics: mockProvider }, + }); + + await plugin.setup(); + + const exports = plugin.exports(); + const tools = exports.getTools(); + + expect(tools).toHaveLength(1); + expect(tools[0].name).toBe("analytics.query"); + }); + + test("skips non-ToolProvider plugins", async () => { + const plugin = new AgentPlugin({ + name: "agent", + plugins: { + server: { name: "server" }, + analytics: createMockToolProvider([ + { name: "query", description: "q", parameters: { type: "object" } }, + ]), + }, + }); + + await plugin.setup(); + const tools = plugin.exports().getTools(); + expect(tools).toHaveLength(1); + }); + + test("registerAgent and resolveAgent", () => { + const plugin = new AgentPlugin({ name: "agent" }); + const adapter = createMockAdapter(); + + plugin.exports().registerAgent("assistant", adapter); + + // The first registered agent becomes the default + const tools = plugin.exports().getTools(); + expect(tools).toEqual([]); + }); + + test("injectRoutes registers chat, cancel, and thread routes", () => { + const plugin = new AgentPlugin({ name: "agent" }); + const { router, handlers } = createMockRouter(); + + plugin.injectRoutes(router); + + expect(handlers["POST:/chat"]).toBeDefined(); + expect(handlers["POST:/cancel"]).toBeDefined(); + expect(handlers["GET:/threads"]).toBeDefined(); + expect(handlers["GET:/threads/:threadId"]).toBeDefined(); + expect(handlers["DELETE:/threads/:threadId"]).toBeDefined(); + }); + + test("clientConfig exposes tools and agents", async () => { + const plugin = new AgentPlugin({ + name: "agent", + agents: { assistant: createMockAdapter() }, + }); + await plugin.setup(); + + const config = plugin.clientConfig(); + expect(config.tools).toEqual([]); + expect(config.agents).toEqual(["assistant"]); + expect(config.defaultAgent).toBe("assistant"); + }); + + test("exports().addTools adds function tools", () => { + const plugin = new AgentPlugin({ name: "agent" }); + + plugin.exports().addTools([ + { + type: "function" as const, + name: "myTool", + description: "A custom tool", + parameters: { type: "object", properties: {} }, + execute: async () => "result", + }, + ]); + + const tools = plugin.exports().getTools(); + expect(tools).toHaveLength(1); + expect(tools[0].name).toBe("myTool"); + }); + + test("executeTool always calls asUser(req) for plugin tools, even without requiresUserContext", async () => { + const mockProvider = createMockToolProvider([ + { + name: "action", + description: "An action without requiresUserContext", + parameters: { type: "object", properties: {} }, + }, + ]); + + function createToolCallingAdapter(): AgentAdapter { + return { + async *run( + _input: AgentInput, + context: AgentRunContext, + ): AsyncGenerator { + await context.executeTool("testplugin.action", {}); + yield { type: "message_delta", content: "done" }; + }, + }; + } + + const plugin = new AgentPlugin({ + name: "agent", + agents: { assistant: createToolCallingAdapter() }, + plugins: { testplugin: mockProvider }, + }); + await plugin.setup(); + + const { router, getHandler } = createMockRouter(); + plugin.injectRoutes(router); + const handler = getHandler("POST", "/chat"); + + const req = createMockRequest({ + body: { message: "hi" }, + headers: { + "x-forwarded-user": "test-user", + "x-forwarded-access-token": "test-token", + }, + }); + const res = createMockResponse(); + + await handler(req, res); + + expect(mockProvider.asUser).toHaveBeenCalledWith(req); + expect(mockProvider.executeAgentTool).toHaveBeenCalledWith( + "action", + {}, + expect.anything(), + ); + }); +}); diff --git a/packages/appkit/src/plugins/agent/tests/config-loader.test.ts b/packages/appkit/src/plugins/agent/tests/config-loader.test.ts new file mode 100644 index 00000000..19b69bdb --- /dev/null +++ b/packages/appkit/src/plugins/agent/tests/config-loader.test.ts @@ -0,0 +1,130 @@ +import fs from "node:fs"; +import path from "node:path"; +import { describe, expect, test, vi } from "vitest"; +import { loadAgentConfigs, parseFrontmatter } from "../config-loader"; + +vi.mock("../../../logging/logger", () => ({ + createLogger: () => ({ + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }), +})); + +describe("parseFrontmatter", () => { + test("parses frontmatter + body", () => { + const result = parseFrontmatter( + "---\nendpoint: my-model\nmaxSteps: 10\n---\nYou are helpful.", + ); + expect(result.data).toEqual({ endpoint: "my-model", maxSteps: 10 }); + expect(result.content).toBe("You are helpful."); + }); + + test("returns full content when no frontmatter", () => { + const result = parseFrontmatter("Just a plain prompt."); + expect(result.data).toEqual({}); + expect(result.content).toBe("Just a plain prompt."); + }); + + test("parses boolean values", () => { + const result = parseFrontmatter("---\ndefault: true\n---\nPrompt."); + expect(result.data.default).toBe(true); + }); + + test("parses numeric values", () => { + const result = parseFrontmatter( + "---\nmaxSteps: 5\nmaxTokens: 2048\n---\nPrompt.", + ); + expect(result.data.maxSteps).toBe(5); + expect(result.data.maxTokens).toBe(2048); + }); + + test("handles empty body", () => { + const result = parseFrontmatter("---\nendpoint: model\n---\n"); + expect(result.data.endpoint).toBe("model"); + expect(result.content).toBe(""); + }); + + test("handles colons in values", () => { + const result = parseFrontmatter( + "---\nendpoint: https://host:443/path\n---\nPrompt.", + ); + expect(result.data.endpoint).toBe("https://host:443/path"); + }); +}); + +describe("loadAgentConfigs", () => { + test("returns empty array for non-existent directory", () => { + const result = loadAgentConfigs("/nonexistent/path"); + expect(result).toEqual([]); + }); + + test("parses .md files from directory", () => { + const tmpDir = fs.mkdtempSync(path.join(import.meta.dirname, "tmp-")); + + try { + fs.writeFileSync( + path.join(tmpDir, "assistant.md"), + "---\nendpoint: claude\ndefault: true\n---\nYou are helpful.", + ); + fs.writeFileSync( + path.join(tmpDir, "autocomplete.md"), + "---\nendpoint: gemini\nmaxSteps: 1\n---\nJust continue.", + ); + + const configs = loadAgentConfigs(tmpDir); + + expect(configs).toHaveLength(2); + + const assistant = configs.find((c) => c.name === "assistant"); + expect(assistant).toBeDefined(); + expect(assistant?.endpoint).toBe("claude"); + expect(assistant?.default).toBe(true); + expect(assistant?.systemPrompt).toBe("You are helpful."); + + const autocomplete = configs.find((c) => c.name === "autocomplete"); + expect(autocomplete).toBeDefined(); + expect(autocomplete?.endpoint).toBe("gemini"); + expect(autocomplete?.maxSteps).toBe(1); + expect(autocomplete?.systemPrompt).toBe("Just continue."); + } finally { + fs.rmSync(tmpDir, { recursive: true }); + } + }); + + test("handles file with no frontmatter", () => { + const tmpDir = fs.mkdtempSync(path.join(import.meta.dirname, "tmp-")); + + try { + fs.writeFileSync( + path.join(tmpDir, "simple.md"), + "Just a plain system prompt.", + ); + + const configs = loadAgentConfigs(tmpDir); + expect(configs).toHaveLength(1); + expect(configs[0].name).toBe("simple"); + expect(configs[0].endpoint).toBeUndefined(); + expect(configs[0].systemPrompt).toBe("Just a plain system prompt."); + } finally { + fs.rmSync(tmpDir, { recursive: true }); + } + }); + + test("ignores non-.md files", () => { + const tmpDir = fs.mkdtempSync(path.join(import.meta.dirname, "tmp-")); + + try { + fs.writeFileSync(path.join(tmpDir, "agent.md"), "Prompt."); + fs.writeFileSync(path.join(tmpDir, "config.yaml"), "key: value"); + fs.writeFileSync(path.join(tmpDir, "notes.txt"), "Notes."); + + const configs = loadAgentConfigs(tmpDir); + expect(configs).toHaveLength(1); + expect(configs[0].name).toBe("agent"); + } finally { + fs.rmSync(tmpDir, { recursive: true }); + } + }); +}); diff --git a/packages/appkit/src/plugins/agent/tests/define-tool.test.ts b/packages/appkit/src/plugins/agent/tests/define-tool.test.ts new file mode 100644 index 00000000..ef61e8c4 --- /dev/null +++ b/packages/appkit/src/plugins/agent/tests/define-tool.test.ts @@ -0,0 +1,133 @@ +import { describe, expect, test, vi } from "vitest"; +import { z } from "zod"; +import { + defineTool, + executeFromRegistry, + type ToolRegistry, + toolsFromRegistry, +} from "../tools/define-tool"; + +describe("defineTool()", () => { + test("returns an entry matching the input config", () => { + const entry = defineTool({ + description: "echo", + schema: z.object({ msg: z.string() }), + annotations: { readOnly: true }, + handler: ({ msg }) => msg, + }); + + expect(entry.description).toBe("echo"); + expect(entry.annotations).toEqual({ readOnly: true }); + expect(typeof entry.handler).toBe("function"); + }); +}); + +describe("executeFromRegistry", () => { + const registry: ToolRegistry = { + echo: defineTool({ + description: "echo", + schema: z.object({ msg: z.string() }), + handler: ({ msg }) => `got ${msg}`, + }), + }; + + test("validates args and calls handler on success", async () => { + const result = await executeFromRegistry(registry, "echo", { msg: "hi" }); + expect(result).toBe("got hi"); + }); + + test("returns formatted error string on validation failure", async () => { + const result = await executeFromRegistry(registry, "echo", {}); + expect(typeof result).toBe("string"); + expect(result).toContain("Invalid arguments for echo"); + expect(result).toContain("msg"); + }); + + test("throws for unknown tool names", async () => { + await expect(executeFromRegistry(registry, "missing", {})).rejects.toThrow( + /Unknown tool: missing/, + ); + }); + + test("forwards AbortSignal to the handler", async () => { + const handler = vi.fn(async (_args: { x: string }, signal?: AbortSignal) => + signal?.aborted ? "aborted" : "ok", + ); + const reg: ToolRegistry = { + t: defineTool({ + description: "t", + schema: z.object({ x: z.string() }), + handler, + }), + }; + + const controller = new AbortController(); + controller.abort(); + await executeFromRegistry(reg, "t", { x: "hi" }, controller.signal); + + expect(handler).toHaveBeenCalledTimes(1); + expect(handler.mock.calls[0][1]).toBe(controller.signal); + }); +}); + +describe("toolsFromRegistry", () => { + test("produces AgentToolDefinition[] with JSON Schema parameters", () => { + const registry: ToolRegistry = { + query: defineTool({ + description: "Execute a SQL query", + schema: z.object({ + query: z.string().describe("SQL query"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + handler: () => "ok", + }), + }; + + const defs = toolsFromRegistry(registry); + expect(defs).toHaveLength(1); + expect(defs[0].name).toBe("query"); + expect(defs[0].description).toBe("Execute a SQL query"); + expect(defs[0].parameters).toMatchObject({ + type: "object", + properties: { + query: { type: "string", description: "SQL query" }, + }, + required: ["query"], + }); + expect(defs[0].annotations).toEqual({ + readOnly: true, + requiresUserContext: true, + }); + }); + + test("preserves dotted names like uploads.list from the registry keys", () => { + const registry: ToolRegistry = { + "uploads.list": defineTool({ + description: "list uploads", + schema: z.object({}), + handler: () => [], + }), + "documents.list": defineTool({ + description: "list documents", + schema: z.object({}), + handler: () => [], + }), + }; + + const names = toolsFromRegistry(registry).map((d) => d.name); + expect(names).toContain("uploads.list"); + expect(names).toContain("documents.list"); + }); + + test("omits annotations when none are provided", () => { + const registry: ToolRegistry = { + plain: defineTool({ + description: "plain", + schema: z.object({}), + handler: () => "ok", + }), + }; + const [def] = toolsFromRegistry(registry); + expect(def.annotations).toBeUndefined(); + }); +}); diff --git a/packages/appkit/src/plugins/agent/tests/event-translator.test.ts b/packages/appkit/src/plugins/agent/tests/event-translator.test.ts new file mode 100644 index 00000000..eda72ebb --- /dev/null +++ b/packages/appkit/src/plugins/agent/tests/event-translator.test.ts @@ -0,0 +1,204 @@ +import { describe, expect, test } from "vitest"; +import { AgentEventTranslator } from "../event-translator"; + +describe("AgentEventTranslator", () => { + test("translates message_delta to output_item.added + output_text.delta on first delta", () => { + const translator = new AgentEventTranslator(); + const events = translator.translate({ + type: "message_delta", + content: "Hello", + }); + + expect(events).toHaveLength(2); + expect(events[0].type).toBe("response.output_item.added"); + expect(events[1].type).toBe("response.output_text.delta"); + + if (events[1].type === "response.output_text.delta") { + expect(events[1].delta).toBe("Hello"); + } + }); + + test("subsequent message_delta only produces output_text.delta", () => { + const translator = new AgentEventTranslator(); + translator.translate({ type: "message_delta", content: "Hello" }); + const events = translator.translate({ + type: "message_delta", + content: " world", + }); + + expect(events).toHaveLength(1); + expect(events[0].type).toBe("response.output_text.delta"); + }); + + test("sequence_number is monotonically increasing", () => { + const translator = new AgentEventTranslator(); + const e1 = translator.translate({ type: "message_delta", content: "a" }); + const e2 = translator.translate({ type: "message_delta", content: "b" }); + const e3 = translator.finalize(); + + const allSeqs = [...e1, ...e2, ...e3].map((e) => + "sequence_number" in e ? e.sequence_number : -1, + ); + + for (let i = 1; i < allSeqs.length; i++) { + expect(allSeqs[i]).toBeGreaterThan(allSeqs[i - 1]); + } + }); + + test("translates tool_call to paired output_item.added + output_item.done", () => { + const translator = new AgentEventTranslator(); + const events = translator.translate({ + type: "tool_call", + callId: "call_1", + name: "analytics.query", + args: { sql: "SELECT 1" }, + }); + + expect(events).toHaveLength(2); + expect(events[0].type).toBe("response.output_item.added"); + expect(events[1].type).toBe("response.output_item.done"); + + if (events[0].type === "response.output_item.added") { + expect(events[0].item.type).toBe("function_call"); + if (events[0].item.type === "function_call") { + expect(events[0].item.name).toBe("analytics.query"); + expect(events[0].item.call_id).toBe("call_1"); + } + } + }); + + test("translates tool_result to paired output_item events", () => { + const translator = new AgentEventTranslator(); + const events = translator.translate({ + type: "tool_result", + callId: "call_1", + result: { rows: 42 }, + }); + + expect(events).toHaveLength(2); + expect(events[0].type).toBe("response.output_item.added"); + + if (events[0].type === "response.output_item.added") { + expect(events[0].item.type).toBe("function_call_output"); + } + }); + + test("translates tool_result error", () => { + const translator = new AgentEventTranslator(); + const events = translator.translate({ + type: "tool_result", + callId: "call_1", + result: null, + error: "Query failed", + }); + + if ( + events[0].type === "response.output_item.added" && + events[0].item.type === "function_call_output" + ) { + expect(events[0].item.output).toBe("Query failed"); + } + }); + + test("translates thinking to appkit.thinking extension event", () => { + const translator = new AgentEventTranslator(); + const events = translator.translate({ + type: "thinking", + content: "Let me think about this...", + }); + + expect(events).toHaveLength(1); + expect(events[0].type).toBe("appkit.thinking"); + if (events[0].type === "appkit.thinking") { + expect(events[0].content).toBe("Let me think about this..."); + } + }); + + test("translates metadata to appkit.metadata extension event", () => { + const translator = new AgentEventTranslator(); + const events = translator.translate({ + type: "metadata", + data: { threadId: "t-123" }, + }); + + expect(events).toHaveLength(1); + expect(events[0].type).toBe("appkit.metadata"); + if (events[0].type === "appkit.metadata") { + expect(events[0].data.threadId).toBe("t-123"); + } + }); + + test("status:complete triggers finalize with response.completed", () => { + const translator = new AgentEventTranslator(); + translator.translate({ type: "message_delta", content: "Hi" }); + const events = translator.translate({ type: "status", status: "complete" }); + + const types = events.map((e) => e.type); + expect(types).toContain("response.output_item.done"); + expect(types).toContain("response.completed"); + }); + + test("status:error emits error + response.failed", () => { + const translator = new AgentEventTranslator(); + const events = translator.translate({ + type: "status", + status: "error", + error: "Something broke", + }); + + expect(events).toHaveLength(2); + expect(events[0].type).toBe("error"); + expect(events[1].type).toBe("response.failed"); + + if (events[0].type === "error") { + expect(events[0].error).toBe("Something broke"); + } + }); + + test("finalize produces response.completed", () => { + const translator = new AgentEventTranslator(); + const events = translator.finalize(); + + expect(events.some((e) => e.type === "response.completed")).toBe(true); + }); + + test("finalize with accumulated message text produces output_item.done", () => { + const translator = new AgentEventTranslator(); + translator.translate({ type: "message_delta", content: "Hello " }); + translator.translate({ type: "message_delta", content: "world" }); + const events = translator.finalize(); + + const doneEvent = events.find( + (e) => e.type === "response.output_item.done", + ); + expect(doneEvent).toBeDefined(); + if ( + doneEvent?.type === "response.output_item.done" && + doneEvent.item.type === "message" + ) { + expect(doneEvent.item.content[0].text).toBe("Hello world"); + } + }); + + test("output_index increments for tool calls", () => { + const translator = new AgentEventTranslator(); + const e1 = translator.translate({ + type: "tool_call", + callId: "c1", + name: "tool1", + args: {}, + }); + const e2 = translator.translate({ + type: "tool_result", + callId: "c1", + result: "ok", + }); + + if ( + e1[0].type === "response.output_item.added" && + e2[0].type === "response.output_item.added" + ) { + expect(e2[0].output_index).toBeGreaterThan(e1[0].output_index); + } + }); +}); diff --git a/packages/appkit/src/plugins/agent/tests/system-prompt.test.ts b/packages/appkit/src/plugins/agent/tests/system-prompt.test.ts new file mode 100644 index 00000000..83bf8e19 --- /dev/null +++ b/packages/appkit/src/plugins/agent/tests/system-prompt.test.ts @@ -0,0 +1,45 @@ +import { describe, expect, test } from "vitest"; +import { buildBaseSystemPrompt, composeSystemPrompt } from "../system-prompt"; + +describe("buildBaseSystemPrompt", () => { + test("includes plugin names", () => { + const prompt = buildBaseSystemPrompt(["analytics", "files", "genie"]); + expect(prompt).toContain("Active plugins: analytics, files, genie"); + }); + + test("includes guidelines", () => { + const prompt = buildBaseSystemPrompt([]); + expect(prompt).toContain("Guidelines:"); + expect(prompt).toContain("Databricks SQL"); + expect(prompt).toContain("summarize key findings"); + }); + + test("works with no plugins", () => { + const prompt = buildBaseSystemPrompt([]); + expect(prompt).toContain("AI assistant running on Databricks AppKit"); + expect(prompt).not.toContain("Active plugins:"); + }); + + test("does NOT include individual tool names", () => { + const prompt = buildBaseSystemPrompt(["analytics"]); + expect(prompt).not.toContain("analytics.query"); + expect(prompt).not.toContain("Available tools:"); + }); +}); + +describe("composeSystemPrompt", () => { + test("concatenates base + agent prompt with double newline", () => { + const composed = composeSystemPrompt("Base prompt.", "Agent prompt."); + expect(composed).toBe("Base prompt.\n\nAgent prompt."); + }); + + test("returns base prompt alone when no agent prompt", () => { + const composed = composeSystemPrompt("Base prompt."); + expect(composed).toBe("Base prompt."); + }); + + test("returns base prompt when agent prompt is empty string", () => { + const composed = composeSystemPrompt("Base prompt.", ""); + expect(composed).toBe("Base prompt."); + }); +}); diff --git a/packages/appkit/src/plugins/agent/tests/thread-store.test.ts b/packages/appkit/src/plugins/agent/tests/thread-store.test.ts new file mode 100644 index 00000000..ed4f70ba --- /dev/null +++ b/packages/appkit/src/plugins/agent/tests/thread-store.test.ts @@ -0,0 +1,138 @@ +import { describe, expect, test } from "vitest"; +import { InMemoryThreadStore } from "../thread-store"; + +describe("InMemoryThreadStore", () => { + test("create() returns a new thread with the given userId", async () => { + const store = new InMemoryThreadStore(); + const thread = await store.create("user-1"); + + expect(thread.id).toBeDefined(); + expect(thread.userId).toBe("user-1"); + expect(thread.messages).toEqual([]); + expect(thread.createdAt).toBeInstanceOf(Date); + expect(thread.updatedAt).toBeInstanceOf(Date); + }); + + test("get() returns the thread for the correct user", async () => { + const store = new InMemoryThreadStore(); + const thread = await store.create("user-1"); + + const retrieved = await store.get(thread.id, "user-1"); + expect(retrieved).toEqual(thread); + }); + + test("get() returns null for wrong user", async () => { + const store = new InMemoryThreadStore(); + const thread = await store.create("user-1"); + + const retrieved = await store.get(thread.id, "user-2"); + expect(retrieved).toBeNull(); + }); + + test("get() returns null for non-existent thread", async () => { + const store = new InMemoryThreadStore(); + const retrieved = await store.get("non-existent", "user-1"); + expect(retrieved).toBeNull(); + }); + + test("list() returns threads sorted by updatedAt desc", async () => { + const store = new InMemoryThreadStore(); + const t1 = await store.create("user-1"); + const t2 = await store.create("user-1"); + + // Make t1 more recently updated + await store.addMessage(t1.id, "user-1", { + id: "msg-1", + role: "user", + content: "hello", + createdAt: new Date(), + }); + + const threads = await store.list("user-1"); + expect(threads).toHaveLength(2); + expect(threads[0].id).toBe(t1.id); + expect(threads[1].id).toBe(t2.id); + }); + + test("list() returns empty for unknown user", async () => { + const store = new InMemoryThreadStore(); + await store.create("user-1"); + + const threads = await store.list("user-2"); + expect(threads).toEqual([]); + }); + + test("addMessage() appends to thread and updates timestamp", async () => { + const store = new InMemoryThreadStore(); + const thread = await store.create("user-1"); + const originalUpdatedAt = thread.updatedAt; + + // Small delay to ensure timestamp differs + await new Promise((r) => setTimeout(r, 5)); + + await store.addMessage(thread.id, "user-1", { + id: "msg-1", + role: "user", + content: "hello", + createdAt: new Date(), + }); + + const updated = await store.get(thread.id, "user-1"); + expect(updated?.messages).toHaveLength(1); + expect(updated?.messages[0].content).toBe("hello"); + expect(updated?.updatedAt.getTime()).toBeGreaterThanOrEqual( + originalUpdatedAt.getTime(), + ); + }); + + test("addMessage() throws for non-existent thread", async () => { + const store = new InMemoryThreadStore(); + + await expect( + store.addMessage("non-existent", "user-1", { + id: "msg-1", + role: "user", + content: "hello", + createdAt: new Date(), + }), + ).rejects.toThrow("Thread non-existent not found"); + }); + + test("delete() removes a thread and returns true", async () => { + const store = new InMemoryThreadStore(); + const thread = await store.create("user-1"); + + const deleted = await store.delete(thread.id, "user-1"); + expect(deleted).toBe(true); + + const retrieved = await store.get(thread.id, "user-1"); + expect(retrieved).toBeNull(); + }); + + test("delete() returns false for non-existent thread", async () => { + const store = new InMemoryThreadStore(); + const deleted = await store.delete("non-existent", "user-1"); + expect(deleted).toBe(false); + }); + + test("delete() returns false for wrong user", async () => { + const store = new InMemoryThreadStore(); + const thread = await store.create("user-1"); + + const deleted = await store.delete(thread.id, "user-2"); + expect(deleted).toBe(false); + }); + + test("threads are isolated per user", async () => { + const store = new InMemoryThreadStore(); + await store.create("user-1"); + await store.create("user-1"); + await store.create("user-2"); + + const user1Threads = await store.list("user-1"); + const user2Threads = await store.list("user-2"); + + expect(user1Threads).toHaveLength(2); + expect(user2Threads).toHaveLength(1); + }); +}); diff --git a/packages/appkit/src/plugins/agent/thread-store.ts b/packages/appkit/src/plugins/agent/thread-store.ts new file mode 100644 index 00000000..f3ca0599 --- /dev/null +++ b/packages/appkit/src/plugins/agent/thread-store.ts @@ -0,0 +1,59 @@ +import { randomUUID } from "node:crypto"; +import type { Message, Thread, ThreadStore } from "shared"; + +/** + * In-memory thread store backed by a nested Map. + * + * Outer key: userId, inner key: threadId. + * Suitable for development and single-instance deployments. + */ +export class InMemoryThreadStore implements ThreadStore { + private store = new Map>(); + + async create(userId: string): Promise { + const now = new Date(); + const thread: Thread = { + id: randomUUID(), + userId, + messages: [], + createdAt: now, + updatedAt: now, + }; + this.userMap(userId).set(thread.id, thread); + return thread; + } + + async get(threadId: string, userId: string): Promise { + return this.userMap(userId).get(threadId) ?? null; + } + + async list(userId: string): Promise { + return Array.from(this.userMap(userId).values()).sort( + (a, b) => b.updatedAt.getTime() - a.updatedAt.getTime(), + ); + } + + async addMessage( + threadId: string, + userId: string, + message: Message, + ): Promise { + const thread = this.userMap(userId).get(threadId); + if (!thread) throw new Error(`Thread ${threadId} not found`); + thread.messages.push(message); + thread.updatedAt = new Date(); + } + + async delete(threadId: string, userId: string): Promise { + return this.userMap(userId).delete(threadId); + } + + private userMap(userId: string): Map { + let map = this.store.get(userId); + if (!map) { + map = new Map(); + this.store.set(userId, map); + } + return map; + } +} diff --git a/packages/appkit/src/plugins/agent/tools/define-tool.ts b/packages/appkit/src/plugins/agent/tools/define-tool.ts new file mode 100644 index 00000000..7c1f49e4 --- /dev/null +++ b/packages/appkit/src/plugins/agent/tools/define-tool.ts @@ -0,0 +1,83 @@ +import type { AgentToolDefinition, ToolAnnotations } from "shared"; +import { toJSONSchema, type z } from "zod"; +import { formatZodError } from "./tool"; + +/** + * Single-tool entry for a plugin's internal tool registry. + * + * Plugins collect these into a `Record` keyed by the tool's + * public name and dispatch via `executeFromRegistry`. + */ +export interface ToolEntry { + description: string; + schema: S; + annotations?: ToolAnnotations; + handler: ( + args: z.infer, + signal?: AbortSignal, + ) => unknown | Promise; +} + +export type ToolRegistry = Record; + +/** + * Defines a single tool entry for a plugin's internal registry. + * + * The generic `S` flows from `schema` through to the `handler` callback so + * `args` is fully typed from the Zod schema. Names are assigned by the + * registry key, so they are not repeated inside the entry. + */ +export function defineTool( + config: ToolEntry, +): ToolEntry { + return config; +} + +/** + * Validates tool-call arguments against the entry's schema and invokes its + * handler. On validation failure, returns an LLM-friendly error string + * (matching the behavior of `tool()`) rather than throwing, so the model + * can self-correct on its next turn. + */ +export async function executeFromRegistry( + registry: ToolRegistry, + name: string, + args: unknown, + signal?: AbortSignal, +): Promise { + const entry = registry[name]; + if (!entry) { + throw new Error(`Unknown tool: ${name}`); + } + const parsed = entry.schema.safeParse(args); + if (!parsed.success) { + return formatZodError(parsed.error, name); + } + return entry.handler(parsed.data, signal); +} + +/** + * Produces the `AgentToolDefinition[]` a ToolProvider exposes to the LLM, + * deriving `parameters` JSON Schema from each entry's Zod schema. + * + * Tool names come from registry keys (supports dotted names like + * `uploads.list` for dynamic plugins). + */ +export function toolsFromRegistry( + registry: ToolRegistry, +): AgentToolDefinition[] { + return Object.entries(registry).map(([name, entry]) => { + const parameters = toJSONSchema( + entry.schema, + ) as unknown as AgentToolDefinition["parameters"]; + const def: AgentToolDefinition = { + name, + description: entry.description, + parameters, + }; + if (entry.annotations) { + def.annotations = entry.annotations; + } + return def; + }); +} diff --git a/packages/appkit/src/plugins/agent/tools/index.ts b/packages/appkit/src/plugins/agent/tools/index.ts index 042f1958..7b779d1c 100644 --- a/packages/appkit/src/plugins/agent/tools/index.ts +++ b/packages/appkit/src/plugins/agent/tools/index.ts @@ -1,3 +1,10 @@ +export { + defineTool, + executeFromRegistry, + type ToolEntry, + type ToolRegistry, + toolsFromRegistry, +} from "./define-tool"; export { type FunctionTool, functionToolToDefinition, diff --git a/packages/appkit/src/plugins/agent/types.ts b/packages/appkit/src/plugins/agent/types.ts new file mode 100644 index 00000000..67cc2c8b --- /dev/null +++ b/packages/appkit/src/plugins/agent/types.ts @@ -0,0 +1,57 @@ +import type { + AgentAdapter, + AgentToolDefinition, + BasePluginConfig, + ThreadStore, + ToolProvider, +} from "shared"; +import type { FunctionTool } from "./tools/function-tool"; +import type { HostedTool } from "./tools/hosted-tools"; + +export type AgentTool = FunctionTool | HostedTool; + +export interface AgentDefinition { + adapter: AgentAdapter | Promise; + systemPrompt?: string; +} + +export type AgentEntry = AgentAdapter | AgentDefinition | Promise; + +export interface AgentPluginConfig extends BasePluginConfig { + agents?: Record; + defaultAgent?: string; + threadStore?: ThreadStore; + tools?: AgentTool[]; + agentsDir?: string; + plugins?: Record; +} + +export type ToolEntry = + | { + source: "plugin"; + plugin: ToolProvider & { asUser(req: any): any }; + def: AgentToolDefinition; + localName: string; + } + | { + source: "function"; + functionTool: FunctionTool; + def: AgentToolDefinition; + } + | { + source: "mcp"; + mcpToolName: string; + def: AgentToolDefinition; + }; + +export type RegisteredAgent = { + name: string; + adapter: AgentAdapter; + systemPrompt?: string; +}; + +export type { + AgentAdapter, + AgentToolDefinition, + ToolProvider, +} from "shared"; diff --git a/packages/appkit/src/plugins/analytics/analytics.ts b/packages/appkit/src/plugins/analytics/analytics.ts index a9c688da..1bd97d18 100644 --- a/packages/appkit/src/plugins/analytics/analytics.ts +++ b/packages/appkit/src/plugins/analytics/analytics.ts @@ -1,16 +1,24 @@ import type { WorkspaceClient } from "@databricks/sdk-experimental"; import type express from "express"; import type { + AgentToolDefinition, IAppRouter, PluginExecuteConfig, SQLTypeMarker, StreamExecutionSettings, + ToolProvider, } from "shared"; +import { z } from "zod"; import { SQLWarehouseConnector } from "../../connectors"; import { getWarehouseId, getWorkspaceClient } from "../../context"; import { createLogger } from "../../logging/logger"; import { Plugin, toPlugin } from "../../plugin"; import type { PluginManifest } from "../../registry"; +import { + defineTool, + executeFromRegistry, + toolsFromRegistry, +} from "../agent/tools/define-tool"; import { queryDefaults } from "./defaults"; import manifest from "./manifest.json"; import { QueryProcessor } from "./query"; @@ -22,7 +30,7 @@ import type { const logger = createLogger("analytics"); -export class AnalyticsPlugin extends Plugin { +export class AnalyticsPlugin extends Plugin implements ToolProvider { /** Plugin manifest declaring metadata and resource requirements */ static manifest = manifest as PluginManifest<"analytics">; @@ -262,6 +270,34 @@ export class AnalyticsPlugin extends Plugin { this.streamManager.abortAll(); } + private tools = { + query: defineTool({ + description: + "Execute a SQL query against the Databricks SQL warehouse. Returns the query results as JSON.", + schema: z.object({ + query: z.string().describe("The SQL query to execute"), + }), + annotations: { + readOnly: true, + requiresUserContext: true, + }, + handler: (args, signal) => + this.query(args.query, undefined, undefined, signal), + }), + }; + + getAgentTools(): AgentToolDefinition[] { + return toolsFromRegistry(this.tools); + } + + async executeAgentTool( + name: string, + args: unknown, + signal?: AbortSignal, + ): Promise { + return executeFromRegistry(this.tools, name, args, signal); + } + /** * Returns the public exports for the analytics plugin. * Note: `asUser()` is automatically added by AppKit. diff --git a/packages/appkit/src/plugins/files/plugin.ts b/packages/appkit/src/plugins/files/plugin.ts index 9344af85..2d92c198 100644 --- a/packages/appkit/src/plugins/files/plugin.ts +++ b/packages/appkit/src/plugins/files/plugin.ts @@ -2,7 +2,13 @@ import { STATUS_CODES } from "node:http"; import { Readable } from "node:stream"; import { ApiError } from "@databricks/sdk-experimental"; import type express from "express"; -import type { IAppRouter, PluginExecutionSettings } from "shared"; +import type { + AgentToolDefinition, + IAppRouter, + PluginExecutionSettings, + ToolProvider, +} from "shared"; +import { z } from "zod"; import { contentTypeFromPath, FilesConnector, @@ -15,6 +21,12 @@ import { createLogger } from "../../logging/logger"; import { Plugin, toPlugin } from "../../plugin"; import type { PluginManifest, ResourceRequirement } from "../../registry"; import { ResourceType } from "../../registry"; +import { + defineTool, + executeFromRegistry, + type ToolRegistry, + toolsFromRegistry, +} from "../agent/tools/define-tool"; import { FILES_DOWNLOAD_DEFAULTS, FILES_MAX_UPLOAD_SIZE, @@ -34,7 +46,7 @@ import type { const logger = createLogger("files"); -export class FilesPlugin extends Plugin { +export class FilesPlugin extends Plugin implements ToolProvider { name = "files"; /** Plugin manifest declaring metadata and resource requirements. */ @@ -45,6 +57,7 @@ export class FilesPlugin extends Plugin { private volumeConnectors: Record = {}; private volumeConfigs: Record = {}; private volumeKeys: string[] = []; + private tools: ToolRegistry = {}; /** * Scans `process.env` for `DATABRICKS_VOLUME_*` keys and merges them with @@ -148,6 +161,79 @@ export class FilesPlugin extends Plugin { customContentTypes: mergedConfig.customContentTypes, }); } + + for (const volumeKey of this.volumeKeys) { + Object.assign(this.tools, this._defineVolumeTools(volumeKey)); + } + } + + /** + * Builds the registry entries for a single volume. One set of tools per + * configured volume, keyed by `${volumeKey}.${method}`. + */ + private _defineVolumeTools(volumeKey: string): ToolRegistry { + const api = () => this.createVolumeAPI(volumeKey); + return { + [`${volumeKey}.list`]: defineTool({ + description: `List files and directories in the "${volumeKey}" volume`, + schema: z.object({ + path: z + .string() + .optional() + .describe("Directory path to list (optional, defaults to root)"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + handler: (args) => api().list(args.path), + }), + [`${volumeKey}.read`]: defineTool({ + description: `Read a text file from the "${volumeKey}" volume`, + schema: z.object({ + path: z.string().describe("File path to read"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + handler: (args) => api().read(args.path), + }), + [`${volumeKey}.exists`]: defineTool({ + description: `Check if a file or directory exists in the "${volumeKey}" volume`, + schema: z.object({ + path: z.string().describe("Path to check"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + handler: (args) => api().exists(args.path), + }), + [`${volumeKey}.metadata`]: defineTool({ + description: `Get metadata (size, type, last modified) for a file in the "${volumeKey}" volume`, + schema: z.object({ + path: z.string().describe("File path"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + handler: (args) => api().metadata(args.path), + }), + [`${volumeKey}.upload`]: defineTool({ + description: `Upload a text file to the "${volumeKey}" volume`, + schema: z.object({ + path: z.string().describe("Destination file path"), + contents: z.string().describe("File contents as a string"), + overwrite: z + .boolean() + .optional() + .describe("Whether to overwrite existing file"), + }), + annotations: { destructive: true, requiresUserContext: true }, + handler: (args) => + api().upload(args.path, args.contents, { + overwrite: args.overwrite, + }), + }), + [`${volumeKey}.delete`]: defineTool({ + description: `Delete a file from the "${volumeKey}" volume`, + schema: z.object({ + path: z.string().describe("File path to delete"), + }), + annotations: { destructive: true, requiresUserContext: true }, + handler: (args) => api().delete(args.path), + }), + }; } /** @@ -950,6 +1036,19 @@ export class FilesPlugin extends Plugin { * appKit.files("uploads").list() * ``` */ + + getAgentTools(): AgentToolDefinition[] { + return toolsFromRegistry(this.tools); + } + + async executeAgentTool( + name: string, + args: unknown, + signal?: AbortSignal, + ): Promise { + return executeFromRegistry(this.tools, name, args, signal); + } + exports(): FilesExport { const resolveVolume = (volumeKey: string): VolumeHandle => { if (!this.volumeKeys.includes(volumeKey)) { diff --git a/packages/appkit/src/plugins/files/tests/plugin.test.ts b/packages/appkit/src/plugins/files/tests/plugin.test.ts index 99e08b8c..17591a45 100644 --- a/packages/appkit/src/plugins/files/tests/plugin.test.ts +++ b/packages/appkit/src/plugins/files/tests/plugin.test.ts @@ -204,6 +204,62 @@ describe("FilesPlugin", () => { }); }); + describe("getAgentTools / executeAgentTool", () => { + test("produces independent tool entries per volume", () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const tools = plugin.getAgentTools(); + const names = tools.map((t) => t.name); + + expect(names).toContain("uploads.list"); + expect(names).toContain("uploads.read"); + expect(names).toContain("uploads.exists"); + expect(names).toContain("uploads.metadata"); + expect(names).toContain("uploads.upload"); + expect(names).toContain("uploads.delete"); + + expect(names).toContain("exports.list"); + expect(names).toContain("exports.read"); + expect(names).toContain("exports.delete"); + + expect(tools).toHaveLength(12); + }); + + test("dispatches to the correct volume API based on the tool name", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const asyncIterable = (items: { path: string }[]) => ({ + [Symbol.asyncIterator]: async function* () { + for (const item of items) yield item; + }, + }); + mockClient.files.listDirectoryContents.mockReturnValueOnce( + asyncIterable([{ path: "uploads-file" }]), + ); + mockClient.files.listDirectoryContents.mockReturnValueOnce( + asyncIterable([{ path: "exports-file" }]), + ); + + const uploadsResult = (await plugin.executeAgentTool( + "uploads.list", + {}, + )) as { path: string }[]; + const exportsResult = (await plugin.executeAgentTool( + "exports.list", + {}, + )) as { path: string }[]; + + expect(uploadsResult[0].path).toBe("uploads-file"); + expect(exportsResult[0].path).toBe("exports-file"); + }); + + test("returns LLM-friendly error string for invalid tool args", async () => { + const plugin = new FilesPlugin(VOLUMES_CONFIG); + const result = await plugin.executeAgentTool("uploads.read", {}); + expect(typeof result).toBe("string"); + expect(result).toContain("Invalid arguments for uploads.read"); + expect(result).toContain("path"); + }); + }); + describe("exports()", () => { test("returns a callable function with a .volume alias", () => { const plugin = new FilesPlugin(VOLUMES_CONFIG); diff --git a/packages/appkit/src/plugins/genie/genie.ts b/packages/appkit/src/plugins/genie/genie.ts index 712aadbf..96c34b64 100644 --- a/packages/appkit/src/plugins/genie/genie.ts +++ b/packages/appkit/src/plugins/genie/genie.ts @@ -1,11 +1,23 @@ import { randomUUID } from "node:crypto"; import type express from "express"; -import type { IAppRouter, StreamExecutionSettings } from "shared"; +import type { + AgentToolDefinition, + IAppRouter, + StreamExecutionSettings, + ToolProvider, +} from "shared"; +import { z } from "zod"; import { GenieConnector } from "../../connectors"; import { getWorkspaceClient } from "../../context"; import { createLogger } from "../../logging"; import { Plugin, toPlugin } from "../../plugin"; import type { PluginManifest } from "../../registry"; +import { + defineTool, + executeFromRegistry, + type ToolRegistry, + toolsFromRegistry, +} from "../agent/tools/define-tool"; import { genieStreamDefaults } from "./defaults"; import manifest from "./manifest.json"; import type { @@ -17,7 +29,7 @@ import type { const logger = createLogger("genie"); -export class GeniePlugin extends Plugin { +export class GeniePlugin extends Plugin implements ToolProvider { static manifest = manifest as PluginManifest<"genie">; protected static description = @@ -25,6 +37,7 @@ export class GeniePlugin extends Plugin { protected declare config: IGenieConfig; private readonly genieConnector: GenieConnector; + private tools: ToolRegistry = {}; constructor(config: IGenieConfig) { super(config); @@ -36,6 +49,53 @@ export class GeniePlugin extends Plugin { timeout: this.config.timeout, maxMessages: 200, }); + + for (const alias of Object.keys(this.config.spaces ?? {})) { + Object.assign(this.tools, this._defineSpaceTools(alias)); + } + } + + /** + * Builds the registry entries for a single Genie space alias. + * One set of tools per configured space, keyed by `${alias}.${method}`. + */ + private _defineSpaceTools(alias: string): ToolRegistry { + return { + [`${alias}.sendMessage`]: defineTool({ + description: `Send a natural language question to the Genie space "${alias}" and get data analysis results`, + schema: z.object({ + content: z.string().describe("The natural language question to ask"), + conversationId: z + .string() + .optional() + .describe( + "Optional conversation ID to continue an existing conversation", + ), + }), + annotations: { requiresUserContext: true }, + handler: async (args) => { + const events: GenieStreamEvent[] = []; + for await (const event of this.sendMessage( + alias, + args.content, + args.conversationId, + )) { + events.push(event); + } + return events; + }, + }), + [`${alias}.getConversation`]: defineTool({ + description: `Retrieve the conversation history from the Genie space "${alias}"`, + schema: z.object({ + conversationId: z + .string() + .describe("The conversation ID to retrieve"), + }), + annotations: { readOnly: true, requiresUserContext: true }, + handler: (args) => this.getConversation(alias, args.conversationId), + }), + }; } private defaultSpaces(): Record { @@ -287,6 +347,18 @@ export class GeniePlugin extends Plugin { this.streamManager.abortAll(); } + getAgentTools(): AgentToolDefinition[] { + return toolsFromRegistry(this.tools); + } + + async executeAgentTool( + name: string, + args: unknown, + signal?: AbortSignal, + ): Promise { + return executeFromRegistry(this.tools, name, args, signal); + } + exports() { return { sendMessage: this.sendMessage, diff --git a/packages/appkit/src/plugins/genie/tests/genie.test.ts b/packages/appkit/src/plugins/genie/tests/genie.test.ts index 3cf0784d..672e6242 100644 --- a/packages/appkit/src/plugins/genie/tests/genie.test.ts +++ b/packages/appkit/src/plugins/genie/tests/genie.test.ts @@ -187,6 +187,30 @@ describe("Genie Plugin", () => { }); }); + describe("getAgentTools / executeAgentTool", () => { + test("produces independent tool entries per configured space", () => { + const plugin = new GeniePlugin(config); + const names = plugin.getAgentTools().map((t) => t.name); + + expect(names).toContain("myspace.sendMessage"); + expect(names).toContain("myspace.getConversation"); + expect(names).toContain("salesbot.sendMessage"); + expect(names).toContain("salesbot.getConversation"); + expect(names).toHaveLength(4); + }); + + test("returns LLM-friendly error string for invalid tool args", async () => { + const plugin = new GeniePlugin(config); + const result = await plugin.executeAgentTool( + "myspace.getConversation", + {}, + ); + expect(typeof result).toBe("string"); + expect(result).toContain("Invalid arguments for myspace.getConversation"); + expect(result).toContain("conversationId"); + }); + }); + describe("space alias resolution", () => { test("should return 404 for unknown alias", async () => { const plugin = new GeniePlugin(config); diff --git a/packages/appkit/src/plugins/index.ts b/packages/appkit/src/plugins/index.ts index 4d58082f..17c92621 100644 --- a/packages/appkit/src/plugins/index.ts +++ b/packages/appkit/src/plugins/index.ts @@ -1,3 +1,4 @@ +export * from "./agent"; export * from "./analytics"; export * from "./files"; export * from "./genie"; diff --git a/packages/appkit/src/plugins/lakebase/lakebase.ts b/packages/appkit/src/plugins/lakebase/lakebase.ts index 3071d539..4ad3384e 100644 --- a/packages/appkit/src/plugins/lakebase/lakebase.ts +++ b/packages/appkit/src/plugins/lakebase/lakebase.ts @@ -1,4 +1,6 @@ import type { Pool, QueryResult, QueryResultRow } from "pg"; +import type { AgentToolDefinition, ToolProvider } from "shared"; +import { z } from "zod"; import { createLakebasePool, getLakebaseOrmConfig, @@ -8,6 +10,11 @@ import { import { createLogger } from "../../logging/logger"; import { Plugin, toPlugin } from "../../plugin"; import type { PluginManifest } from "../../registry"; +import { + defineTool, + executeFromRegistry, + toolsFromRegistry, +} from "../agent/tools/define-tool"; import manifest from "./manifest.json"; import type { ILakebaseConfig } from "./types"; @@ -30,7 +37,7 @@ const logger = createLogger("lakebase"); * const result = await AppKit.lakebase.query("SELECT * FROM users WHERE id = $1", [userId]); * ``` */ -class LakebasePlugin extends Plugin { +class LakebasePlugin extends Plugin implements ToolProvider { /** Plugin manifest declaring metadata and resource requirements */ static manifest = manifest as PluginManifest<"lakebase">; @@ -102,6 +109,46 @@ class LakebasePlugin extends Plugin { * - `getOrmConfig()` — Returns a config object compatible with Drizzle, TypeORM, Sequelize, etc. * - `getPgConfig()` — Returns a `pg.PoolConfig` object for manual pool construction */ + + private tools = { + query: defineTool({ + description: + "Execute a parameterized SQL query against the Lakebase PostgreSQL database. Use $1, $2, etc. as placeholders and pass values separately.", + schema: z.object({ + text: z + .string() + .describe( + "SQL query string with $1, $2, ... placeholders for parameters", + ), + values: z + .array(z.unknown()) + .optional() + .describe("Parameter values corresponding to placeholders"), + }), + annotations: { + readOnly: false, + destructive: false, + idempotent: false, + }, + handler: async (args) => { + const result = await this.query(args.text, args.values); + return result.rows; + }, + }), + }; + + getAgentTools(): AgentToolDefinition[] { + return toolsFromRegistry(this.tools); + } + + async executeAgentTool( + name: string, + args: unknown, + signal?: AbortSignal, + ): Promise { + return executeFromRegistry(this.tools, name, args, signal); + } + exports() { return { // biome-ignore lint/style/noNonNullAssertion: pool is guaranteed non-null after setup(), which AppKit always awaits before exposing the plugin API diff --git a/packages/appkit/src/plugins/server/index.ts b/packages/appkit/src/plugins/server/index.ts index e7b9b31a..75d3e1d0 100644 --- a/packages/appkit/src/plugins/server/index.ts +++ b/packages/appkit/src/plugins/server/index.ts @@ -179,6 +179,16 @@ export class ServerPlugin extends Plugin { return this; } + /** + * Register a server extension from another plugin during setup. + * Unlike extend(), this does not guard on autoStart — it's designed + * for internal plugin-to-plugin coordination where extensions are + * registered before the server starts listening. + */ + addExtension(fn: (app: express.Application) => void) { + this.serverExtensions.push(fn); + } + /** * Setup the routes with the plugins. * diff --git a/template/appkit.plugins.json b/template/appkit.plugins.json index d1420d2e..a9ca281d 100644 --- a/template/appkit.plugins.json +++ b/template/appkit.plugins.json @@ -2,6 +2,16 @@ "$schema": "https://databricks.github.io/appkit/schemas/template-plugins.schema.json", "version": "1.0", "plugins": { + "agent": { + "name": "agent", + "displayName": "Agent Plugin", + "description": "Framework-agnostic AI agent with auto-tool-discovery from all registered plugins", + "package": "@databricks/appkit", + "resources": { + "required": [], + "optional": [] + } + }, "analytics": { "name": "analytics", "displayName": "Analytics Plugin", @@ -149,30 +159,6 @@ "optional": [] }, "requiredByTemplate": true - }, - "serving": { - "name": "serving", - "displayName": "Model Serving Plugin", - "description": "Authenticated proxy to Databricks Model Serving endpoints", - "package": "@databricks/appkit", - "resources": { - "required": [ - { - "type": "serving_endpoint", - "alias": "Serving Endpoint", - "resourceKey": "serving-endpoint", - "description": "Model Serving endpoint for inference", - "permission": "CAN_QUERY", - "fields": { - "name": { - "env": "DATABRICKS_SERVING_ENDPOINT_NAME", - "description": "Serving endpoint name" - } - } - } - ], - "optional": [] - } } } }