From 114150b75ed7925fa1dcabb981ed1ef2b8aaae4f Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Thu, 16 Apr 2026 16:16:49 +0200 Subject: [PATCH] feat(appkit): add PluginContext mediator for inter-plugin communication Introduce a PluginContext class that the AppKit core creates and passes to every plugin, mediating all inter-plugin communication: - Route buffering: addRoute() buffers until server registers via registerAsRouteTarget(), then flushes via addExtension() - ToolProvider registry: typed, live discovery (not frozen snapshots) - executeTool(): user-scoped tool execution with automatic telemetry - Lifecycle hooks: setup:complete, server:ready, shutdown coordination Migrates agent plugin from config.plugins to context methods. Updates layered prompt to use context.getPluginNames() for base prompt. Migrates server plugin from config.plugins to context.getPlugins(). Removes the deprecated plugins extraData for deferred plugins. Signed-off-by: MarioCadenas --- packages/appkit/src/core/appkit.ts | 20 +- packages/appkit/src/core/plugin-context.ts | 287 ++++++++++++++++ .../appkit/src/core/tests/databricks.test.ts | 15 +- .../src/core/tests/plugin-context.test.ts | 325 ++++++++++++++++++ packages/appkit/src/plugin/plugin.ts | 5 + packages/appkit/src/plugins/agent/agent.ts | 136 ++++---- .../src/plugins/agent/tests/agent.test.ts | 87 +++-- packages/appkit/src/plugins/agent/types.ts | 9 +- packages/appkit/src/plugins/server/index.ts | 11 +- .../src/plugins/server/tests/server.test.ts | 42 ++- packages/appkit/src/plugins/server/types.ts | 2 - 11 files changed, 812 insertions(+), 127 deletions(-) create mode 100644 packages/appkit/src/core/plugin-context.ts create mode 100644 packages/appkit/src/core/tests/plugin-context.test.ts diff --git a/packages/appkit/src/core/appkit.ts b/packages/appkit/src/core/appkit.ts index a2cba994..9e00da1d 100644 --- a/packages/appkit/src/core/appkit.ts +++ b/packages/appkit/src/core/appkit.ts @@ -13,14 +13,18 @@ import { ServiceContext } from "../context"; import { ResourceRegistry, ResourceType } from "../registry"; import type { TelemetryConfig } from "../telemetry"; import { TelemetryManager } from "../telemetry"; +import { isToolProvider, PluginContext } from "./plugin-context"; export class AppKit { #pluginInstances: Record = {}; #setupPromises: Promise[] = []; + #context: PluginContext; private constructor(config: { plugins: TPlugins }) { const { plugins, ...globalConfig } = config; + this.#context = new PluginContext(); + const pluginEntries = Object.entries(plugins); const corePlugins = pluginEntries.filter(([_, p]) => { @@ -35,20 +39,24 @@ export class AppKit { for (const [name, pluginData] of corePlugins) { if (pluginData) { - this.createAndRegisterPlugin(globalConfig, name, pluginData); + this.createAndRegisterPlugin(globalConfig, name, pluginData, { + context: this.#context, + }); } } for (const [name, pluginData] of normalPlugins) { if (pluginData) { - this.createAndRegisterPlugin(globalConfig, name, pluginData); + this.createAndRegisterPlugin(globalConfig, name, pluginData, { + context: this.#context, + }); } } for (const [name, pluginData] of deferredPlugins) { if (pluginData) { this.createAndRegisterPlugin(globalConfig, name, pluginData, { - plugins: this.#pluginInstances, + context: this.#context, }); } } @@ -72,6 +80,11 @@ export class AppKit { this.#pluginInstances[name] = pluginInstance; + this.#context.registerPlugin(name, pluginInstance); + if (isToolProvider(pluginInstance)) { + this.#context.registerToolProvider(name, pluginInstance); + } + this.#setupPromises.push(pluginInstance.setup()); const self = this; @@ -199,6 +212,7 @@ export class AppKit { const instance = new AppKit(mergedConfig); await Promise.all(instance.#setupPromises); + await instance.#context.emitLifecycle("setup:complete"); return instance as unknown as PluginMap; } diff --git a/packages/appkit/src/core/plugin-context.ts b/packages/appkit/src/core/plugin-context.ts new file mode 100644 index 00000000..c2801585 --- /dev/null +++ b/packages/appkit/src/core/plugin-context.ts @@ -0,0 +1,287 @@ +import type express from "express"; +import type { BasePlugin, ToolProvider } from "shared"; +import { createLogger } from "../logging/logger"; +import { TelemetryManager } from "../telemetry"; + +const logger = createLogger("plugin-context"); + +interface BufferedRoute { + method: string; + path: string; + handlers: express.RequestHandler[]; +} + +interface RouteTarget { + addExtension(fn: (app: express.Application) => void): void; +} + +interface ToolProviderEntry { + plugin: BasePlugin & ToolProvider; + name: string; +} + +type LifecycleEvent = "setup:complete" | "server:ready" | "shutdown"; + +/** + * Mediator for inter-plugin communication. + * + * Created by AppKit core and passed to every plugin. Plugins request + * capabilities from the context instead of holding direct references + * to sibling plugin instances. + * + * Capabilities: + * - Route mounting with buffering (order-independent) + * - Typed ToolProvider registry (live, not snapshot-based) + * - User-scoped tool execution with automatic telemetry + * - Lifecycle hooks for plugin coordination + */ +export class PluginContext { + private routeBuffer: BufferedRoute[] = []; + private routeTarget: RouteTarget | null = null; + private toolProviders = new Map(); + private plugins = new Map(); + private lifecycleHooks = new Map< + LifecycleEvent, + Set<() => void | Promise> + >(); + private telemetry = TelemetryManager.getProvider("plugin-context"); + + /** + * Register a route on the root Express application. + * + * If a route target (server plugin) has registered, the route is applied + * immediately. Otherwise it is buffered and flushed when a route target + * becomes available. + */ + addRoute( + method: string, + path: string, + ...handlers: express.RequestHandler[] + ): void { + if (this.routeTarget) { + this.applyRoute({ method, path, handlers }); + } else { + this.routeBuffer.push({ method, path, handlers }); + } + } + + /** + * Register middleware on the root Express application. + * + * Same buffering semantics as `addRoute`. + */ + addMiddleware(path: string, ...handlers: express.RequestHandler[]): void { + if (this.routeTarget) { + this.applyMiddleware(path, handlers); + } else { + this.routeBuffer.push({ method: "use", path, handlers }); + } + } + + /** + * Called by the server plugin to opt in as the route target. + * Flushes all buffered routes via the server's `addExtension`. + */ + registerAsRouteTarget(target: RouteTarget): void { + this.routeTarget = target; + + for (const route of this.routeBuffer) { + if (route.method === "use") { + this.applyMiddleware(route.path, route.handlers); + } else { + this.applyRoute(route); + } + } + this.routeBuffer = []; + } + + /** + * Register a plugin that implements the ToolProvider interface. + * Called by AppKit core after constructing each plugin. + */ + registerToolProvider(name: string, plugin: BasePlugin & ToolProvider): void { + this.toolProviders.set(name, { plugin, name }); + } + + /** + * Register a plugin instance. + * Called by AppKit core after constructing each plugin. + */ + registerPlugin(name: string, instance: BasePlugin): void { + this.plugins.set(name, instance); + } + + /** + * Returns all registered plugin instances keyed by name. + * Used by the server plugin for route injection, client config, + * and shutdown coordination. + */ + getPlugins(): Map { + return this.plugins; + } + + /** + * Returns all registered ToolProvider plugins. + * Always returns the current set — not a frozen snapshot. + */ + getToolProviders(): Array<{ name: string; provider: ToolProvider }> { + return Array.from(this.toolProviders.values()).map((entry) => ({ + name: entry.name, + provider: entry.plugin, + })); + } + + /** + * Execute a tool on a ToolProvider plugin with automatic user scoping + * and telemetry. + * + * The context: + * 1. Resolves the plugin by name + * 2. Calls `asUser(req)` for user-scoped execution + * 3. Wraps the call in a telemetry span with a 30s timeout + */ + async executeTool( + req: express.Request, + pluginName: string, + toolName: string, + args: unknown, + signal?: AbortSignal, + ): Promise { + const entry = this.toolProviders.get(pluginName); + if (!entry) { + throw new Error( + `PluginContext: unknown plugin "${pluginName}". Available: ${Array.from(this.toolProviders.keys()).join(", ")}`, + ); + } + + const tracer = this.telemetry.getTracer(); + const operationName = `executeTool:${pluginName}.${toolName}`; + + return tracer.startActiveSpan(operationName, async (span) => { + const timeout = 30_000; + const timeoutSignal = AbortSignal.timeout(timeout); + const combinedSignal = signal + ? AbortSignal.any([signal, timeoutSignal]) + : timeoutSignal; + + try { + const userPlugin = (entry.plugin as any).asUser(req); + const result = await (userPlugin as ToolProvider).executeAgentTool( + toolName, + args, + combinedSignal, + ); + span.setStatus({ code: 0 }); + return result; + } catch (error) { + span.setStatus({ + code: 2, + message: + error instanceof Error ? error.message : "Tool execution failed", + }); + span.recordException( + error instanceof Error ? error : new Error(String(error)), + ); + throw error; + } finally { + span.end(); + } + }); + } + + /** + * Register a lifecycle hook callback. + */ + onLifecycle(event: LifecycleEvent, fn: () => void | Promise): void { + let hooks = this.lifecycleHooks.get(event); + if (!hooks) { + hooks = new Set(); + this.lifecycleHooks.set(event, hooks); + } + hooks.add(fn); + } + + /** + * Emit a lifecycle event, calling all registered callbacks. + * Errors in individual callbacks are logged but do not prevent + * other callbacks from running. + * + * @internal Called by AppKit core only. + */ + async emitLifecycle(event: LifecycleEvent): Promise { + const hooks = this.lifecycleHooks.get(event); + if (!hooks) return; + + if ( + event === "setup:complete" && + this.routeBuffer.length > 0 && + !this.routeTarget + ) { + logger.warn( + "%d buffered routes were never applied — no server plugin registered as route target", + this.routeBuffer.length, + ); + } + + for (const fn of hooks) { + try { + await fn(); + } catch (error) { + logger.error("Lifecycle hook '%s' failed: %O", event, error); + } + } + } + + /** + * Returns all registered plugin names. + */ + getPluginNames(): string[] { + return Array.from(this.plugins.keys()); + } + + /** + * Check if a plugin with the given name is registered. + */ + hasPlugin(name: string): boolean { + return this.plugins.has(name); + } + + private applyRoute(route: BufferedRoute): void { + if (!this.routeTarget) return; + this.routeTarget.addExtension((app) => { + const method = route.method.toLowerCase() as keyof express.Application; + if (typeof app[method] === "function") { + (app[method] as (...a: unknown[]) => void)( + route.path, + ...route.handlers, + ); + } + }); + } + + private applyMiddleware( + path: string, + handlers: express.RequestHandler[], + ): void { + if (!this.routeTarget) return; + this.routeTarget.addExtension((app) => { + app.use(path, ...handlers); + }); + } +} + +/** + * Type guard: checks whether a plugin implements the ToolProvider interface. + */ +export function isToolProvider( + plugin: unknown, +): plugin is BasePlugin & ToolProvider { + return ( + typeof plugin === "object" && + plugin !== null && + "getAgentTools" in plugin && + typeof (plugin as ToolProvider).getAgentTools === "function" && + "executeAgentTool" in plugin && + typeof (plugin as ToolProvider).executeAgentTool === "function" + ); +} diff --git a/packages/appkit/src/core/tests/databricks.test.ts b/packages/appkit/src/core/tests/databricks.test.ts index c05345a6..9d3fe5f8 100644 --- a/packages/appkit/src/core/tests/databricks.test.ts +++ b/packages/appkit/src/core/tests/databricks.test.ts @@ -109,11 +109,11 @@ class DeferredTestPlugin implements BasePlugin { name = "deferredTest"; setupCalled = false; injectedConfig: any; - injectedPlugins: any; + injectedContext: any; constructor(config: any) { this.injectedConfig = config; - this.injectedPlugins = config.plugins; + this.injectedContext = config.context; } async setup() { @@ -130,7 +130,7 @@ class DeferredTestPlugin implements BasePlugin { return { setupCalled: this.setupCalled, injectedConfig: this.injectedConfig, - injectedPlugins: this.injectedPlugins, + injectedContext: this.injectedContext, }; } } @@ -276,7 +276,7 @@ describe("AppKit", () => { expect(setupOrder).toEqual(["core", "normal", "deferred"]); }); - test("should provide plugin instances to deferred plugins", async () => { + test("should provide PluginContext to deferred plugins", async () => { const pluginData = [ { plugin: CoreTestPlugin, config: {}, name: "coreTest" }, { plugin: DeferredTestPlugin, config: {}, name: "deferredTest" }, @@ -284,10 +284,9 @@ describe("AppKit", () => { const instance = (await createApp({ plugins: pluginData })) as any; - // Deferred plugins receive plugin instances (not SDKs) for internal use - expect(instance.deferredTest.injectedPlugins).toBeDefined(); - expect(instance.deferredTest.injectedPlugins.coreTest).toBeInstanceOf( - CoreTestPlugin, + expect(instance.deferredTest.injectedContext).toBeDefined(); + expect(instance.deferredTest.injectedContext.hasPlugin("coreTest")).toBe( + true, ); }); diff --git a/packages/appkit/src/core/tests/plugin-context.test.ts b/packages/appkit/src/core/tests/plugin-context.test.ts new file mode 100644 index 00000000..276c5502 --- /dev/null +++ b/packages/appkit/src/core/tests/plugin-context.test.ts @@ -0,0 +1,325 @@ +import type { AgentToolDefinition } from "shared"; +import { beforeEach, describe, expect, test, vi } from "vitest"; +import { isToolProvider, PluginContext } from "../plugin-context"; + +vi.mock("../../telemetry", () => ({ + TelemetryManager: { + getProvider: () => ({ + getTracer: () => ({ + startActiveSpan: (_name: string, fn: (span: any) => any) => { + const span = { + setStatus: vi.fn(), + recordException: vi.fn(), + end: vi.fn(), + }; + return fn(span); + }, + }), + }), + }, +})); + +vi.mock("../../logging/logger", () => ({ + createLogger: () => ({ + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }), +})); + +function createMockToolProvider(tools: AgentToolDefinition[] = []) { + const mock = { + name: "mock-plugin", + setup: vi.fn().mockResolvedValue(undefined), + injectRoutes: vi.fn(), + getEndpoints: vi.fn().mockReturnValue({}), + getAgentTools: vi.fn().mockReturnValue(tools), + executeAgentTool: vi.fn().mockResolvedValue("tool-result"), + asUser: vi.fn().mockReturnThis(), + }; + return mock as any; +} + +describe("PluginContext", () => { + let ctx: PluginContext; + + beforeEach(() => { + ctx = new PluginContext(); + }); + + describe("route buffering", () => { + test("addRoute buffers when no route target exists", () => { + const handler = vi.fn(); + ctx.addRoute("post", "/invocations", handler); + + expect(ctx.getPluginNames()).toEqual([]); + }); + + test("flushRoutes applies buffered routes via addExtension", () => { + const handler = vi.fn(); + ctx.addRoute("post", "/invocations", handler); + + const addExtension = vi.fn(); + ctx.registerAsRouteTarget({ addExtension }); + + expect(addExtension).toHaveBeenCalledTimes(1); + const extensionFn = addExtension.mock.calls[0][0]; + + const mockApp = { post: vi.fn() }; + extensionFn(mockApp); + expect(mockApp.post).toHaveBeenCalledWith("/invocations", handler); + }); + + test("addRoute called after registerAsRouteTarget applies immediately", () => { + const addExtension = vi.fn(); + ctx.registerAsRouteTarget({ addExtension }); + + const handler = vi.fn(); + ctx.addRoute("get", "/health", handler); + + expect(addExtension).toHaveBeenCalledTimes(1); + const extensionFn = addExtension.mock.calls[0][0]; + + const mockApp = { get: vi.fn() }; + extensionFn(mockApp); + expect(mockApp.get).toHaveBeenCalledWith("/health", handler); + }); + + test("addRoute supports middleware chains", () => { + const auth = vi.fn(); + const handler = vi.fn(); + + const addExtension = vi.fn(); + ctx.registerAsRouteTarget({ addExtension }); + + ctx.addRoute("post", "/api", auth, handler); + + const extensionFn = addExtension.mock.calls[0][0]; + const mockApp = { post: vi.fn() }; + extensionFn(mockApp); + expect(mockApp.post).toHaveBeenCalledWith("/api", auth, handler); + }); + + test("addMiddleware buffers and applies via use()", () => { + const handler = vi.fn(); + ctx.addMiddleware("/api", handler); + + const addExtension = vi.fn(); + ctx.registerAsRouteTarget({ addExtension }); + + expect(addExtension).toHaveBeenCalledTimes(1); + const extensionFn = addExtension.mock.calls[0][0]; + + const mockApp = { use: vi.fn() }; + extensionFn(mockApp); + expect(mockApp.use).toHaveBeenCalledWith("/api", handler); + }); + + test("multiple buffered routes are all applied on registration", () => { + const h1 = vi.fn(); + const h2 = vi.fn(); + ctx.addRoute("post", "/a", h1); + ctx.addRoute("get", "/b", h2); + + const addExtension = vi.fn(); + ctx.registerAsRouteTarget({ addExtension }); + + expect(addExtension).toHaveBeenCalledTimes(2); + }); + }); + + describe("ToolProvider registry", () => { + test("registerToolProvider makes provider visible via getToolProviders", () => { + const provider = createMockToolProvider([ + { + name: "query", + description: "Run query", + parameters: { type: "object" }, + }, + ]); + + ctx.registerToolProvider("analytics", provider); + + const providers = ctx.getToolProviders(); + expect(providers).toHaveLength(1); + expect(providers[0].name).toBe("analytics"); + expect(providers[0].provider.getAgentTools()).toHaveLength(1); + }); + + test("getToolProviders returns all registered providers", () => { + ctx.registerToolProvider("analytics", createMockToolProvider()); + ctx.registerToolProvider("files", createMockToolProvider()); + ctx.registerToolProvider("genie", createMockToolProvider()); + + expect(ctx.getToolProviders()).toHaveLength(3); + }); + + test("getToolProviders returns current set, not snapshot", () => { + const before = ctx.getToolProviders(); + expect(before).toHaveLength(0); + + ctx.registerToolProvider("analytics", createMockToolProvider()); + + const after = ctx.getToolProviders(); + expect(after).toHaveLength(1); + }); + }); + + describe("executeTool", () => { + test("calls asUser(req).executeAgentTool on the correct plugin", async () => { + const provider = createMockToolProvider(); + ctx.registerToolProvider("analytics", provider); + + const mockReq = { headers: {} } as any; + await ctx.executeTool(mockReq, "analytics", "query", { sql: "SELECT 1" }); + + expect(provider.asUser).toHaveBeenCalledWith(mockReq); + expect(provider.executeAgentTool).toHaveBeenCalledWith( + "query", + { sql: "SELECT 1" }, + expect.any(Object), + ); + }); + + test("throws for unknown plugin name", async () => { + const mockReq = { headers: {} } as any; + + await expect( + ctx.executeTool(mockReq, "nonexistent", "query", {}), + ).rejects.toThrow('unknown plugin "nonexistent"'); + }); + + test("propagates tool execution errors", async () => { + const provider = createMockToolProvider(); + (provider.executeAgentTool as any).mockRejectedValue( + new Error("Query failed"), + ); + ctx.registerToolProvider("analytics", provider); + + const mockReq = { headers: {} } as any; + + await expect( + ctx.executeTool(mockReq, "analytics", "query", {}), + ).rejects.toThrow("Query failed"); + }); + + test("passes abort signal to executeAgentTool", async () => { + const provider = createMockToolProvider(); + ctx.registerToolProvider("analytics", provider); + + const controller = new AbortController(); + const mockReq = { headers: {} } as any; + + await ctx.executeTool( + mockReq, + "analytics", + "query", + {}, + controller.signal, + ); + + const callArgs = (provider.executeAgentTool as any).mock.calls[0]; + expect(callArgs[2]).toBeDefined(); + }); + }); + + describe("lifecycle hooks", () => { + test("onLifecycle registers callback, emitLifecycle invokes it", async () => { + const fn = vi.fn(); + ctx.onLifecycle("setup:complete", fn); + + await ctx.emitLifecycle("setup:complete"); + + expect(fn).toHaveBeenCalledTimes(1); + }); + + test("multiple callbacks for the same event all fire", async () => { + const fn1 = vi.fn(); + const fn2 = vi.fn(); + ctx.onLifecycle("setup:complete", fn1); + ctx.onLifecycle("setup:complete", fn2); + + await ctx.emitLifecycle("setup:complete"); + + expect(fn1).toHaveBeenCalledTimes(1); + expect(fn2).toHaveBeenCalledTimes(1); + }); + + test("callback error does not prevent other callbacks from running", async () => { + const fn1 = vi.fn().mockRejectedValue(new Error("fail")); + const fn2 = vi.fn(); + ctx.onLifecycle("shutdown", fn1); + ctx.onLifecycle("shutdown", fn2); + + await ctx.emitLifecycle("shutdown"); + + expect(fn1).toHaveBeenCalled(); + expect(fn2).toHaveBeenCalled(); + }); + + test("emitLifecycle with no registered hooks does nothing", async () => { + await expect(ctx.emitLifecycle("server:ready")).resolves.toBeUndefined(); + }); + }); + + describe("plugin metadata", () => { + const stubPlugin = { name: "stub" } as any; + + test("getPluginNames returns all registered names", () => { + ctx.registerPlugin("analytics", stubPlugin); + ctx.registerPlugin("server", stubPlugin); + ctx.registerPlugin("agent", stubPlugin); + + const names = ctx.getPluginNames(); + expect(names).toContain("analytics"); + expect(names).toContain("server"); + expect(names).toContain("agent"); + expect(names).toHaveLength(3); + }); + + test("hasPlugin returns true for registered plugins", () => { + ctx.registerPlugin("analytics", stubPlugin); + + expect(ctx.hasPlugin("analytics")).toBe(true); + expect(ctx.hasPlugin("nonexistent")).toBe(false); + }); + + test("getPlugins returns all registered instances", () => { + const p1 = { name: "analytics" } as any; + const p2 = { name: "server" } as any; + ctx.registerPlugin("analytics", p1); + ctx.registerPlugin("server", p2); + + const plugins = ctx.getPlugins(); + expect(plugins.size).toBe(2); + expect(plugins.get("analytics")).toBe(p1); + expect(plugins.get("server")).toBe(p2); + }); + }); +}); + +describe("isToolProvider", () => { + test("returns true for objects with getAgentTools and executeAgentTool", () => { + const provider = createMockToolProvider(); + expect(isToolProvider(provider)).toBe(true); + }); + + test("returns false for null", () => { + expect(isToolProvider(null)).toBe(false); + }); + + test("returns false for objects missing executeAgentTool", () => { + expect(isToolProvider({ getAgentTools: vi.fn() })).toBe(false); + }); + + test("returns false for objects missing getAgentTools", () => { + expect(isToolProvider({ executeAgentTool: vi.fn() })).toBe(false); + }); + + test("returns false for non-objects", () => { + expect(isToolProvider("string")).toBe(false); + expect(isToolProvider(42)).toBe(false); + expect(isToolProvider(undefined)).toBe(false); + }); +}); diff --git a/packages/appkit/src/plugin/plugin.ts b/packages/appkit/src/plugin/plugin.ts index 5173cb61..1581dc53 100644 --- a/packages/appkit/src/plugin/plugin.ts +++ b/packages/appkit/src/plugin/plugin.ts @@ -19,6 +19,7 @@ import { ServiceContext, type UserContext, } from "../context"; +import type { PluginContext } from "../core/plugin-context"; import { AppKitError, AuthenticationError } from "../errors"; import { createLogger } from "../logging/logger"; import { StreamManager } from "../stream"; @@ -168,6 +169,7 @@ export abstract class Plugin< protected devFileReader: DevFileReader; protected streamManager: StreamManager; protected telemetry: ITelemetry; + protected context?: PluginContext; /** Registered endpoints for this plugin */ private registeredEndpoints: PluginEndpointMap = {}; @@ -198,6 +200,9 @@ export abstract class Plugin< this.cache = CacheManager.getInstanceSync(); this.app = new AppManager(); this.devFileReader = DevFileReader.getInstance(); + this.context = (config as Record).context as + | PluginContext + | undefined; this.isReady = true; } diff --git a/packages/appkit/src/plugins/agent/agent.ts b/packages/appkit/src/plugins/agent/agent.ts index 15827cad..9873291c 100644 --- a/packages/appkit/src/plugins/agent/agent.ts +++ b/packages/appkit/src/plugins/agent/agent.ts @@ -9,7 +9,6 @@ import type { Message, PluginPhase, ResponseStreamEvent, - ToolProvider, } from "shared"; import { createLogger } from "../../logging/logger"; import { Plugin, toPlugin } from "../../plugin"; @@ -33,17 +32,6 @@ import type { AgentPluginConfig, RegisteredAgent, ToolEntry } from "./types"; const logger = createLogger("agent"); -function isToolProvider(obj: unknown): obj is ToolProvider { - return ( - typeof obj === "object" && - obj !== null && - "getAgentTools" in obj && - typeof (obj as any).getAgentTools === "function" && - "executeAgentTool" in obj && - typeof (obj as any).executeAgentTool === "function" - ); -} - export class AgentPlugin extends Plugin { static manifest = manifest as PluginManifest<"agent">; static phase: PluginPhase = "deferred"; @@ -161,40 +149,32 @@ export class AgentPlugin extends Plugin { } private mountInvocationsRoute() { - const serverPlugin = this.config.plugins?.server as - | { addExtension?: (fn: (app: any) => void) => void } - | undefined; - - if (!serverPlugin?.addExtension) return; + if (!this.context) return; - serverPlugin.addExtension((app: import("express").Application) => { - app.post( - "/invocations", - (req: express.Request, res: express.Response) => { - this._handleInvocations(req, res); - }, - ); - }); + this.context.addRoute( + "post", + "/invocations", + (req: express.Request, res: express.Response) => { + this._handleInvocations(req, res); + }, + ); - logger.info("Mounted POST /invocations route"); + logger.info("Registered POST /invocations route via PluginContext"); } private async collectTools() { - // 1. Auto-discover from sibling ToolProvider plugins - const plugins = this.config.plugins; - if (plugins) { - for (const [pluginName, pluginInstance] of Object.entries(plugins)) { - if (pluginName === "agent") continue; - if (!isToolProvider(pluginInstance)) continue; - - const tools = (pluginInstance as ToolProvider).getAgentTools(); + // 1. Auto-discover from sibling ToolProvider plugins via PluginContext + if (this.context) { + for (const { + name: pluginName, + provider, + } of this.context.getToolProviders()) { + const tools = provider.getAgentTools(); for (const tool of tools) { const qualifiedName = `${pluginName}.${tool.name}`; this.toolIndex.set(qualifiedName, { source: "plugin", - plugin: pluginInstance as ToolProvider & { - asUser(req: any): any; - }, + pluginName, def: { ...tool, name: qualifiedName }, localName: tool.name, }); @@ -473,41 +453,51 @@ export class AgentPlugin extends Plugin { const entry = self.toolIndex.get(qualifiedName); if (!entry) throw new Error(`Unknown tool: ${qualifiedName}`); - const result = await self.execute( - async (execSignal) => { - switch (entry.source) { - case "plugin": { - const target = (entry.plugin as any).asUser(req); - return (target as ToolProvider).executeAgentTool( - entry.localName, - args, - execSignal, - ); - } - case "function": - return entry.functionTool.execute( - args as Record, - ); - case "mcp": { - if (!self.mcpClient) { - throw new Error("MCP client not connected"); + let result: unknown; + + if (entry.source === "plugin" && self.context) { + result = await self.context.executeTool( + req, + entry.pluginName, + entry.localName, + args, + signal, + ); + } else { + result = await self.execute( + async (_execSignal) => { + switch (entry.source) { + case "plugin": + throw new Error("Plugin tool execution requires PluginContext"); + case "function": + return entry.functionTool.execute( + args as Record, + ); + case "mcp": { + if (!self.mcpClient) { + throw new Error("MCP client not connected"); + } + const oboToken = req.headers["x-forwarded-access-token"]; + const mcpAuth = + typeof oboToken === "string" + ? { Authorization: `Bearer ${oboToken}` } + : undefined; + return self.mcpClient.callTool( + entry.mcpToolName, + args, + mcpAuth, + ); } - const oboToken = req.headers["x-forwarded-access-token"]; - const mcpAuth = - typeof oboToken === "string" - ? { Authorization: `Bearer ${oboToken}` } - : undefined; - return self.mcpClient.callTool(entry.mcpToolName, args, mcpAuth); } - } - }, - { - default: { - telemetryInterceptor: { enabled: true }, - timeout: 30_000, }, - }, - ); + { + default: { + telemetryInterceptor: { enabled: true }, + timeout: 30_000, + }, + }, + ); + } if (result === undefined) { return `Error: Tool "${qualifiedName}" execution failed`; @@ -537,10 +527,10 @@ export class AgentPlugin extends Plugin { yield evt; } - const pluginNames = self.config.plugins - ? Object.keys(self.config.plugins).filter( - (n) => n !== "agent" && n !== "server", - ) + const pluginNames = self.context + ? self.context + .getPluginNames() + .filter((n) => n !== "agent" && n !== "server") : []; const basePrompt = buildBaseSystemPrompt(pluginNames); const fullPrompt = composeSystemPrompt( diff --git a/packages/appkit/src/plugins/agent/tests/agent.test.ts b/packages/appkit/src/plugins/agent/tests/agent.test.ts index 357d68d0..128e6157 100644 --- a/packages/appkit/src/plugins/agent/tests/agent.test.ts +++ b/packages/appkit/src/plugins/agent/tests/agent.test.ts @@ -13,6 +13,7 @@ import type { ToolProvider, } from "shared"; import { beforeEach, describe, expect, test, vi } from "vitest"; +import { PluginContext } from "../../../core/plugin-context"; import { AgentPlugin } from "../agent"; vi.mock("../../../cache", () => ({ @@ -42,7 +43,16 @@ vi.mock("../../../context", async (importOriginal) => { vi.mock("../../../telemetry", () => ({ TelemetryManager: { getProvider: vi.fn(() => ({ - getTracer: vi.fn(), + getTracer: vi.fn(() => ({ + startActiveSpan: (_name: string, fn: (span: any) => any) => { + const span = { + setStatus: vi.fn(), + recordException: vi.fn(), + end: vi.fn(), + }; + return fn(span); + }, + })), getMeter: vi.fn(), getLogger: vi.fn(), emit: vi.fn(), @@ -61,10 +71,25 @@ function createMockToolProvider( tools: AgentToolDefinition[], ): ToolProvider & { asUser: any } { return { + name: "mock-plugin", getAgentTools: () => tools, executeAgentTool: vi.fn().mockResolvedValue({ result: "ok" }), asUser: vi.fn().mockReturnThis(), - }; + } as any; +} + +function createMockContext( + providers: Array<{ + name: string; + provider: ToolProvider & { asUser: any }; + }> = [], +): PluginContext { + const ctx = new PluginContext(); + for (const { name, provider } of providers) { + ctx.registerToolProvider(name, provider as any); + ctx.registerPlugin(name, provider as any); + } + return ctx; } async function* mockAdapterRun(): AsyncGenerator { @@ -83,7 +108,7 @@ describe("AgentPlugin", () => { setupDatabricksEnv(); }); - test("collectTools discovers ToolProvider plugins", async () => { + test("collectTools discovers ToolProvider plugins via context", async () => { const mockProvider = createMockToolProvider([ { name: "query", @@ -92,10 +117,14 @@ describe("AgentPlugin", () => { }, ]); + const context = createMockContext([ + { name: "analytics", provider: mockProvider }, + ]); + const plugin = new AgentPlugin({ name: "agent", - plugins: { analytics: mockProvider }, - }); + context, + } as any); await plugin.setup(); @@ -106,20 +135,14 @@ describe("AgentPlugin", () => { expect(tools[0].name).toBe("analytics.query"); }); - test("skips non-ToolProvider plugins", async () => { + test("works with no context (backward compat)", async () => { const plugin = new AgentPlugin({ name: "agent", - plugins: { - server: { name: "server" }, - analytics: createMockToolProvider([ - { name: "query", description: "q", parameters: { type: "object" } }, - ]), - }, }); await plugin.setup(); const tools = plugin.exports().getTools(); - expect(tools).toHaveLength(1); + expect(tools).toEqual([]); }); test("registerAgent and resolveAgent", () => { @@ -128,7 +151,6 @@ describe("AgentPlugin", () => { plugin.exports().registerAgent("assistant", adapter); - // The first registered agent becomes the default const tools = plugin.exports().getTools(); expect(tools).toEqual([]); }); @@ -177,15 +199,39 @@ describe("AgentPlugin", () => { expect(tools[0].name).toBe("myTool"); }); - test("executeTool always calls asUser(req) for plugin tools, even without requiresUserContext", async () => { + test("mountInvocationsRoute registers via context.addRoute", async () => { + const context = createMockContext(); + const addRouteSpy = vi.spyOn(context, "addRoute"); + + const plugin = new AgentPlugin({ + name: "agent", + agents: { assistant: createMockAdapter() }, + context, + } as any); + + await plugin.setup(); + + expect(addRouteSpy).toHaveBeenCalledWith( + "post", + "/invocations", + expect.any(Function), + ); + }); + + test("executeTool calls context.executeTool for plugin tools", async () => { const mockProvider = createMockToolProvider([ { name: "action", - description: "An action without requiresUserContext", + description: "An action", parameters: { type: "object", properties: {} }, }, ]); + const context = createMockContext([ + { name: "testplugin", provider: mockProvider }, + ]); + const executeToolSpy = vi.spyOn(context, "executeTool"); + function createToolCallingAdapter(): AgentAdapter { return { async *run( @@ -201,8 +247,8 @@ describe("AgentPlugin", () => { const plugin = new AgentPlugin({ name: "agent", agents: { assistant: createToolCallingAdapter() }, - plugins: { testplugin: mockProvider }, - }); + context, + } as any); await plugin.setup(); const { router, getHandler } = createMockRouter(); @@ -220,8 +266,9 @@ describe("AgentPlugin", () => { await handler(req, res); - expect(mockProvider.asUser).toHaveBeenCalledWith(req); - expect(mockProvider.executeAgentTool).toHaveBeenCalledWith( + expect(executeToolSpy).toHaveBeenCalledWith( + req, + "testplugin", "action", {}, expect.anything(), diff --git a/packages/appkit/src/plugins/agent/types.ts b/packages/appkit/src/plugins/agent/types.ts index 67cc2c8b..5ab8fc91 100644 --- a/packages/appkit/src/plugins/agent/types.ts +++ b/packages/appkit/src/plugins/agent/types.ts @@ -3,7 +3,6 @@ import type { AgentToolDefinition, BasePluginConfig, ThreadStore, - ToolProvider, } from "shared"; import type { FunctionTool } from "./tools/function-tool"; import type { HostedTool } from "./tools/hosted-tools"; @@ -29,7 +28,7 @@ export interface AgentPluginConfig extends BasePluginConfig { export type ToolEntry = | { source: "plugin"; - plugin: ToolProvider & { asUser(req: any): any }; + pluginName: string; def: AgentToolDefinition; localName: string; } @@ -50,8 +49,4 @@ export type RegisteredAgent = { systemPrompt?: string; }; -export type { - AgentAdapter, - AgentToolDefinition, - ToolProvider, -} from "shared"; +export type { AgentAdapter, AgentToolDefinition } from "shared"; diff --git a/packages/appkit/src/plugins/server/index.ts b/packages/appkit/src/plugins/server/index.ts index 75d3e1d0..b6efd016 100644 --- a/packages/appkit/src/plugins/server/index.ts +++ b/packages/appkit/src/plugins/server/index.ts @@ -63,6 +63,7 @@ export class ServerPlugin extends Plugin { instrumentations.http, instrumentations.express, ]); + this.context?.registerAsRouteTarget(this); } /** Setup the server plugin. */ @@ -203,14 +204,15 @@ export class ServerPlugin extends Plugin { const endpoints: PluginEndpoints = {}; const pluginConfigs: PluginClientConfigs = {}; - if (!this.config.plugins) return { endpoints, pluginConfigs }; + const plugins = this.context?.getPlugins(); + if (!plugins || plugins.size === 0) return { endpoints, pluginConfigs }; this.serverApplication.get("/health", (_, res) => { res.status(200).json({ status: "ok" }); }); this.registerEndpoint("health", "/health"); - for (const plugin of Object.values(this.config.plugins)) { + for (const plugin of plugins.values()) { if (EXCLUDED_PLUGINS.includes(plugin.name)) continue; if (plugin?.injectRoutes && typeof plugin.injectRoutes === "function") { @@ -359,8 +361,9 @@ export class ServerPlugin extends Plugin { } // 1. abort active operations from plugins - if (this.config.plugins) { - for (const plugin of Object.values(this.config.plugins)) { + const shutdownPlugins = this.context?.getPlugins(); + if (shutdownPlugins) { + for (const plugin of shutdownPlugins.values()) { if (plugin.abortActiveOperations) { try { plugin.abortActiveOperations(); diff --git a/packages/appkit/src/plugins/server/tests/server.test.ts b/packages/appkit/src/plugins/server/tests/server.test.ts index 22f18129..52d15845 100644 --- a/packages/appkit/src/plugins/server/tests/server.test.ts +++ b/packages/appkit/src/plugins/server/tests/server.test.ts @@ -1,4 +1,6 @@ +import type { BasePlugin } from "shared"; import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { PluginContext } from "../../../core/plugin-context"; // Use vi.hoisted for mocks that need to be available before module loading const { @@ -171,6 +173,14 @@ import { RemoteTunnelController } from "../remote-tunnel/remote-tunnel-controlle import { StaticServer } from "../static-server"; import { ViteDevServer } from "../vite-dev-server"; +function createContextWithPlugins(plugins: Record): PluginContext { + const ctx = new PluginContext(); + for (const [name, instance] of Object.entries(plugins)) { + ctx.registerPlugin(name, instance as BasePlugin); + } + return ctx; +} + describe("ServerPlugin", () => { let originalEnv: NodeJS.ProcessEnv; @@ -340,7 +350,7 @@ describe("ServerPlugin", () => { process.env.NODE_ENV = "production"; const injectRoutes = vi.fn(); - const plugins: any = { + const testPlugins: any = { "test-plugin": { name: "test-plugin", injectRoutes, @@ -348,7 +358,10 @@ describe("ServerPlugin", () => { }, }; - const plugin = new ServerPlugin({ autoStart: false, plugins }); + const plugin = new ServerPlugin({ + autoStart: false, + context: createContextWithPlugins(testPlugins), + } as any); await plugin.start(); const routerFn = (express as any).Router as ReturnType; @@ -386,7 +399,10 @@ describe("ServerPlugin", () => { }, }; - const plugin = new ServerPlugin({ autoStart: false, plugins }); + const plugin = new ServerPlugin({ + autoStart: false, + context: createContextWithPlugins(plugins), + } as any); await plugin.start(); expect(plugins["plugin-a"].clientConfig).toHaveBeenCalled(); @@ -413,7 +429,10 @@ describe("ServerPlugin", () => { }, }; - const plugin = new ServerPlugin({ autoStart: false, plugins }); + const plugin = new ServerPlugin({ + autoStart: false, + context: createContextWithPlugins(plugins), + } as any); await plugin.start(); expect(plugins["plugin-null"].clientConfig).toHaveBeenCalled(); @@ -444,7 +463,10 @@ describe("ServerPlugin", () => { }, }; - const plugin = new ServerPlugin({ autoStart: false, plugins }); + const plugin = new ServerPlugin({ + autoStart: false, + context: createContextWithPlugins(plugins), + } as any); await expect(plugin.start()).resolves.toBeDefined(); expect(mockLoggerError).toHaveBeenCalledWith( "Plugin '%s' clientConfig() failed, skipping its config: %O", @@ -608,19 +630,19 @@ describe("ServerPlugin", () => { const plugin = new ServerPlugin({ autoStart: false, - plugins: { + context: createContextWithPlugins({ ok: { name: "ok", abortActiveOperations: vi.fn(), - } as any, + }, bad: { name: "bad", abortActiveOperations: vi.fn(() => { throw new Error("boom"); }), - } as any, - }, - }); + }, + }), + } as any); // pretend started (plugin as any).server = mockHttpServer; diff --git a/packages/appkit/src/plugins/server/types.ts b/packages/appkit/src/plugins/server/types.ts index e187cacc..84a2327e 100644 --- a/packages/appkit/src/plugins/server/types.ts +++ b/packages/appkit/src/plugins/server/types.ts @@ -1,9 +1,7 @@ import type { BasePluginConfig } from "shared"; -import type { Plugin } from "../../plugin"; export interface ServerConfig extends BasePluginConfig { port?: number; - plugins?: Record; staticPath?: string; autoStart?: boolean; host?: string;