Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions src/common/atlas/cluster.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import type { ClusterDescription20240805, FlexClusterDescription20241113 } from "./openapi.js";
import type {
ClusterConnectionStrings,
ClusterDescription20240805,
FlexClusterDescription20241113,
} from "./openapi.js";
import type { ApiClient } from "./apiClient.js";
import { LogId } from "../logger.js";
import { ConnectionString } from "mongodb-connection-string-url";
Expand All @@ -18,19 +22,18 @@ export interface Cluster {
instanceSize?: string;
state?: "IDLE" | "CREATING" | "UPDATING" | "DELETING" | "REPAIRING";
mongoDBVersion?: string;
connectionString?: string;
connectionStrings?: ClusterConnectionStrings;
processIds?: Array<string>;
}

export function formatFlexCluster(cluster: FlexClusterDescription20241113): Cluster {
const connectionString = cluster.connectionStrings?.standardSrv || cluster.connectionStrings?.standard;
return {
name: cluster.name,
instanceType: "FLEX",
instanceSize: undefined,
state: cluster.stateName,
mongoDBVersion: cluster.mongoDBVersion,
connectionString,
connectionStrings: cluster.connectionStrings,
processIds: extractProcessIds(cluster.connectionStrings?.standard ?? ""),
};
}
Expand Down Expand Up @@ -65,15 +68,14 @@ export function formatCluster(cluster: ClusterDescription20240805): Cluster {

const instanceSize = regionConfigs[0]?.instanceSize ?? "UNKNOWN";
const clusterInstanceType = instanceSize === "M0" ? "FREE" : "DEDICATED";
const connectionString = cluster.connectionStrings?.standardSrv || cluster.connectionStrings?.standard;

return {
name: cluster.name,
instanceType: clusterInstanceType,
instanceSize: clusterInstanceType === "DEDICATED" ? instanceSize : undefined,
state: cluster.stateName,
mongoDBVersion: cluster.mongoDBVersion,
connectionString,
connectionStrings: cluster.connectionStrings,
processIds: extractProcessIds(cluster.connectionStrings?.standard ?? ""),
};
}
Expand Down Expand Up @@ -112,6 +114,27 @@ export async function inspectCluster(apiClient: ApiClient, projectId: string, cl
}
}

/**
* Returns a connection string for the specified connectionType.
* For "privateEndpoint", it returns the first private endpoint connection string available.
*/
export function getConnectionString(
connectionStrings: ClusterConnectionStrings,
connectionType: "standard" | "private" | "privateEndpoint"
): string | undefined {
switch (connectionType) {
case "standard":
return connectionStrings.standardSrv || connectionStrings.standard;
case "private":
return connectionStrings.privateSrv || connectionStrings.private;
case "privateEndpoint":
return (
connectionStrings.privateEndpoint?.[0]?.srvConnectionString ||
connectionStrings.privateEndpoint?.[0]?.connectionString
);
}
}

