Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions src/transports/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import type { Client } from "@mongodb-js/atlas-local";
import { VectorSearchEmbeddingsManager } from "../common/search/vectorSearchEmbeddingsManager.js";
import type { ToolBase, ToolConstructorParams } from "../tools/tool.js";

type CreateSessionConfigFn = (userConfig: UserConfig) => Promise<UserConfig> | UserConfig;

export type TransportRunnerConfig = {
userConfig: UserConfig;
createConnectionManager?: ConnectionManagerFactoryFn;
Expand All @@ -30,6 +32,11 @@ export type TransportRunnerConfig = {
additionalLoggers?: LoggerBase[];
telemetryProperties?: Partial<CommonProperties>;
tools?: (new (params: ToolConstructorParams) => ToolBase)[];
/**
* Hook which allows library consumers to fetch configuration from external sources (e.g., secrets managers, APIs)
* or modify the existing configuration before the session is created.
*/
createSessionConfig?: CreateSessionConfigFn;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not too sure if I understood the usecase for this correctly so let me know if this should be used differently.

};

export abstract class TransportRunnerBase {
Expand All @@ -41,6 +48,7 @@ export abstract class TransportRunnerBase {
private readonly atlasLocalClient: Promise<Client | undefined>;
private readonly telemetryProperties: Partial<CommonProperties>;
private readonly tools?: (new (params: ToolConstructorParams) => ToolBase)[];
private readonly createSessionConfig?: CreateSessionConfigFn;

protected constructor({
userConfig,
Expand All @@ -50,13 +58,15 @@ export abstract class TransportRunnerBase {
additionalLoggers = [],
telemetryProperties = {},
tools,
createSessionConfig,
}: TransportRunnerConfig) {
this.userConfig = userConfig;
this.createConnectionManager = createConnectionManager;
this.connectionErrorHandler = connectionErrorHandler;
this.atlasLocalClient = createAtlasLocalClient();
this.telemetryProperties = telemetryProperties;
this.tools = tools;
this.createSessionConfig = createSessionConfig;
const loggers: LoggerBase[] = [...additionalLoggers];
if (this.userConfig.loggers.includes("stderr")) {
loggers.push(new ConsoleLogger(Keychain.root));
Expand All @@ -81,30 +91,34 @@ export abstract class TransportRunnerBase {
}

protected async setupServer(): Promise<Server> {
// Call the config provider hook if provided, allowing consumers to
// fetch or modify configuration before session initialization
const userConfig = this.createSessionConfig ? await this.createSessionConfig(this.userConfig) : this.userConfig;

const mcpServer = new McpServer({
name: packageInfo.mcpServerName,
version: packageInfo.version,
});

const logger = new CompositeLogger(this.logger);
const exportsManager = ExportsManager.init(this.userConfig, logger);
const exportsManager = ExportsManager.init(userConfig, logger);
const connectionManager = await this.createConnectionManager({
logger,
userConfig: this.userConfig,
userConfig,
deviceId: this.deviceId,
});

const session = new Session({
userConfig: this.userConfig,
userConfig,
atlasLocalClient: await this.atlasLocalClient,
logger,
exportsManager,
connectionManager,
keychain: Keychain.root,
vectorSearchEmbeddingsManager: new VectorSearchEmbeddingsManager(this.userConfig, connectionManager),
vectorSearchEmbeddingsManager: new VectorSearchEmbeddingsManager(userConfig, connectionManager),
});

const telemetry = Telemetry.create(session, this.userConfig, this.deviceId, {
const telemetry = Telemetry.create(session, userConfig, this.deviceId, {
commonProperties: this.telemetryProperties,
});

Expand All @@ -114,15 +128,15 @@ export abstract class TransportRunnerBase {
mcpServer,
session,
telemetry,
userConfig: this.userConfig,
userConfig,
connectionErrorHandler: this.connectionErrorHandler,
elicitation,
tools: this.tools,
});

// We need to create the MCP logger after the server is constructed
// because it needs the server instance
if (this.userConfig.loggers.includes("mcp")) {
if (userConfig.loggers.includes("mcp")) {
logger.addLogger(new McpLogger(result, Keychain.root));
}

Expand Down
151 changes: 151 additions & 0 deletions tests/integration/transports/createSessionConfig.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import { StreamableHttpRunner } from "../../../src/transports/streamableHttp.js";
import { Client } from "@modelcontextprotocol/sdk/client/index.js";
import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js";
import { describe, expect, it } from "vitest";
import type { TransportRunnerConfig } from "../../../src/lib.js";
import { defaultTestConfig } from "../helpers.js";

describe("createSessionConfig", () => {
const userConfig = defaultTestConfig;
let runner: StreamableHttpRunner;

describe("basic functionality", () => {
it("should use the modified config from createSessionConfig", async () => {
const createSessionConfig: TransportRunnerConfig["createSessionConfig"] = async (userConfig) => {
return Promise.resolve({
...userConfig,
apiBaseUrl: "https://test-api.mongodb.com/",
});
};
userConfig.httpPort = 0; // Use a random port
runner = new StreamableHttpRunner({
userConfig,
createSessionConfig,
});
await runner.start();

const server = await runner["setupServer"]();
expect(server.userConfig.apiBaseUrl).toBe("https://test-api.mongodb.com/");

await runner.close();
});

it("should work without a createSessionConfig", async () => {
userConfig.httpPort = 0; // Use a random port
runner = new StreamableHttpRunner({
userConfig,
});
await runner.start();

const server = await runner["setupServer"]();
expect(server.userConfig.apiBaseUrl).toBe(userConfig.apiBaseUrl);

await runner.close();
});
});

describe("connection string modification", () => {
it("should allow modifying connection string via createSessionConfig", async () => {
const createSessionConfig: TransportRunnerConfig["createSessionConfig"] = async (userConfig) => {
// Simulate fetching connection string from environment or secrets
await new Promise((resolve) => setTimeout(resolve, 10));

return {
...userConfig,
connectionString: "mongodb://test-server:27017/test-db",
};
};

userConfig.httpPort = 0; // Use a random port
runner = new StreamableHttpRunner({
userConfig: { ...userConfig, connectionString: undefined },
createSessionConfig,
});
await runner.start();

const server = await runner["setupServer"]();
expect(server.userConfig.connectionString).toBe("mongodb://test-server:27017/test-db");

await runner.close();
});
});

describe("server integration", () => {
let client: Client;
let transport: StreamableHTTPClientTransport;

it("should successfully initialize server with createSessionConfig and serve requests", async () => {
const createSessionConfig: TransportRunnerConfig["createSessionConfig"] = async (userConfig) => {
// Simulate async config modification
await new Promise((resolve) => setTimeout(resolve, 10));
return {
...userConfig,
readOnly: true, // Enable read-only mode
};
};

userConfig.httpPort = 0; // Use a random port
runner = new StreamableHttpRunner({
userConfig,
createSessionConfig,
});
await runner.start();

client = new Client({
name: "test-client",
version: "1.0.0",
});
transport = new StreamableHTTPClientTransport(new URL(`${runner.serverAddress}/mcp`));

await client.connect(transport);
const response = await client.listTools();

expect(response).toBeDefined();
expect(response.tools).toBeDefined();
expect(response.tools.length).toBeGreaterThan(0);

// Verify read-only mode is applied - insert-one should not be available
const writeTools = response.tools.filter((tool) => tool.name === "insert-one");
expect(writeTools.length).toBe(0);

// Verify read tools are available
const readTools = response.tools.filter((tool) => tool.name === "find");
expect(readTools.length).toBe(1);

await client.close();
await transport.close();
await runner.close();
});
});

describe("error handling", () => {
it("should propagate errors from configProvider on client connection", async () => {
const createSessionConfig: TransportRunnerConfig["createSessionConfig"] = async () => {
return Promise.reject(new Error("Failed to fetch config"));
};

userConfig.httpPort = 0; // Use a random port
runner = new StreamableHttpRunner({
userConfig,
createSessionConfig,
});

// Start succeeds because setupServer is only called when a client connects
await runner.start();

// Error should occur when a client tries to connect
const testClient = new Client({
name: "test-client",
version: "1.0.0",
});
const testTransport = new StreamableHTTPClientTransport(new URL(`${runner.serverAddress}/mcp`));

await expect(testClient.connect(testTransport)).rejects.toThrow();

await testClient.close();
await testTransport.close();

await runner.close();
});
});
});
Loading