diff --git a/src/transports/base.ts b/src/transports/base.ts index 4643a31d..f9b00f75 100644 --- a/src/transports/base.ts +++ b/src/transports/base.ts @@ -22,6 +22,8 @@ import type { Client } from "@mongodb-js/atlas-local"; import { VectorSearchEmbeddingsManager } from "../common/search/vectorSearchEmbeddingsManager.js"; import type { ToolBase, ToolConstructorParams } from "../tools/tool.js"; +type CreateSessionConfigFn = (userConfig: UserConfig) => Promise | UserConfig; + export type TransportRunnerConfig = { userConfig: UserConfig; createConnectionManager?: ConnectionManagerFactoryFn; @@ -30,6 +32,11 @@ export type TransportRunnerConfig = { additionalLoggers?: LoggerBase[]; telemetryProperties?: Partial; tools?: (new (params: ToolConstructorParams) => ToolBase)[]; + /** + * Hook which allows library consumers to fetch configuration from external sources (e.g., secrets managers, APIs) + * or modify the existing configuration before the session is created. + */ + createSessionConfig?: CreateSessionConfigFn; }; export abstract class TransportRunnerBase { @@ -41,6 +48,7 @@ export abstract class TransportRunnerBase { private readonly atlasLocalClient: Promise; private readonly telemetryProperties: Partial; private readonly tools?: (new (params: ToolConstructorParams) => ToolBase)[]; + private readonly createSessionConfig?: CreateSessionConfigFn; protected constructor({ userConfig, @@ -50,6 +58,7 @@ export abstract class TransportRunnerBase { additionalLoggers = [], telemetryProperties = {}, tools, + createSessionConfig, }: TransportRunnerConfig) { this.userConfig = userConfig; this.createConnectionManager = createConnectionManager; @@ -57,6 +66,7 @@ export abstract class TransportRunnerBase { this.atlasLocalClient = createAtlasLocalClient(); this.telemetryProperties = telemetryProperties; this.tools = tools; + this.createSessionConfig = createSessionConfig; const loggers: LoggerBase[] = [...additionalLoggers]; if (this.userConfig.loggers.includes("stderr")) { loggers.push(new ConsoleLogger(Keychain.root)); @@ -81,30 +91,34 @@ export abstract class TransportRunnerBase { } protected async setupServer(): Promise { + // Call the config provider hook if provided, allowing consumers to + // fetch or modify configuration before session initialization + const userConfig = this.createSessionConfig ? await this.createSessionConfig(this.userConfig) : this.userConfig; + const mcpServer = new McpServer({ name: packageInfo.mcpServerName, version: packageInfo.version, }); const logger = new CompositeLogger(this.logger); - const exportsManager = ExportsManager.init(this.userConfig, logger); + const exportsManager = ExportsManager.init(userConfig, logger); const connectionManager = await this.createConnectionManager({ logger, - userConfig: this.userConfig, + userConfig, deviceId: this.deviceId, }); const session = new Session({ - userConfig: this.userConfig, + userConfig, atlasLocalClient: await this.atlasLocalClient, logger, exportsManager, connectionManager, keychain: Keychain.root, - vectorSearchEmbeddingsManager: new VectorSearchEmbeddingsManager(this.userConfig, connectionManager), + vectorSearchEmbeddingsManager: new VectorSearchEmbeddingsManager(userConfig, connectionManager), }); - const telemetry = Telemetry.create(session, this.userConfig, this.deviceId, { + const telemetry = Telemetry.create(session, userConfig, this.deviceId, { commonProperties: this.telemetryProperties, }); @@ -114,7 +128,7 @@ export abstract class TransportRunnerBase { mcpServer, session, telemetry, - userConfig: this.userConfig, + userConfig, connectionErrorHandler: this.connectionErrorHandler, elicitation, tools: this.tools, @@ -122,7 +136,7 @@ export abstract class TransportRunnerBase { // We need to create the MCP logger after the server is constructed // because it needs the server instance - if (this.userConfig.loggers.includes("mcp")) { + if (userConfig.loggers.includes("mcp")) { logger.addLogger(new McpLogger(result, Keychain.root)); } diff --git a/tests/integration/transports/createSessionConfig.test.ts b/tests/integration/transports/createSessionConfig.test.ts new file mode 100644 index 00000000..a0b72dcd --- /dev/null +++ b/tests/integration/transports/createSessionConfig.test.ts @@ -0,0 +1,151 @@ +import { StreamableHttpRunner } from "../../../src/transports/streamableHttp.js"; +import { Client } from "@modelcontextprotocol/sdk/client/index.js"; +import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; +import { describe, expect, it } from "vitest"; +import type { TransportRunnerConfig } from "../../../src/lib.js"; +import { defaultTestConfig } from "../helpers.js"; + +describe("createSessionConfig", () => { + const userConfig = defaultTestConfig; + let runner: StreamableHttpRunner; + + describe("basic functionality", () => { + it("should use the modified config from createSessionConfig", async () => { + const createSessionConfig: TransportRunnerConfig["createSessionConfig"] = async (userConfig) => { + return Promise.resolve({ + ...userConfig, + apiBaseUrl: "https://test-api.mongodb.com/", + }); + }; + userConfig.httpPort = 0; // Use a random port + runner = new StreamableHttpRunner({ + userConfig, + createSessionConfig, + }); + await runner.start(); + + const server = await runner["setupServer"](); + expect(server.userConfig.apiBaseUrl).toBe("https://test-api.mongodb.com/"); + + await runner.close(); + }); + + it("should work without a createSessionConfig", async () => { + userConfig.httpPort = 0; // Use a random port + runner = new StreamableHttpRunner({ + userConfig, + }); + await runner.start(); + + const server = await runner["setupServer"](); + expect(server.userConfig.apiBaseUrl).toBe(userConfig.apiBaseUrl); + + await runner.close(); + }); + }); + + describe("connection string modification", () => { + it("should allow modifying connection string via createSessionConfig", async () => { + const createSessionConfig: TransportRunnerConfig["createSessionConfig"] = async (userConfig) => { + // Simulate fetching connection string from environment or secrets + await new Promise((resolve) => setTimeout(resolve, 10)); + + return { + ...userConfig, + connectionString: "mongodb://test-server:27017/test-db", + }; + }; + + userConfig.httpPort = 0; // Use a random port + runner = new StreamableHttpRunner({ + userConfig: { ...userConfig, connectionString: undefined }, + createSessionConfig, + }); + await runner.start(); + + const server = await runner["setupServer"](); + expect(server.userConfig.connectionString).toBe("mongodb://test-server:27017/test-db"); + + await runner.close(); + }); + }); + + describe("server integration", () => { + let client: Client; + let transport: StreamableHTTPClientTransport; + + it("should successfully initialize server with createSessionConfig and serve requests", async () => { + const createSessionConfig: TransportRunnerConfig["createSessionConfig"] = async (userConfig) => { + // Simulate async config modification + await new Promise((resolve) => setTimeout(resolve, 10)); + return { + ...userConfig, + readOnly: true, // Enable read-only mode + }; + }; + + userConfig.httpPort = 0; // Use a random port + runner = new StreamableHttpRunner({ + userConfig, + createSessionConfig, + }); + await runner.start(); + + client = new Client({ + name: "test-client", + version: "1.0.0", + }); + transport = new StreamableHTTPClientTransport(new URL(`${runner.serverAddress}/mcp`)); + + await client.connect(transport); + const response = await client.listTools(); + + expect(response).toBeDefined(); + expect(response.tools).toBeDefined(); + expect(response.tools.length).toBeGreaterThan(0); + + // Verify read-only mode is applied - insert-one should not be available + const writeTools = response.tools.filter((tool) => tool.name === "insert-one"); + expect(writeTools.length).toBe(0); + + // Verify read tools are available + const readTools = response.tools.filter((tool) => tool.name === "find"); + expect(readTools.length).toBe(1); + + await client.close(); + await transport.close(); + await runner.close(); + }); + }); + + describe("error handling", () => { + it("should propagate errors from configProvider on client connection", async () => { + const createSessionConfig: TransportRunnerConfig["createSessionConfig"] = async () => { + return Promise.reject(new Error("Failed to fetch config")); + }; + + userConfig.httpPort = 0; // Use a random port + runner = new StreamableHttpRunner({ + userConfig, + createSessionConfig, + }); + + // Start succeeds because setupServer is only called when a client connects + await runner.start(); + + // Error should occur when a client tries to connect + const testClient = new Client({ + name: "test-client", + version: "1.0.0", + }); + const testTransport = new StreamableHTTPClientTransport(new URL(`${runner.serverAddress}/mcp`)); + + await expect(testClient.connect(testTransport)).rejects.toThrow(); + + await testClient.close(); + await testTransport.close(); + + await runner.close(); + }); + }); +});