export async function getProcessIdsFromCluster(
apiClient: ApiClient,
projectId: string,
Expand Down
3 changes: 3 additions & 0 deletions src/tools/args.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ export const AtlasArgs = {
.max(64, "Cluster name must be 64 characters or less")
.regex(ALLOWED_CLUSTER_NAME_CHARACTERS_REGEX, ALLOWED_CLUSTER_NAME_CHARACTERS_ERROR),

connectionType: (): z.ZodDefault<z.ZodEnum<["standard", "private", "privateEndpoint"]>> =>
z.enum(["standard", "private", "privateEndpoint"]).default("standard"),

projectName: (): z.ZodString =>
z
.string()
Expand Down
32 changes: 25 additions & 7 deletions src/tools/atlas/connect/connectCluster.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { type OperationType, type ToolArgs } from "../../tool.js";
import { AtlasToolBase } from "../atlasTool.js";
import { generateSecurePassword } from "../../../helpers/generatePassword.js";
import { LogId } from "../../../common/logger.js";
import { inspectCluster } from "../../../common/atlas/cluster.js";
import { getConnectionString, inspectCluster } from "../../../common/atlas/cluster.js";
import { ensureCurrentIpInAccessList } from "../../../common/atlas/accessListUtils.js";
import type { AtlasClusterConnectionInfo } from "../../../common/connectionManager.js";
import { getDefaultRoleFromConfig } from "../../../common/atlas/roles.js";
Expand All @@ -22,6 +22,9 @@ function sleep(ms: number): Promise<void> {
export const ConnectClusterArgs = {
projectId: AtlasArgs.projectId().describe("Atlas project ID"),
clusterName: AtlasArgs.clusterName().describe("Atlas cluster name"),
connectionType: AtlasArgs.connectionType().describe(
"Type of connection (standard, private, or privateEndpoint) to an Atlas cluster"
),
};

export class ConnectClusterTool extends AtlasToolBase {
Expand Down Expand Up @@ -69,12 +72,19 @@ export class ConnectClusterTool extends AtlasToolBase {

private async prepareClusterConnection(
projectId: string,
clusterName: string
clusterName: string,
connectionType: "standard" | "private" | "privateEndpoint" | undefined = "standard"
): Promise<{ connectionString: string; atlas: AtlasClusterConnectionInfo }> {
const cluster = await inspectCluster(this.session.apiClient, projectId, clusterName);

if (!cluster.connectionString) {
throw new Error("Connection string not available");
if (cluster.connectionStrings === undefined) {
throw new Error("Connection strings not available");
}
const connectionString = getConnectionString(cluster.connectionStrings, connectionType);
if (connectionString === undefined) {
throw new Error(
`Connection string for connection type "${connectionType}" is not available. Please ensure this connection type is set up in Atlas. See https://www.mongodb.com/docs/atlas/connect-to-database-deployment/#connect-to-an-atlas-cluster.`
);
}

const username = `mcpUser${Math.floor(Math.random() * 100000)}`;
Expand Down Expand Up @@ -113,7 +123,7 @@ export class ConnectClusterTool extends AtlasToolBase {
expiryDate,
};

const cn = new URL(cluster.connectionString);
const cn = new URL(connectionString);
cn.username = username;
cn.password = password;
cn.searchParams.set("authSource", "admin");
Expand Down Expand Up @@ -200,7 +210,11 @@ export class ConnectClusterTool extends AtlasToolBase {
});
}

protected async execute({ projectId, clusterName }: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
protected async execute({
projectId,
clusterName,
connectionType,
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
const ipAccessListUpdated = await ensureCurrentIpInAccessList(this.session.apiClient, projectId);
let createdUser = false;

Expand Down Expand Up @@ -239,7 +253,11 @@ export class ConnectClusterTool extends AtlasToolBase {
case "disconnected":
default: {
await this.session.disconnect();
const { connectionString, atlas } = await this.prepareClusterConnection(projectId, clusterName);
const { connectionString, atlas } = await this.prepareClusterConnection(
projectId,
clusterName,
connectionType
);

createdUser = true;
// try to connect for about 5 minutes asynchronously
Expand Down
2 changes: 1 addition & 1 deletion src/tools/atlas/read/inspectCluster.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ export class InspectClusterTool extends AtlasToolBase {
"Cluster details:",
`Cluster Name | Cluster Type | Tier | State | MongoDB Version | Connection String
----------------|----------------|----------------|----------------|----------------|----------------
${formattedCluster.name || "Unknown"} | ${formattedCluster.instanceType} | ${formattedCluster.instanceSize || "N/A"} | ${formattedCluster.state || "UNKNOWN"} | ${formattedCluster.mongoDBVersion || "N/A"} | ${formattedCluster.connectionString || "N/A"}`
${formattedCluster.name || "Unknown"} | ${formattedCluster.instanceType} | ${formattedCluster.instanceSize || "N/A"} | ${formattedCluster.state || "UNKNOWN"} | ${formattedCluster.mongoDBVersion || "N/A"} | ${formattedCluster.connectionStrings?.standardSrv || formattedCluster.connectionStrings?.standard || "N/A"}`
),
};
}
Expand Down
2 changes: 1 addition & 1 deletion src/tools/atlas/read/listClusters.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ ${rows}`,
----------------|----------------|----------------|----------------|----------------|----------------
${allClusters
.map((formattedCluster) => {
return `${formattedCluster.name || "Unknown"} | ${formattedCluster.instanceType} | ${formattedCluster.instanceSize || "N/A"} | ${formattedCluster.state || "UNKNOWN"} | ${formattedCluster.mongoDBVersion || "N/A"} | ${formattedCluster.connectionString || "N/A"}`;
return `${formattedCluster.name || "Unknown"} | ${formattedCluster.instanceType} | ${formattedCluster.instanceSize || "N/A"} | ${formattedCluster.state || "UNKNOWN"} | ${formattedCluster.mongoDBVersion || "N/A"} | ${formattedCluster.connectionStrings?.standardSrv || formattedCluster.connectionStrings?.standard || "N/A"}`;
})
.join("\n")}`
),
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/tools/atlas/clusters.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,18 @@ describeWithAtlas("clusters", (integration) => {
expectDefined(connectCluster.inputSchema.properties);
expect(connectCluster.inputSchema.properties).toHaveProperty("projectId");
expect(connectCluster.inputSchema.properties).toHaveProperty("clusterName");
expect(connectCluster.inputSchema.properties).toHaveProperty("connectionType");
});

it("connects to cluster", async () => {
const projectId = getProjectId();
const connectionType = "standard";
let connected = false;

for (let i = 0; i < 10; i++) {
const response = await integration.mcpClient().callTool({
name: "atlas-connect-cluster",
arguments: { projectId, clusterName },
arguments: { projectId, clusterName, connectionType },
});

const elements = getResponseElements(response.content);
Expand Down
Loading