diff --git a/packages/appkit/src/core/appkit.ts b/packages/appkit/src/core/appkit.ts index a2cba994..9e00da1d 100644 --- a/packages/appkit/src/core/appkit.ts +++ b/packages/appkit/src/core/appkit.ts @@ -13,14 +13,18 @@ import { ServiceContext } from "../context"; import { ResourceRegistry, ResourceType } from "../registry"; import type { TelemetryConfig } from "../telemetry"; import { TelemetryManager } from "../telemetry"; +import { isToolProvider, PluginContext } from "./plugin-context"; export class AppKit { #pluginInstances: Record = {}; #setupPromises: Promise[] = []; + #context: PluginContext; private constructor(config: { plugins: TPlugins }) { const { plugins, ...globalConfig } = config; + this.#context = new PluginContext(); + const pluginEntries = Object.entries(plugins); const corePlugins = pluginEntries.filter(([_, p]) => { @@ -35,20 +39,24 @@ export class AppKit { for (const [name, pluginData] of corePlugins) { if (pluginData) { - this.createAndRegisterPlugin(globalConfig, name, pluginData); + this.createAndRegisterPlugin(globalConfig, name, pluginData, { + context: this.#context, + }); } } for (const [name, pluginData] of normalPlugins) { if (pluginData) { - this.createAndRegisterPlugin(globalConfig, name, pluginData); + this.createAndRegisterPlugin(globalConfig, name, pluginData, { + context: this.#context, + }); } } for (const [name, pluginData] of deferredPlugins) { if (pluginData) { this.createAndRegisterPlugin(globalConfig, name, pluginData, { - plugins: this.#pluginInstances, + context: this.#context, }); } } @@ -72,6 +80,11 @@ export class AppKit { this.#pluginInstances[name] = pluginInstance; + this.#context.registerPlugin(name, pluginInstance); + if (isToolProvider(pluginInstance)) { + this.#context.registerToolProvider(name, pluginInstance); + } + this.#setupPromises.push(pluginInstance.setup()); const self = this; @@ -199,6 +212,7 @@ export class AppKit { const instance = new AppKit(mergedConfig); await Promise.all(instance.#setupPromises); + await instance.#context.emitLifecycle("setup:complete"); return instance as unknown as PluginMap; } diff --git a/packages/appkit/src/core/plugin-context.ts b/packages/appkit/src/core/plugin-context.ts new file mode 100644 index 00000000..c2801585 --- /dev/null +++ b/packages/appkit/src/core/plugin-context.ts @@ -0,0 +1,287 @@ +import type express from "express"; +import type { BasePlugin, ToolProvider } from "shared"; +import { createLogger } from "../logging/logger"; +import { TelemetryManager } from "../telemetry"; + +const logger = createLogger("plugin-context"); + +interface BufferedRoute { + method: string; + path: string; + handlers: express.RequestHandler[]; +} + +interface RouteTarget { + addExtension(fn: (app: express.Application) => void): void; +} + +interface ToolProviderEntry { + plugin: BasePlugin & ToolProvider; + name: string; +} + +type LifecycleEvent = "setup:complete" | "server:ready" | "shutdown"; + +/** + * Mediator for inter-plugin communication. + * + * Created by AppKit core and passed to every plugin. Plugins request + * capabilities from the context instead of holding direct references + * to sibling plugin instances. + * + * Capabilities: + * - Route mounting with buffering (order-independent) + * - Typed ToolProvider registry (live, not snapshot-based) + * - User-scoped tool execution with automatic telemetry + * - Lifecycle hooks for plugin coordination + */ +export class PluginContext { + private routeBuffer: BufferedRoute[] = []; + private routeTarget: RouteTarget | null = null; + private toolProviders = new Map(); + private plugins = new Map(); + private lifecycleHooks = new Map< + LifecycleEvent, + Set<() => void | Promise> + >(); + private telemetry = TelemetryManager.getProvider("plugin-context"); + + /** + * Register a route on the root Express application. + * + * If a route target (server plugin) has registered, the route is applied + * immediately. Otherwise it is buffered and flushed when a route target + * becomes available. + */ + addRoute( + method: string, + path: string, + ...handlers: express.RequestHandler[] + ): void { + if (this.routeTarget) { + this.applyRoute({ method, path, handlers }); + } else { + this.routeBuffer.push({ method, path, handlers }); + } + } + + /** + * Register middleware on the root Express application. + * + * Same buffering semantics as `addRoute`. + */ + addMiddleware(path: string, ...handlers: express.RequestHandler[]): void { + if (this.routeTarget) { + this.applyMiddleware(path, handlers); + } else { + this.routeBuffer.push({ method: "use", path, handlers }); + } + } + + /** + * Called by the server plugin to opt in as the route target. + * Flushes all buffered routes via the server's `addExtension`. + */ + registerAsRouteTarget(target: RouteTarget): void { + this.routeTarget = target; + + for (const route of this.routeBuffer) { + if (route.method === "use") { + this.applyMiddleware(route.path, route.handlers); + } else { + this.applyRoute(route); + } + } + this.routeBuffer = []; + } + + /** + * Register a plugin that implements the ToolProvider interface. + * Called by AppKit core after constructing each plugin. + */ + registerToolProvider(name: string, plugin: BasePlugin & ToolProvider): void { + this.toolProviders.set(name, { plugin, name }); + } + + /** + * Register a plugin instance. + * Called by AppKit core after constructing each plugin. + */ + registerPlugin(name: string, instance: BasePlugin): void { + this.plugins.set(name, instance); + } + + /** + * Returns all registered plugin instances keyed by name. + * Used by the server plugin for route injection, client config, + * and shutdown coordination. + */ + getPlugins(): Map { + return this.plugins; + } + + /** + * Returns all registered ToolProvider plugins. + * Always returns the current set — not a frozen snapshot. + */ + getToolProviders(): Array<{ name: string; provider: ToolProvider }> { + return Array.from(this.toolProviders.values()).map((entry) => ({ + name: entry.name, + provider: entry.plugin, + })); + } + + /** + * Execute a tool on a ToolProvider plugin with automatic user scoping + * and telemetry. + * + * The context: + * 1. Resolves the plugin by name + * 2. Calls `asUser(req)` for user-scoped execution + * 3. Wraps the call in a telemetry span with a 30s timeout + */ + async executeTool( + req: express.Request, + pluginName: string, + toolName: string, + args: unknown, + signal?: AbortSignal, + ): Promise { + const entry = this.toolProviders.get(pluginName); + if (!entry) { + throw new Error( + `PluginContext: unknown plugin "${pluginName}". Available: ${Array.from(this.toolProviders.keys()).join(", ")}`, + ); + } + + const tracer = this.telemetry.getTracer(); + const operationName = `executeTool:${pluginName}.${toolName}`; + + return tracer.startActiveSpan(operationName, async (span) => { + const timeout = 30_000; + const timeoutSignal = AbortSignal.timeout(timeout); + const combinedSignal = signal + ? AbortSignal.any([signal, timeoutSignal]) + : timeoutSignal; + + try { + const userPlugin = (entry.plugin as any).asUser(req); + const result = await (userPlugin as ToolProvider).executeAgentTool( + toolName, + args, + combinedSignal, + ); + span.setStatus({ code: 0 }); + return result; + } catch (error) { + span.setStatus({ + code: 2, + message: + error instanceof Error ? error.message : "Tool execution failed", + }); + span.recordException( + error instanceof Error ? error : new Error(String(error)), + ); + throw error; + } finally { + span.end(); + } + }); + } + + /** + * Register a lifecycle hook callback. + */ + onLifecycle(event: LifecycleEvent, fn: () => void | Promise): void { + let hooks = this.lifecycleHooks.get(event); + if (!hooks) { + hooks = new Set(); + this.lifecycleHooks.set(event, hooks); + } + hooks.add(fn); + } + + /** + * Emit a lifecycle event, calling all registered callbacks. + * Errors in individual callbacks are logged but do not prevent + * other callbacks from running. + * + * @internal Called by AppKit core only. + */ + async emitLifecycle(event: LifecycleEvent): Promise { + const hooks = this.lifecycleHooks.get(event); + if (!hooks) return; + + if ( + event === "setup:complete" && + this.routeBuffer.length > 0 && + !this.routeTarget + ) { + logger.warn( + "%d buffered routes were never applied — no server plugin registered as route target", + this.routeBuffer.length, + ); + } + + for (const fn of hooks) { + try { + await fn(); + } catch (error) { + logger.error("Lifecycle hook '%s' failed: %O", event, error); + } + } + } + + /** + * Returns all registered plugin names. + */ + getPluginNames(): string[] { + return Array.from(this.plugins.keys()); + } + + /** + * Check if a plugin with the given name is registered. + */ + hasPlugin(name: string): boolean { + return this.plugins.has(name); + } + + private applyRoute(route: BufferedRoute): void { + if (!this.routeTarget) return; + this.routeTarget.addExtension((app) => { + const method = route.method.toLowerCase() as keyof express.Application; + if (typeof app[method] === "function") { + (app[method] as (...a: unknown[]) => void)( + route.path, + ...route.handlers, + ); + } + }); + } + + private applyMiddleware( + path: string, + handlers: express.RequestHandler[], + ): void { + if (!this.routeTarget) return; + this.routeTarget.addExtension((app) => { + app.use(path, ...handlers); + }); + } +} + +/** + * Type guard: checks whether a plugin implements the ToolProvider interface. + */ +export function isToolProvider( + plugin: unknown, +): plugin is BasePlugin & ToolProvider { + return ( + typeof plugin === "object" && + plugin !== null && + "getAgentTools" in plugin && + typeof (plugin as ToolProvider).getAgentTools === "function" && + "executeAgentTool" in plugin && + typeof (plugin as ToolProvider).executeAgentTool === "function" + ); +} diff --git a/packages/appkit/src/core/tests/databricks.test.ts b/packages/appkit/src/core/tests/databricks.test.ts index c05345a6..9d3fe5f8 100644 --- a/packages/appkit/src/core/tests/databricks.test.ts +++ b/packages/appkit/src/core/tests/databricks.test.ts @@ -109,11 +109,11 @@ class DeferredTestPlugin implements BasePlugin { name = "deferredTest"; setupCalled = false; injectedConfig: any; - injectedPlugins: any; + injectedContext: any; constructor(config: any) { this.injectedConfig = config; - this.injectedPlugins = config.plugins; + this.injectedContext = config.context; } async setup() { @@ -130,7 +130,7 @@ class DeferredTestPlugin implements BasePlugin { return { setupCalled: this.setupCalled, injectedConfig: this.injectedConfig, - injectedPlugins: this.injectedPlugins, + injectedContext: this.injectedContext, }; } } @@ -276,7 +276,7 @@ describe("AppKit", () => { expect(setupOrder).toEqual(["core", "normal", "deferred"]); }); - test("should provide plugin instances to deferred plugins", async () => { + test("should provide PluginContext to deferred plugins", async () => { const pluginData = [ { plugin: CoreTestPlugin, config: {}, name: "coreTest" }, { plugin: DeferredTestPlugin, config: {}, name: "deferredTest" }, @@ -284,10 +284,9 @@ describe("AppKit", () => { const instance = (await createApp({ plugins: pluginData })) as any; - // Deferred plugins receive plugin instances (not SDKs) for internal use - expect(instance.deferredTest.injectedPlugins).toBeDefined(); - expect(instance.deferredTest.injectedPlugins.coreTest).toBeInstanceOf( - CoreTestPlugin, + expect(instance.deferredTest.injectedContext).toBeDefined(); + expect(instance.deferredTest.injectedContext.hasPlugin("coreTest")).toBe( + true, ); }); diff --git a/packages/appkit/src/core/tests/plugin-context.test.ts b/packages/appkit/src/core/tests/plugin-context.test.ts new file mode 100644 index 00000000..276c5502 --- /dev/null +++ b/packages/appkit/src/core/tests/plugin-context.test.ts @@ -0,0 +1,325 @@ +import type { AgentToolDefinition } from "shared"; +import { beforeEach, describe, expect, test, vi } from "vitest"; +import { isToolProvider, PluginContext } from "../plugin-context"; + +vi.mock("../../telemetry", () => ({ + TelemetryManager: { + getProvider: () => ({ + getTracer: () => ({ + startActiveSpan: (_name: string, fn: (span: any) => any) => { + const span = { + setStatus: vi.fn(), + recordException: vi.fn(), + end: vi.fn(), + }; + return fn(span); + }, + }), + }), + }, +})); + +vi.mock("../../logging/logger", () => ({ + createLogger: () => ({ + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }), +})); + +function createMockToolProvider(tools: AgentToolDefinition[] = []) { + const mock = { + name: "mock-plugin", + setup: vi.fn().mockResolvedValue(undefined), + injectRoutes: vi.fn(), + getEndpoints: vi.fn().mockReturnValue({}), + getAgentTools: vi.fn().mockReturnValue(tools), + executeAgentTool: vi.fn().mockResolvedValue("tool-result"), + asUser: vi.fn().mockReturnThis(), + }; + return mock as any; +} + +describe("PluginContext", () => { + let ctx: PluginContext; + + beforeEach(() => { + ctx = new PluginContext(); + }); + + describe("route buffering", () => { + test("addRoute buffers when no route target exists", () => { + const handler = vi.fn(); + ctx.addRoute("post", "/invocations", handler); + + expect(ctx.getPluginNames()).toEqual([]); + }); + + test("flushRoutes applies buffered routes via addExtension", () => { + const handler = vi.fn(); + ctx.addRoute("post", "/invocations", handler); + + const addExtension = vi.fn(); + ctx.registerAsRouteTarget({ addExtension }); + + expect(addExtension).toHaveBeenCalledTimes(1); + const extensionFn = addExtension.mock.calls[0][0]; + + const mockApp = { post: vi.fn() }; + extensionFn(mockApp); + expect(mockApp.post).toHaveBeenCalledWith("/invocations", handler); + }); + + test("addRoute called after registerAsRouteTarget applies immediately", () => { + const addExtension = vi.fn(); + ctx.registerAsRouteTarget({ addExtension }); + + const handler = vi.fn(); + ctx.addRoute("get", "/health", handler); + + expect(addExtension).toHaveBeenCalledTimes(1); + const extensionFn = addExtension.mock.calls[0][0]; + + const mockApp = { get: vi.fn() }; + extensionFn(mockApp); + expect(mockApp.get).toHaveBeenCalledWith("/health", handler); + }); + + test("addRoute supports middleware chains", () => { + const auth = vi.fn(); + const handler = vi.fn(); + + const addExtension = vi.fn(); + ctx.registerAsRouteTarget({ addExtension }); + + ctx.addRoute("post", "/api", auth, handler); + + const extensionFn = addExtension.mock.calls[0][0]; + const mockApp = { post: vi.fn() }; + extensionFn(mockApp); + expect(mockApp.post).toHaveBeenCalledWith("/api", auth, handler); + }); + + test("addMiddleware buffers and applies via use()", () => { + const handler = vi.fn(); + ctx.addMiddleware("/api", handler); + + const addExtension = vi.fn(); + ctx.registerAsRouteTarget({ addExtension }); + + expect(addExtension).toHaveBeenCalledTimes(1); + const extensionFn = addExtension.mock.calls[0][0]; + + const mockApp = { use: vi.fn() }; + extensionFn(mockApp); + expect(mockApp.use).toHaveBeenCalledWith("/api", handler); + }); + + test("multiple buffered routes are all applied on registration", () => { + const h1 = vi.fn(); + const h2 = vi.fn(); + ctx.addRoute("post", "/a", h1); + ctx.addRoute("get", "/b", h2); + + const addExtension = vi.fn(); + ctx.registerAsRouteTarget({ addExtension }); + + expect(addExtension).toHaveBeenCalledTimes(2); + }); + }); + + describe("ToolProvider registry", () => { + test("registerToolProvider makes provider visible via getToolProviders", () => { + const provider = createMockToolProvider([ + { + name: "query", + description: "Run query", + parameters: { type: "object" }, + }, + ]); + + ctx.registerToolProvider("analytics", provider); + + const providers = ctx.getToolProviders(); + expect(providers).toHaveLength(1); + expect(providers[0].name).toBe("analytics"); + expect(providers[0].provider.getAgentTools()).toHaveLength(1); + }); + + test("getToolProviders returns all registered providers", () => { + ctx.registerToolProvider("analytics", createMockToolProvider()); + ctx.registerToolProvider("files", createMockToolProvider()); + ctx.registerToolProvider("genie", createMockToolProvider()); + + expect(ctx.getToolProviders()).toHaveLength(3); + }); + + test("getToolProviders returns current set, not snapshot", () => { + const before = ctx.getToolProviders(); + expect(before).toHaveLength(0); + + ctx.registerToolProvider("analytics", createMockToolProvider()); + + const after = ctx.getToolProviders(); + expect(after).toHaveLength(1); + }); + }); + + describe("executeTool", () => { + test("calls asUser(req).executeAgentTool on the correct plugin", async () => { + const provider = createMockToolProvider(); + ctx.registerToolProvider("analytics", provider); + + const mockReq = { headers: {} } as any; + await ctx.executeTool(mockReq, "analytics", "query", { sql: "SELECT 1" }); + + expect(provider.asUser).toHaveBeenCalledWith(mockReq); + expect(provider.executeAgentTool).toHaveBeenCalledWith( + "query", + { sql: "SELECT 1" }, + expect.any(Object), + ); + }); + + test("throws for unknown plugin name", async () => { + const mockReq = { headers: {} } as any; + + await expect( + ctx.executeTool(mockReq, "nonexistent", "query", {}), + ).rejects.toThrow('unknown plugin "nonexistent"'); + }); + + test("propagates tool execution errors", async () => { + const provider = createMockToolProvider(); + (provider.executeAgentTool as any).mockRejectedValue( + new Error("Query failed"), + ); + ctx.registerToolProvider("analytics", provider); + + const mockReq = { headers: {} } as any; + + await expect( + ctx.executeTool(mockReq, "analytics", "query", {}), + ).rejects.toThrow("Query failed"); + }); + + test("passes abort signal to executeAgentTool", async () => { + const provider = createMockToolProvider(); + ctx.registerToolProvider("analytics", provider); + + const controller = new AbortController(); + const mockReq = { headers: {} } as any; + + await ctx.executeTool( + mockReq, + "analytics", + "query", + {}, + controller.signal, + ); + + const callArgs = (provider.executeAgentTool as any).mock.calls[0]; + expect(callArgs[2]).toBeDefined(); + }); + }); + + describe("lifecycle hooks", () => { + test("onLifecycle registers callback, emitLifecycle invokes it", async () => { + const fn = vi.fn(); + ctx.onLifecycle("setup:complete", fn); + + await ctx.emitLifecycle("setup:complete"); + + expect(fn).toHaveBeenCalledTimes(1); + }); + + test("multiple callbacks for the same event all fire", async () => { + const fn1 = vi.fn(); + const fn2 = vi.fn(); + ctx.onLifecycle("setup:complete", fn1); + ctx.onLifecycle("setup:complete", fn2); + + await ctx.emitLifecycle("setup:complete"); + + expect(fn1).toHaveBeenCalledTimes(1); + expect(fn2).toHaveBeenCalledTimes(1); + }); + + test("callback error does not prevent other callbacks from running", async () => { + const fn1 = vi.fn().mockRejectedValue(new Error("fail")); + const fn2 = vi.fn(); + ctx.onLifecycle("shutdown", fn1); + ctx.onLifecycle("shutdown", fn2); + + await ctx.emitLifecycle("shutdown"); + + expect(fn1).toHaveBeenCalled(); + expect(fn2).toHaveBeenCalled(); + }); + + test("emitLifecycle with no registered hooks does nothing", async () => { + await expect(ctx.emitLifecycle("server:ready")).resolves.toBeUndefined(); + }); + }); + + describe("plugin metadata", () => { + const stubPlugin = { name: "stub" } as any; + + test("getPluginNames returns all registered names", () => { + ctx.registerPlugin("analytics", stubPlugin); + ctx.registerPlugin("server", stubPlugin); + ctx.registerPlugin("agent", stubPlugin); + + const names = ctx.getPluginNames(); + expect(names).toContain("analytics"); + expect(names).toContain("server"); + expect(names).toContain("agent"); + expect(names).toHaveLength(3); + }); + + test("hasPlugin returns true for registered plugins", () => { + ctx.registerPlugin("analytics", stubPlugin); + + expect(ctx.hasPlugin("analytics")).toBe(true); + expect(ctx.hasPlugin("nonexistent")).toBe(false); + }); + + test("getPlugins returns all registered instances", () => { + const p1 = { name: "analytics" } as any; + const p2 = { name: "server" } as any; + ctx.registerPlugin("analytics", p1); + ctx.registerPlugin("server", p2); + + const plugins = ctx.getPlugins(); + expect(plugins.size).toBe(2); + expect(plugins.get("analytics")).toBe(p1); + expect(plugins.get("server")).toBe(p2); + }); + }); +}); + +describe("isToolProvider", () => { + test("returns true for objects with getAgentTools and executeAgentTool", () => { + const provider = createMockToolProvider(); + expect(isToolProvider(provider)).toBe(true); + }); + + test("returns false for null", () => { + expect(isToolProvider(null)).toBe(false); + }); + + test("returns false for objects missing executeAgentTool", () => { + expect(isToolProvider({ getAgentTools: vi.fn() })).toBe(false); + }); + + test("returns false for objects missing getAgentTools", () => { + expect(isToolProvider({ executeAgentTool: vi.fn() })).toBe(false); + }); + + test("returns false for non-objects", () => { + expect(isToolProvider("string")).toBe(false); + expect(isToolProvider(42)).toBe(false); + expect(isToolProvider(undefined)).toBe(false); + }); +}); diff --git a/packages/appkit/src/plugin/plugin.ts b/packages/appkit/src/plugin/plugin.ts index 5173cb61..1581dc53 100644 --- a/packages/appkit/src/plugin/plugin.ts +++ b/packages/appkit/src/plugin/plugin.ts @@ -19,6 +19,7 @@ import { ServiceContext, type UserContext, } from "../context"; +import type { PluginContext } from "../core/plugin-context"; import { AppKitError, AuthenticationError } from "../errors"; import { createLogger } from "../logging/logger"; import { StreamManager } from "../stream"; @@ -168,6 +169,7 @@ export abstract class Plugin< protected devFileReader: DevFileReader; protected streamManager: StreamManager; protected telemetry: ITelemetry; + protected context?: PluginContext; /** Registered endpoints for this plugin */ private registeredEndpoints: PluginEndpointMap = {}; @@ -198,6 +200,9 @@ export abstract class Plugin< this.cache = CacheManager.getInstanceSync(); this.app = new AppManager(); this.devFileReader = DevFileReader.getInstance(); + this.context = (config as Record).context as + | PluginContext + | undefined; this.isReady = true; } diff --git a/packages/appkit/src/plugins/agent/agent.ts b/packages/appkit/src/plugins/agent/agent.ts index 15827cad..9873291c 100644 --- a/packages/appkit/src/plugins/agent/agent.ts +++ b/packages/appkit/src/plugins/agent/agent.ts @@ -9,7 +9,6 @@ import type { Message, PluginPhase, ResponseStreamEvent, - ToolProvider, } from "shared"; import { createLogger } from "../../logging/logger"; import { Plugin, toPlugin } from "../../plugin"; @@ -33,17 +32,6 @@ import type { AgentPluginConfig, RegisteredAgent, ToolEntry } from "./types"; const logger = createLogger("agent"); -function isToolProvider(obj: unknown): obj is ToolProvider { - return ( - typeof obj === "object" && - obj !== null && - "getAgentTools" in obj && - typeof (obj as any).getAgentTools === "function" && - "executeAgentTool" in obj && - typeof (obj as any).executeAgentTool === "function" - ); -} - export class AgentPlugin extends Plugin { static manifest = manifest as PluginManifest<"agent">; static phase: PluginPhase = "deferred"; @@ -161,40 +149,32 @@ export class AgentPlugin extends Plugin { } private mountInvocationsRoute() { - const serverPlugin = this.config.plugins?.server as - | { addExtension?: (fn: (app: any) => void) => void } - | undefined; - - if (!serverPlugin?.addExtension) return; + if (!this.context) return; - serverPlugin.addExtension((app: import("express").Application) => { - app.post( - "/invocations", - (req: express.Request, res: express.Response) => { - this._handleInvocations(req, res); - }, - ); - }); + this.context.addRoute( + "post", + "/invocations", + (req: express.Request, res: express.Response) => { + this._handleInvocations(req, res); + }, + ); - logger.info("Mounted POST /invocations route"); + logger.info("Registered POST /invocations route via PluginContext"); } private async collectTools() { - // 1. Auto-discover from sibling ToolProvider plugins - const plugins = this.config.plugins; - if (plugins) { - for (const [pluginName, pluginInstance] of Object.entries(plugins)) { - if (pluginName === "agent") continue; - if (!isToolProvider(pluginInstance)) continue; - - const tools = (pluginInstance as ToolProvider).getAgentTools(); + // 1. Auto-discover from sibling ToolProvider plugins via PluginContext + if (this.context) { + for (const { + name: pluginName, + provider, + } of this.context.getToolProviders()) { + const tools = provider.getAgentTools(); for (const tool of tools) { const qualifiedName = `${pluginName}.${tool.name}`; this.toolIndex.set(qualifiedName, { source: "plugin", - plugin: pluginInstance as ToolProvider & { - asUser(req: any): any; - }, + pluginName, def: { ...tool, name: qualifiedName }, localName: tool.name, }); @@ -473,41 +453,51 @@ export class AgentPlugin extends Plugin { const entry = self.toolIndex.get(qualifiedName); if (!entry) throw new Error(`Unknown tool: ${qualifiedName}`); - const result = await self.execute( - async (execSignal) => { - switch (entry.source) { - case "plugin": { - const target = (entry.plugin as any).asUser(req); - return (target as ToolProvider).executeAgentTool( - entry.localName, - args, - execSignal, - ); - } - case "function": - return entry.functionTool.execute( - args as Record, - ); - case "mcp": { - if (!self.mcpClient) { - throw new Error("MCP client not connected"); + let result: unknown; + + if (entry.source === "plugin" && self.context) { + result = await self.context.executeTool( + req, + entry.pluginName, + entry.localName, + args, + signal, + ); + } else { + result = await self.execute( + async (_execSignal) => { + switch (entry.source) { + case "plugin": + throw new Error("Plugin tool execution requires PluginContext"); + case "function": + return entry.functionTool.execute( + args as Record, + ); + case "mcp": { + if (!self.mcpClient) { + throw new Error("MCP client not connected"); + } + const oboToken = req.headers["x-forwarded-access-token"]; + const mcpAuth = + typeof oboToken === "string" + ? { Authorization: `Bearer ${oboToken}` } + : undefined; + return self.mcpClient.callTool( + entry.mcpToolName, + args, + mcpAuth, + ); } - const oboToken = req.headers["x-forwarded-access-token"]; - const mcpAuth = - typeof oboToken === "string" - ? { Authorization: `Bearer ${oboToken}` } - : undefined; - return self.mcpClient.callTool(entry.mcpToolName, args, mcpAuth); } - } - }, - { - default: { - telemetryInterceptor: { enabled: true }, - timeout: 30_000, }, - }, - ); + { + default: { + telemetryInterceptor: { enabled: true }, + timeout: 30_000, + }, + }, + ); + } if (result === undefined) { return `Error: Tool "${qualifiedName}" execution failed`; @@ -537,10 +527,10 @@ export class AgentPlugin extends Plugin { yield evt; } - const pluginNames = self.config.plugins - ? Object.keys(self.config.plugins).filter( - (n) => n !== "agent" && n !== "server", - ) + const pluginNames = self.context + ? self.context + .getPluginNames() + .filter((n) => n !== "agent" && n !== "server") : []; const basePrompt = buildBaseSystemPrompt(pluginNames); const fullPrompt = composeSystemPrompt( diff --git a/packages/appkit/src/plugins/agent/tests/agent.test.ts b/packages/appkit/src/plugins/agent/tests/agent.test.ts index 357d68d0..128e6157 100644 --- a/packages/appkit/src/plugins/agent/tests/agent.test.ts +++ b/packages/appkit/src/plugins/agent/tests/agent.test.ts @@ -13,6 +13,7 @@ import type { ToolProvider, } from "shared"; import { beforeEach, describe, expect, test, vi } from "vitest"; +import { PluginContext } from "../../../core/plugin-context"; import { AgentPlugin } from "../agent"; vi.mock("../../../cache", () => ({ @@ -42,7 +43,16 @@ vi.mock("../../../context", async (importOriginal) => { vi.mock("../../../telemetry", () => ({ TelemetryManager: { getProvider: vi.fn(() => ({ - getTracer: vi.fn(), + getTracer: vi.fn(() => ({ + startActiveSpan: (_name: string, fn: (span: any) => any) => { + const span = { + setStatus: vi.fn(), + recordException: vi.fn(), + end: vi.fn(), + }; + return fn(span); + }, + })), getMeter: vi.fn(), getLogger: vi.fn(), emit: vi.fn(), @@ -61,10 +71,25 @@ function createMockToolProvider( tools: AgentToolDefinition[], ): ToolProvider & { asUser: any } { return { + name: "mock-plugin", getAgentTools: () => tools, executeAgentTool: vi.fn().mockResolvedValue({ result: "ok" }), asUser: vi.fn().mockReturnThis(), - }; + } as any; +} + +function createMockContext( + providers: Array<{ + name: string; + provider: ToolProvider & { asUser: any }; + }> = [], +): PluginContext { + const ctx = new PluginContext(); + for (const { name, provider } of providers) { + ctx.registerToolProvider(name, provider as any); + ctx.registerPlugin(name, provider as any); + } + return ctx; } async function* mockAdapterRun(): AsyncGenerator { @@ -83,7 +108,7 @@ describe("AgentPlugin", () => { setupDatabricksEnv(); }); - test("collectTools discovers ToolProvider plugins", async () => { + test("collectTools discovers ToolProvider plugins via context", async () => { const mockProvider = createMockToolProvider([ { name: "query", @@ -92,10 +117,14 @@ describe("AgentPlugin", () => { }, ]); + const context = createMockContext([ + { name: "analytics", provider: mockProvider }, + ]); + const plugin = new AgentPlugin({ name: "agent", - plugins: { analytics: mockProvider }, - }); + context, + } as any); await plugin.setup(); @@ -106,20 +135,14 @@ describe("AgentPlugin", () => { expect(tools[0].name).toBe("analytics.query"); }); - test("skips non-ToolProvider plugins", async () => { + test("works with no context (backward compat)", async () => { const plugin = new AgentPlugin({ name: "agent", - plugins: { - server: { name: "server" }, - analytics: createMockToolProvider([ - { name: "query", description: "q", parameters: { type: "object" } }, - ]), - }, }); await plugin.setup(); const tools = plugin.exports().getTools(); - expect(tools).toHaveLength(1); + expect(tools).toEqual([]); }); test("registerAgent and resolveAgent", () => { @@ -128,7 +151,6 @@ describe("AgentPlugin", () => { plugin.exports().registerAgent("assistant", adapter); - // The first registered agent becomes the default const tools = plugin.exports().getTools(); expect(tools).toEqual([]); }); @@ -177,15 +199,39 @@ describe("AgentPlugin", () => { expect(tools[0].name).toBe("myTool"); }); - test("executeTool always calls asUser(req) for plugin tools, even without requiresUserContext", async () => { + test("mountInvocationsRoute registers via context.addRoute", async () => { + const context = createMockContext(); + const addRouteSpy = vi.spyOn(context, "addRoute"); + + const plugin = new AgentPlugin({ + name: "agent", + agents: { assistant: createMockAdapter() }, + context, + } as any); + + await plugin.setup(); + + expect(addRouteSpy).toHaveBeenCalledWith( + "post", + "/invocations", + expect.any(Function), + ); + }); + + test("executeTool calls context.executeTool for plugin tools", async () => { const mockProvider = createMockToolProvider([ { name: "action", - description: "An action without requiresUserContext", + description: "An action", parameters: { type: "object", properties: {} }, }, ]); + const context = createMockContext([ + { name: "testplugin", provider: mockProvider }, + ]); + const executeToolSpy = vi.spyOn(context, "executeTool"); + function createToolCallingAdapter(): AgentAdapter { return { async *run( @@ -201,8 +247,8 @@ describe("AgentPlugin", () => { const plugin = new AgentPlugin({ name: "agent", agents: { assistant: createToolCallingAdapter() }, - plugins: { testplugin: mockProvider }, - }); + context, + } as any); await plugin.setup(); const { router, getHandler } = createMockRouter(); @@ -220,8 +266,9 @@ describe("AgentPlugin", () => { await handler(req, res); - expect(mockProvider.asUser).toHaveBeenCalledWith(req); - expect(mockProvider.executeAgentTool).toHaveBeenCalledWith( + expect(executeToolSpy).toHaveBeenCalledWith( + req, + "testplugin", "action", {}, expect.anything(), diff --git a/packages/appkit/src/plugins/agent/types.ts b/packages/appkit/src/plugins/agent/types.ts index 67cc2c8b..5ab8fc91 100644 --- a/packages/appkit/src/plugins/agent/types.ts +++ b/packages/appkit/src/plugins/agent/types.ts @@ -3,7 +3,6 @@ import type { AgentToolDefinition, BasePluginConfig, ThreadStore, - ToolProvider, } from "shared"; import type { FunctionTool } from "./tools/function-tool"; import type { HostedTool } from "./tools/hosted-tools"; @@ -29,7 +28,7 @@ export interface AgentPluginConfig extends BasePluginConfig { export type ToolEntry = | { source: "plugin"; - plugin: ToolProvider & { asUser(req: any): any }; + pluginName: string; def: AgentToolDefinition; localName: string; } @@ -50,8 +49,4 @@ export type RegisteredAgent = { systemPrompt?: string; }; -export type { - AgentAdapter, - AgentToolDefinition, - ToolProvider, -} from "shared"; +export type { AgentAdapter, AgentToolDefinition } from "shared"; diff --git a/packages/appkit/src/plugins/server/index.ts b/packages/appkit/src/plugins/server/index.ts index 75d3e1d0..b6efd016 100644 --- a/packages/appkit/src/plugins/server/index.ts +++ b/packages/appkit/src/plugins/server/index.ts @@ -63,6 +63,7 @@ export class ServerPlugin extends Plugin { instrumentations.http, instrumentations.express, ]); + this.context?.registerAsRouteTarget(this); } /** Setup the server plugin. */ @@ -203,14 +204,15 @@ export class ServerPlugin extends Plugin { const endpoints: PluginEndpoints = {}; const pluginConfigs: PluginClientConfigs = {}; - if (!this.config.plugins) return { endpoints, pluginConfigs }; + const plugins = this.context?.getPlugins(); + if (!plugins || plugins.size === 0) return { endpoints, pluginConfigs }; this.serverApplication.get("/health", (_, res) => { res.status(200).json({ status: "ok" }); }); this.registerEndpoint("health", "/health"); - for (const plugin of Object.values(this.config.plugins)) { + for (const plugin of plugins.values()) { if (EXCLUDED_PLUGINS.includes(plugin.name)) continue; if (plugin?.injectRoutes && typeof plugin.injectRoutes === "function") { @@ -359,8 +361,9 @@ export class ServerPlugin extends Plugin { } // 1. abort active operations from plugins - if (this.config.plugins) { - for (const plugin of Object.values(this.config.plugins)) { + const shutdownPlugins = this.context?.getPlugins(); + if (shutdownPlugins) { + for (const plugin of shutdownPlugins.values()) { if (plugin.abortActiveOperations) { try { plugin.abortActiveOperations(); diff --git a/packages/appkit/src/plugins/server/tests/server.test.ts b/packages/appkit/src/plugins/server/tests/server.test.ts index 22f18129..52d15845 100644 --- a/packages/appkit/src/plugins/server/tests/server.test.ts +++ b/packages/appkit/src/plugins/server/tests/server.test.ts @@ -1,4 +1,6 @@ +import type { BasePlugin } from "shared"; import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { PluginContext } from "../../../core/plugin-context"; // Use vi.hoisted for mocks that need to be available before module loading const { @@ -171,6 +173,14 @@ import { RemoteTunnelController } from "../remote-tunnel/remote-tunnel-controlle import { StaticServer } from "../static-server"; import { ViteDevServer } from "../vite-dev-server"; +function createContextWithPlugins(plugins: Record): PluginContext { + const ctx = new PluginContext(); + for (const [name, instance] of Object.entries(plugins)) { + ctx.registerPlugin(name, instance as BasePlugin); + } + return ctx; +} + describe("ServerPlugin", () => { let originalEnv: NodeJS.ProcessEnv; @@ -340,7 +350,7 @@ describe("ServerPlugin", () => { process.env.NODE_ENV = "production"; const injectRoutes = vi.fn(); - const plugins: any = { + const testPlugins: any = { "test-plugin": { name: "test-plugin", injectRoutes, @@ -348,7 +358,10 @@ describe("ServerPlugin", () => { }, }; - const plugin = new ServerPlugin({ autoStart: false, plugins }); + const plugin = new ServerPlugin({ + autoStart: false, + context: createContextWithPlugins(testPlugins), + } as any); await plugin.start(); const routerFn = (express as any).Router as ReturnType; @@ -386,7 +399,10 @@ describe("ServerPlugin", () => { }, }; - const plugin = new ServerPlugin({ autoStart: false, plugins }); + const plugin = new ServerPlugin({ + autoStart: false, + context: createContextWithPlugins(plugins), + } as any); await plugin.start(); expect(plugins["plugin-a"].clientConfig).toHaveBeenCalled(); @@ -413,7 +429,10 @@ describe("ServerPlugin", () => { }, }; - const plugin = new ServerPlugin({ autoStart: false, plugins }); + const plugin = new ServerPlugin({ + autoStart: false, + context: createContextWithPlugins(plugins), + } as any); await plugin.start(); expect(plugins["plugin-null"].clientConfig).toHaveBeenCalled(); @@ -444,7 +463,10 @@ describe("ServerPlugin", () => { }, }; - const plugin = new ServerPlugin({ autoStart: false, plugins }); + const plugin = new ServerPlugin({ + autoStart: false, + context: createContextWithPlugins(plugins), + } as any); await expect(plugin.start()).resolves.toBeDefined(); expect(mockLoggerError).toHaveBeenCalledWith( "Plugin '%s' clientConfig() failed, skipping its config: %O", @@ -608,19 +630,19 @@ describe("ServerPlugin", () => { const plugin = new ServerPlugin({ autoStart: false, - plugins: { + context: createContextWithPlugins({ ok: { name: "ok", abortActiveOperations: vi.fn(), - } as any, + }, bad: { name: "bad", abortActiveOperations: vi.fn(() => { throw new Error("boom"); }), - } as any, - }, - }); + }, + }), + } as any); // pretend started (plugin as any).server = mockHttpServer; diff --git a/packages/appkit/src/plugins/server/types.ts b/packages/appkit/src/plugins/server/types.ts index e187cacc..84a2327e 100644 --- a/packages/appkit/src/plugins/server/types.ts +++ b/packages/appkit/src/plugins/server/types.ts @@ -1,9 +1,7 @@ import type { BasePluginConfig } from "shared"; -import type { Plugin } from "../../plugin"; export interface ServerConfig extends BasePluginConfig { port?: number; - plugins?: Record; staticPath?: string; autoStart?: boolean; host?: string;