diff --git a/src/common/atlas/cluster.ts b/src/common/atlas/cluster.ts index 1ea30286b..a153e7fea 100644 --- a/src/common/atlas/cluster.ts +++ b/src/common/atlas/cluster.ts @@ -1,4 +1,8 @@ -import type { ClusterDescription20240805, FlexClusterDescription20241113 } from "./openapi.js"; +import type { + ClusterConnectionStrings, + ClusterDescription20240805, + FlexClusterDescription20241113, +} from "./openapi.js"; import type { ApiClient } from "./apiClient.js"; import { LogId } from "../logger.js"; import { ConnectionString } from "mongodb-connection-string-url"; @@ -18,19 +22,18 @@ export interface Cluster { instanceSize?: string; state?: "IDLE" | "CREATING" | "UPDATING" | "DELETING" | "REPAIRING"; mongoDBVersion?: string; - connectionString?: string; + connectionStrings?: ClusterConnectionStrings; processIds?: Array; } export function formatFlexCluster(cluster: FlexClusterDescription20241113): Cluster { - const connectionString = cluster.connectionStrings?.standardSrv || cluster.connectionStrings?.standard; return { name: cluster.name, instanceType: "FLEX", instanceSize: undefined, state: cluster.stateName, mongoDBVersion: cluster.mongoDBVersion, - connectionString, + connectionStrings: cluster.connectionStrings, processIds: extractProcessIds(cluster.connectionStrings?.standard ?? ""), }; } @@ -65,7 +68,6 @@ export function formatCluster(cluster: ClusterDescription20240805): Cluster { const instanceSize = regionConfigs[0]?.instanceSize ?? "UNKNOWN"; const clusterInstanceType = instanceSize === "M0" ? "FREE" : "DEDICATED"; - const connectionString = cluster.connectionStrings?.standardSrv || cluster.connectionStrings?.standard; return { name: cluster.name, @@ -73,7 +75,7 @@ export function formatCluster(cluster: ClusterDescription20240805): Cluster { instanceSize: clusterInstanceType === "DEDICATED" ? instanceSize : undefined, state: cluster.stateName, mongoDBVersion: cluster.mongoDBVersion, - connectionString, + connectionStrings: cluster.connectionStrings, processIds: extractProcessIds(cluster.connectionStrings?.standard ?? ""), }; } @@ -112,6 +114,27 @@ export async function inspectCluster(apiClient: ApiClient, projectId: string, cl } } +/** + * Returns a connection string for the specified connectionType. + * For "privateEndpoint", it returns the first private endpoint connection string available. + */ +export function getConnectionString( + connectionStrings: ClusterConnectionStrings, + connectionType: "standard" | "private" | "privateEndpoint" +): string | undefined { + switch (connectionType) { + case "standard": + return connectionStrings.standardSrv || connectionStrings.standard; + case "private": + return connectionStrings.privateSrv || connectionStrings.private; + case "privateEndpoint": + return ( + connectionStrings.privateEndpoint?.[0]?.srvConnectionString || + connectionStrings.privateEndpoint?.[0]?.connectionString + ); + } +} + export async function getProcessIdsFromCluster( apiClient: ApiClient, projectId: string, diff --git a/src/tools/args.ts b/src/tools/args.ts index 653f72da2..11b5b8b80 100644 --- a/src/tools/args.ts +++ b/src/tools/args.ts @@ -41,6 +41,9 @@ export const AtlasArgs = { .max(64, "Cluster name must be 64 characters or less") .regex(ALLOWED_CLUSTER_NAME_CHARACTERS_REGEX, ALLOWED_CLUSTER_NAME_CHARACTERS_ERROR), + connectionType: (): z.ZodDefault> => + z.enum(["standard", "private", "privateEndpoint"]).default("standard"), + projectName: (): z.ZodString => z .string() diff --git a/src/tools/atlas/connect/connectCluster.ts b/src/tools/atlas/connect/connectCluster.ts index 54f3ae8bd..3ba519fc8 100644 --- a/src/tools/atlas/connect/connectCluster.ts +++ b/src/tools/atlas/connect/connectCluster.ts @@ -3,7 +3,7 @@ import { type OperationType, type ToolArgs } from "../../tool.js"; import { AtlasToolBase } from "../atlasTool.js"; import { generateSecurePassword } from "../../../helpers/generatePassword.js"; import { LogId } from "../../../common/logger.js"; -import { inspectCluster } from "../../../common/atlas/cluster.js"; +import { getConnectionString, inspectCluster } from "../../../common/atlas/cluster.js"; import { ensureCurrentIpInAccessList } from "../../../common/atlas/accessListUtils.js"; import type { AtlasClusterConnectionInfo } from "../../../common/connectionManager.js"; import { getDefaultRoleFromConfig } from "../../../common/atlas/roles.js"; @@ -22,6 +22,9 @@ function sleep(ms: number): Promise { export const ConnectClusterArgs = { projectId: AtlasArgs.projectId().describe("Atlas project ID"), clusterName: AtlasArgs.clusterName().describe("Atlas cluster name"), + connectionType: AtlasArgs.connectionType().describe( + "Type of connection (standard, private, or privateEndpoint) to an Atlas cluster" + ), }; export class ConnectClusterTool extends AtlasToolBase { @@ -69,12 +72,19 @@ export class ConnectClusterTool extends AtlasToolBase { private async prepareClusterConnection( projectId: string, - clusterName: string + clusterName: string, + connectionType: "standard" | "private" | "privateEndpoint" | undefined = "standard" ): Promise<{ connectionString: string; atlas: AtlasClusterConnectionInfo }> { const cluster = await inspectCluster(this.session.apiClient, projectId, clusterName); - if (!cluster.connectionString) { - throw new Error("Connection string not available"); + if (cluster.connectionStrings === undefined) { + throw new Error("Connection strings not available"); + } + const connectionString = getConnectionString(cluster.connectionStrings, connectionType); + if (connectionString === undefined) { + throw new Error( + `Connection string for connection type "${connectionType}" is not available. Please ensure this connection type is set up in Atlas. See https://www.mongodb.com/docs/atlas/connect-to-database-deployment/#connect-to-an-atlas-cluster.` + ); } const username = `mcpUser${Math.floor(Math.random() * 100000)}`; @@ -113,7 +123,7 @@ export class ConnectClusterTool extends AtlasToolBase { expiryDate, }; - const cn = new URL(cluster.connectionString); + const cn = new URL(connectionString); cn.username = username; cn.password = password; cn.searchParams.set("authSource", "admin"); @@ -200,7 +210,11 @@ export class ConnectClusterTool extends AtlasToolBase { }); } - protected async execute({ projectId, clusterName }: ToolArgs): Promise { + protected async execute({ + projectId, + clusterName, + connectionType, + }: ToolArgs): Promise { const ipAccessListUpdated = await ensureCurrentIpInAccessList(this.session.apiClient, projectId); let createdUser = false; @@ -239,7 +253,11 @@ export class ConnectClusterTool extends AtlasToolBase { case "disconnected": default: { await this.session.disconnect(); - const { connectionString, atlas } = await this.prepareClusterConnection(projectId, clusterName); + const { connectionString, atlas } = await this.prepareClusterConnection( + projectId, + clusterName, + connectionType + ); createdUser = true; // try to connect for about 5 minutes asynchronously diff --git a/src/tools/atlas/read/inspectCluster.ts b/src/tools/atlas/read/inspectCluster.ts index 56e1e5a8b..d4defcc92 100644 --- a/src/tools/atlas/read/inspectCluster.ts +++ b/src/tools/atlas/read/inspectCluster.ts @@ -30,7 +30,7 @@ export class InspectClusterTool extends AtlasToolBase { "Cluster details:", `Cluster Name | Cluster Type | Tier | State | MongoDB Version | Connection String ----------------|----------------|----------------|----------------|----------------|---------------- -${formattedCluster.name || "Unknown"} | ${formattedCluster.instanceType} | ${formattedCluster.instanceSize || "N/A"} | ${formattedCluster.state || "UNKNOWN"} | ${formattedCluster.mongoDBVersion || "N/A"} | ${formattedCluster.connectionString || "N/A"}` +${formattedCluster.name || "Unknown"} | ${formattedCluster.instanceType} | ${formattedCluster.instanceSize || "N/A"} | ${formattedCluster.state || "UNKNOWN"} | ${formattedCluster.mongoDBVersion || "N/A"} | ${formattedCluster.connectionStrings?.standardSrv || formattedCluster.connectionStrings?.standard || "N/A"}` ), }; } diff --git a/src/tools/atlas/read/listClusters.ts b/src/tools/atlas/read/listClusters.ts index 60344f7d3..1dfe626ea 100644 --- a/src/tools/atlas/read/listClusters.ts +++ b/src/tools/atlas/read/listClusters.ts @@ -105,7 +105,7 @@ ${rows}`, ----------------|----------------|----------------|----------------|----------------|---------------- ${allClusters .map((formattedCluster) => { - return `${formattedCluster.name || "Unknown"} | ${formattedCluster.instanceType} | ${formattedCluster.instanceSize || "N/A"} | ${formattedCluster.state || "UNKNOWN"} | ${formattedCluster.mongoDBVersion || "N/A"} | ${formattedCluster.connectionString || "N/A"}`; + return `${formattedCluster.name || "Unknown"} | ${formattedCluster.instanceType} | ${formattedCluster.instanceSize || "N/A"} | ${formattedCluster.state || "UNKNOWN"} | ${formattedCluster.mongoDBVersion || "N/A"} | ${formattedCluster.connectionStrings?.standardSrv || formattedCluster.connectionStrings?.standard || "N/A"}`; }) .join("\n")}` ), diff --git a/tests/integration/tools/atlas/clusters.test.ts b/tests/integration/tools/atlas/clusters.test.ts index f340dc08f..30f15bb96 100644 --- a/tests/integration/tools/atlas/clusters.test.ts +++ b/tests/integration/tools/atlas/clusters.test.ts @@ -150,16 +150,18 @@ describeWithAtlas("clusters", (integration) => { expectDefined(connectCluster.inputSchema.properties); expect(connectCluster.inputSchema.properties).toHaveProperty("projectId"); expect(connectCluster.inputSchema.properties).toHaveProperty("clusterName"); + expect(connectCluster.inputSchema.properties).toHaveProperty("connectionType"); }); it("connects to cluster", async () => { const projectId = getProjectId(); + const connectionType = "standard"; let connected = false; for (let i = 0; i < 10; i++) { const response = await integration.mcpClient().callTool({ name: "atlas-connect-cluster", - arguments: { projectId, clusterName }, + arguments: { projectId, clusterName, connectionType }, }); const elements = getResponseElements(response.content);