diff --git a/src/tools/atlas/connect/connectCluster.ts b/src/tools/atlas/connect/connectCluster.ts index 3ba519fc..77ee17f8 100644 --- a/src/tools/atlas/connect/connectCluster.ts +++ b/src/tools/atlas/connect/connectCluster.ts @@ -218,6 +218,39 @@ export class ConnectClusterTool extends AtlasToolBase { const ipAccessListUpdated = await ensureCurrentIpInAccessList(this.session.apiClient, projectId); let createdUser = false; + const state = this.queryConnection(projectId, clusterName); + switch (state) { + case "connected-to-other-cluster": + case "disconnected": { + await this.session.disconnect(); + + const { connectionString, atlas } = await this.prepareClusterConnection( + projectId, + clusterName, + connectionType + ); + + createdUser = true; + + // try to connect for about 5 minutes asynchronously + void this.connectToCluster(connectionString, atlas).catch((err: unknown) => { + const error = err instanceof Error ? err : new Error(String(err)); + this.session.logger.error({ + id: LogId.atlasConnectFailure, + context: "atlas-connect-cluster", + message: `error connecting to cluster: ${error.message}`, + }); + }); + break; + } + case "connecting": + case "connected": + case "unknown": + default: { + break; + } + } + for (let i = 0; i < 60; i++) { const state = this.queryConnection(projectId, clusterName); switch (state) { @@ -246,34 +279,15 @@ export class ConnectClusterTool extends AtlasToolBase { return { content }; } case "connecting": - case "unknown": { - break; - } + case "unknown": case "connected-to-other-cluster": case "disconnected": default: { - await this.session.disconnect(); - const { connectionString, atlas } = await this.prepareClusterConnection( - projectId, - clusterName, - connectionType - ); - - createdUser = true; - // try to connect for about 5 minutes asynchronously - void this.connectToCluster(connectionString, atlas).catch((err: unknown) => { - const error = err instanceof Error ? err : new Error(String(err)); - this.session.logger.error({ - id: LogId.atlasConnectFailure, - context: "atlas-connect-cluster", - message: `error connecting to cluster: ${error.message}`, - }); - }); break; } } - await sleep(500); + await sleep(500); // wait 500ms before checking the connection state again } const content: CallToolResult["content"] = [ diff --git a/tests/integration/tools/atlas/atlasHelpers.ts b/tests/integration/tools/atlas/atlasHelpers.ts index ab807e93..308fdc06 100644 --- a/tests/integration/tools/atlas/atlasHelpers.ts +++ b/tests/integration/tools/atlas/atlasHelpers.ts @@ -33,8 +33,16 @@ interface ProjectTestArgs { getIpAddress: () => string; } +interface ClusterTestArgs { + getProjectId: () => string; + getIpAddress: () => string; + getClusterName: () => string; +} + type ProjectTestFunction = (args: ProjectTestArgs) => void; +type ClusterTestFunction = (args: ClusterTestArgs) => void; + export function withCredentials(integration: IntegrationTest, fn: IntegrationTestFunction): SuiteCollector { const describeFn = !process.env.MDB_MCP_API_CLIENT_ID?.length || !process.env.MDB_MCP_API_CLIENT_SECRET?.length @@ -71,25 +79,25 @@ export function withProject(integration: IntegrationTest, fn: ProjectTestFunctio } }); - afterAll(() => { + afterAll(async () => { if (!projectId) { return; } const apiClient = integration.mcpServer().session.apiClient; - // send the delete request and ignore errors - apiClient - .deleteGroup({ + try { + await apiClient.deleteGroup({ params: { path: { groupId: projectId, }, }, - }) - .catch((error) => { - console.log("Failed to delete project:", error); }); + } catch (error) { + // send the delete request and ignore errors + console.log("Failed to delete group:", error); + } }); const args = { @@ -101,10 +109,12 @@ export function withProject(integration: IntegrationTest, fn: ProjectTestFunctio }); } -export const randomId = new ObjectId().toString(); +export function randomId(): string { + return new ObjectId().toString(); +} async function createGroup(apiClient: ApiClient): Promise>> { - const projectName: string = `testProj-` + randomId; + const projectName: string = `testProj-` + randomId(); const orgs = await apiClient.listOrgs(); if (!orgs?.results?.length || !orgs.results[0]?.id) { @@ -229,3 +239,78 @@ export async function waitCluster( `Cluster wait timeout: ${clusterName} did not meet condition within ${maxPollingIterations} iterations` ); } + +export function withCluster(integration: IntegrationTest, fn: ClusterTestFunction): SuiteCollector { + return withProject(integration, ({ getProjectId, getIpAddress }) => { + describe("with cluster", () => { + const clusterName: string = `test-cluster-${randomId()}`; + + beforeAll(async () => { + const apiClient = integration.mcpServer().session.apiClient; + + const projectId = getProjectId(); + + const input = { + groupId: projectId, + name: clusterName, + clusterType: "REPLICASET", + replicationSpecs: [ + { + zoneName: "Zone 1", + regionConfigs: [ + { + providerName: "TENANT", + backingProviderName: "AWS", + regionName: "US_EAST_1", + electableSpecs: { + instanceSize: "M0", + }, + }, + ], + }, + ], + terminationProtectionEnabled: false, + } as unknown as ClusterDescription20240805; + + await apiClient.createCluster({ + params: { + path: { + groupId: projectId, + }, + }, + body: input, + }); + + await waitCluster(integration.mcpServer().session, projectId, clusterName, (cluster) => { + return cluster.stateName === "IDLE"; + }); + }); + + afterAll(async () => { + const apiClient = integration.mcpServer().session.apiClient; + + try { + // send the delete request and ignore errors + await apiClient.deleteCluster({ + params: { + path: { + groupId: getProjectId(), + clusterName, + }, + }, + }); + } catch (error) { + console.log("Failed to delete cluster:", error); + } + }); + + const args = { + getProjectId: (): string => getProjectId(), + getIpAddress: (): string => getIpAddress(), + getClusterName: (): string => clusterName, + }; + + fn(args); + }); + }); +} diff --git a/tests/integration/tools/atlas/clusters.test.ts b/tests/integration/tools/atlas/clusters.test.ts index c8c31101..a06a8523 100644 --- a/tests/integration/tools/atlas/clusters.test.ts +++ b/tests/integration/tools/atlas/clusters.test.ts @@ -1,11 +1,19 @@ import type { Session } from "../../../../src/common/session.js"; import { expectDefined, getResponseContent } from "../../helpers.js"; -import { describeWithAtlas, withProject, randomId, deleteCluster, waitCluster, sleep } from "./atlasHelpers.js"; -import { afterAll, beforeAll, describe, expect, it } from "vitest"; +import { + describeWithAtlas, + withProject, + withCluster, + randomId, + deleteCluster, + waitCluster, + sleep, +} from "./atlasHelpers.js"; +import { afterAll, beforeAll, describe, expect, it, vitest } from "vitest"; describeWithAtlas("clusters", (integration) => { withProject(integration, ({ getProjectId, getIpAddress }) => { - const clusterName = "ClusterTest-" + randomId; + const clusterName = "ClusterTest-" + randomId(); afterAll(async () => { const projectId = getProjectId(); @@ -142,6 +150,11 @@ describeWithAtlas("clusters", (integration) => { }); it("connects to cluster", async () => { + const createDatabaseUserSpy = vitest.spyOn( + integration.mcpServer().session.apiClient, + "createDatabaseUser" + ); + const projectId = getProjectId(); const connectionType = "standard"; let connected = false; @@ -158,6 +171,8 @@ describeWithAtlas("clusters", (integration) => { if (content.includes(`Connected to cluster "${clusterName}"`)) { connected = true; + expect(createDatabaseUserSpy).toHaveBeenCalledTimes(1); + // assert that some of the element s have the message expect(content).toContain( "Note: A temporary user has been created to enable secure connection to the cluster. For more information, see https://dochub.mongodb.org/core/mongodb-mcp-server-tools-considerations" @@ -172,6 +187,58 @@ describeWithAtlas("clusters", (integration) => { expect(connected).toBe(true); }); + describe("when connected", () => { + withCluster( + integration, + ({ getProjectId: getSecondaryProjectId, getClusterName: getSecondaryClusterName }) => { + beforeAll(async () => { + let connected = false; + for (let i = 0; i < 10; i++) { + const response = await integration.mcpClient().callTool({ + name: "atlas-connect-cluster", + arguments: { + projectId: getSecondaryProjectId(), + clusterName: getSecondaryClusterName(), + connectionType: "standard", + }, + }); + + const content = getResponseContent(response.content); + + if (content.includes(`Connected to cluster "${getSecondaryClusterName()}"`)) { + connected = true; + break; + } + + await sleep(500); + } + + if (!connected) { + throw new Error("Could not connect to cluster before tests"); + } + }); + + it("disconnects and deletes the database user before connecting to another cluster", async () => { + const deleteDatabaseUserSpy = vitest.spyOn( + integration.mcpServer().session.apiClient, + "deleteDatabaseUser" + ); + + await integration.mcpClient().callTool({ + name: "atlas-connect-cluster", + arguments: { + projectId: getProjectId(), + clusterName: clusterName, + connectionType: "standard", + }, + }); + + expect(deleteDatabaseUserSpy).toHaveBeenCalledTimes(1); + }); + } + ); + }); + describe("when not connected", () => { it("prompts for atlas-connect-cluster when querying mongodb", async () => { const response = await integration.mcpClient().callTool({ diff --git a/tests/integration/tools/atlas/dbUsers.test.ts b/tests/integration/tools/atlas/dbUsers.test.ts index fa46aaa6..5bba1e2d 100644 --- a/tests/integration/tools/atlas/dbUsers.test.ts +++ b/tests/integration/tools/atlas/dbUsers.test.ts @@ -8,7 +8,7 @@ describeWithAtlas("db users", (integration) => { withProject(integration, ({ getProjectId }) => { let userName: string; beforeEach(() => { - userName = "testuser-" + randomId; + userName = "testuser-" + randomId(); }); const createUserWithMCP = async (password?: string): Promise => { diff --git a/tests/integration/tools/atlas/performanceAdvisor.test.ts b/tests/integration/tools/atlas/performanceAdvisor.test.ts index 9f9cc73a..f8b5ec24 100644 --- a/tests/integration/tools/atlas/performanceAdvisor.test.ts +++ b/tests/integration/tools/atlas/performanceAdvisor.test.ts @@ -18,7 +18,7 @@ import type { BaseEvent, ToolEvent } from "../../../../src/telemetry/types.js"; describeWithAtlas("performanceAdvisor", (integration) => { withProject(integration, ({ getProjectId }) => { - const clusterName = "ClusterTest-" + randomId; + const clusterName = "ClusterTest-" + randomId(); afterAll(async () => { const projectId = getProjectId();