diff --git a/tests/accuracy/createIndex.test.ts b/tests/accuracy/createIndex.test.ts index f3c600eaf..66f330148 100644 --- a/tests/accuracy/createIndex.test.ts +++ b/tests/accuracy/createIndex.test.ts @@ -1,140 +1,145 @@ import { describeAccuracyTests } from "./sdk/describeAccuracyTests.js"; import { Matcher } from "./sdk/matcher.js"; -// TODO: supply this with a proper config API once we refactor describeAccuracyTests to support it -process.env.MDB_VOYAGE_API_KEY = "valid-key"; - -describeAccuracyTests([ - { - prompt: "Create an index that covers the following query on 'mflix.movies' namespace - { \"release_year\": 1992 }", - expectedToolCalls: [ - { - toolName: "create-index", - parameters: { - database: "mflix", - collection: "movies", - name: Matcher.anyOf(Matcher.undefined, Matcher.string()), - definition: [ - { - type: "classic", - keys: { - release_year: 1, +describeAccuracyTests( + [ + { + prompt: "Create an index that covers the following query on 'mflix.movies' namespace - { \"release_year\": 1992 }", + expectedToolCalls: [ + { + toolName: "create-index", + parameters: { + database: "mflix", + collection: "movies", + name: Matcher.anyOf(Matcher.undefined, Matcher.string()), + definition: [ + { + type: "classic", + keys: { + release_year: 1, + }, }, - }, - ], + ], + }, }, - }, - ], - }, - { - prompt: "Create a text index on title field in 'mflix.movies' namespace", - expectedToolCalls: [ - { - toolName: "create-index", - parameters: { - database: "mflix", - collection: "movies", - name: Matcher.anyOf(Matcher.undefined, Matcher.string()), - definition: [ - { - type: "classic", - keys: { - title: "text", + ], + }, + { + prompt: "Create a text index on title field in 'mflix.movies' namespace", + expectedToolCalls: [ + { + toolName: "create-index", + parameters: { + database: "mflix", + collection: "movies", + name: Matcher.anyOf(Matcher.undefined, Matcher.string()), + definition: [ + { + type: "classic", + keys: { + title: "text", + }, }, - }, - ], + ], + }, }, - }, - ], - }, - { - prompt: "Create a vector search index on 'mydb.movies' namespace on the 'plotSummary' field. The index should use 1024 dimensions.", - expectedToolCalls: [ - { - toolName: "create-index", - parameters: { - database: "mydb", - collection: "movies", - name: Matcher.anyOf(Matcher.undefined, Matcher.string()), - definition: [ - { - type: "vectorSearch", - fields: [ - { - type: "vector", - path: "plotSummary", - numDimensions: 1024, - }, - ], - }, - ], + ], + }, + { + prompt: "Create a vector search index on 'mflix.movies' namespace on the 'plotSummary' field. The index should use 1024 dimensions.", + expectedToolCalls: [ + { + toolName: "create-index", + parameters: { + database: "mflix", + collection: "movies", + name: Matcher.anyOf(Matcher.undefined, Matcher.string()), + definition: [ + { + type: "vectorSearch", + fields: [ + { + type: "vector", + path: "plotSummary", + numDimensions: 1024, + }, + ], + }, + ], + }, }, - }, - ], - }, - { - prompt: "Create a vector search index on 'mydb.movies' namespace with on the 'plotSummary' field and 'genre' field, both of which contain vector embeddings. Pick a sensible number of dimensions for a voyage 3.5 model.", - expectedToolCalls: [ - { - toolName: "create-index", - parameters: { - database: "mydb", - collection: "movies", - name: Matcher.anyOf(Matcher.undefined, Matcher.string()), - definition: [ - { - type: "vectorSearch", - fields: [ - { - type: "vector", - path: "plotSummary", - numDimensions: Matcher.number( - (value) => value % 2 === 0 && value >= 256 && value <= 8192 - ), - similarity: Matcher.anyOf(Matcher.undefined, Matcher.string()), - }, - { - type: "vector", - path: "genre", - numDimensions: Matcher.number( - (value) => value % 2 === 0 && value >= 256 && value <= 8192 - ), - similarity: Matcher.anyOf(Matcher.undefined, Matcher.string()), - }, - ], - }, - ], + ], + }, + { + prompt: "Create a vector search index on 'mflix.movies' namespace with on the 'plotSummary' field and 'genre' field, both of which contain vector embeddings. Pick a sensible number of dimensions for a voyage 3.5 model.", + expectedToolCalls: [ + { + toolName: "create-index", + parameters: { + database: "mflix", + collection: "movies", + name: Matcher.anyOf(Matcher.undefined, Matcher.string()), + definition: [ + { + type: "vectorSearch", + fields: [ + { + type: "vector", + path: "plotSummary", + numDimensions: Matcher.number( + (value) => value % 2 === 0 && value >= 256 && value <= 8192 + ), + similarity: Matcher.anyOf(Matcher.undefined, Matcher.string()), + }, + { + type: "vector", + path: "genre", + numDimensions: Matcher.number( + (value) => value % 2 === 0 && value >= 256 && value <= 8192 + ), + similarity: Matcher.anyOf(Matcher.undefined, Matcher.string()), + }, + ], + }, + ], + }, }, - }, - ], - }, - { - prompt: "Create a vector search index on 'mydb.movies' namespace where the 'plotSummary' field is indexed as a 1024-dimensional vector and the 'releaseDate' field is indexed as a regular field.", - expectedToolCalls: [ - { - toolName: "create-index", - parameters: { - database: "mydb", - collection: "movies", - name: Matcher.anyOf(Matcher.undefined, Matcher.string()), - definition: [ - { - type: "vectorSearch", - fields: [ - { - type: "vector", - path: "plotSummary", - numDimensions: 1024, - }, - { - type: "filter", - path: "releaseDate", - }, - ], - }, - ], + ], + }, + { + prompt: "Create a vector search index on 'mflix.movies' namespace where the 'plotSummary' field is indexed as a 1024-dimensional vector and the 'releaseDate' field is indexed as a regular field.", + expectedToolCalls: [ + { + toolName: "create-index", + parameters: { + database: "mflix", + collection: "movies", + name: Matcher.anyOf(Matcher.undefined, Matcher.string()), + definition: [ + { + type: "vectorSearch", + fields: [ + { + type: "vector", + path: "plotSummary", + numDimensions: 1024, + }, + { + type: "filter", + path: "releaseDate", + }, + ], + }, + ], + }, }, - }, - ], - }, -]); + ], + }, + ], + { + userConfig: { voyageApiKey: "valid-key" }, + clusterConfig: { + search: true, + }, + } +); diff --git a/tests/accuracy/createIndex.vectorSearchDisabled.test.ts b/tests/accuracy/createIndex.vectorSearchDisabled.test.ts new file mode 100644 index 000000000..eb5fd3ebe --- /dev/null +++ b/tests/accuracy/createIndex.vectorSearchDisabled.test.ts @@ -0,0 +1,57 @@ +/** + * Accuracy tests for when the vector search feature flag is disabled. + * + * TODO: Remove this file once we permanently enable the vector search feature. + */ +import { describeAccuracyTests } from "./sdk/describeAccuracyTests.js"; +import { Matcher } from "./sdk/matcher.js"; + +describeAccuracyTests( + [ + { + prompt: "(vectorSearchDisabled) Create an index that covers the following query on 'mflix.movies' namespace - { \"release_year\": 1992 }", + expectedToolCalls: [ + { + toolName: "create-index", + parameters: { + database: "mflix", + collection: "movies", + name: Matcher.anyOf(Matcher.undefined, Matcher.string()), + definition: [ + { + type: Matcher.anyOf(Matcher.undefined, Matcher.value("classic")), + keys: { + release_year: 1, + }, + }, + ], + }, + }, + ], + }, + { + prompt: "(vectorSearchDisabled) Create a text index on title field in 'mflix.movies' namespace", + expectedToolCalls: [ + { + toolName: "create-index", + parameters: { + database: "mflix", + collection: "movies", + name: Matcher.anyOf(Matcher.undefined, Matcher.string()), + definition: [ + { + type: Matcher.anyOf(Matcher.undefined, Matcher.value("classic")), + keys: { + title: "text", + }, + }, + ], + }, + }, + ], + }, + ], + { + userConfig: { voyageApiKey: "" }, + } +); diff --git a/tests/accuracy/sdk/accuracyTestingClient.ts b/tests/accuracy/sdk/accuracyTestingClient.ts index 48cba3b2c..6ebed6878 100644 --- a/tests/accuracy/sdk/accuracyTestingClient.ts +++ b/tests/accuracy/sdk/accuracyTestingClient.ts @@ -6,6 +6,7 @@ import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js" import { MCP_SERVER_CLI_SCRIPT } from "./constants.js"; import type { LLMToolCall } from "./accuracyResultStorage/resultStorage.js"; import type { VercelMCPClient, VercelMCPClientTools } from "./agent.js"; +import type { UserConfig } from "../../../src/lib.js"; type ToolResultGeneratorFn = (parameters: Record) => CallToolResult | Promise; export type MockedTools = Record; @@ -81,18 +82,13 @@ export class AccuracyTestingClient { static async initializeClient( mdbConnectionString: string, - atlasApiClientId?: string, - atlasApiClientSecret?: string, - voyageApiKey?: string + userConfig: Partial<{ [k in keyof UserConfig]: string }> = {} ): Promise { - const args = [ - MCP_SERVER_CLI_SCRIPT, - "--connectionString", - mdbConnectionString, - ...(atlasApiClientId ? ["--apiClientId", atlasApiClientId] : []), - ...(atlasApiClientSecret ? ["--apiClientSecret", atlasApiClientSecret] : []), - ...(voyageApiKey ? ["--voyageApiKey", voyageApiKey] : []), - ]; + const additionalArgs = Object.entries(userConfig).flatMap(([key, value]) => { + return [`--${key}`, value]; + }); + + const args = [MCP_SERVER_CLI_SCRIPT, "--connectionString", mdbConnectionString, ...additionalArgs]; const clientTransport = new StdioClientTransport({ command: process.execPath, diff --git a/tests/accuracy/sdk/describeAccuracyTests.ts b/tests/accuracy/sdk/describeAccuracyTests.ts index 4c39e9623..adf75d7da 100644 --- a/tests/accuracy/sdk/describeAccuracyTests.ts +++ b/tests/accuracy/sdk/describeAccuracyTests.ts @@ -10,6 +10,11 @@ import type { AccuracyResultStorage, ExpectedToolCall, LLMToolCall } from "./acc import { getAccuracyResultStorage } from "./accuracyResultStorage/getAccuracyResultStorage.js"; import { getCommitSHA } from "./gitInfo.js"; import type { MongoClient } from "mongodb"; +import type { UserConfig } from "../../../src/lib.js"; +import { + MongoDBClusterProcess, + type MongoClusterConfiguration, +} from "../../integration/tools/mongodb/mongodbClusterProcess.js"; export interface AccuracyTestConfig { /** The prompt to be provided to LLM for evaluation. */ @@ -48,7 +53,13 @@ export interface AccuracyTestConfig { ) => Promise | number; } -export function describeAccuracyTests(accuracyTestConfigs: AccuracyTestConfig[]): void { +export function describeAccuracyTests( + accuracyTestConfigs: AccuracyTestConfig[], + { + userConfig: partialUserConfig, + clusterConfig, + }: { userConfig?: Partial<{ [k in keyof UserConfig]: string }>; clusterConfig?: MongoClusterConfiguration } = {} +): void { if (!process.env.MDB_ACCURACY_RUN_ID) { throw new Error("MDB_ACCURACY_RUN_ID env variable is required for accuracy test runs!"); } @@ -58,17 +69,22 @@ export function describeAccuracyTests(accuracyTestConfigs: AccuracyTestConfig[]) throw new Error("No models available to test. Ensure that the API keys are properly setup!"); } - const eachModel = describe.each(models); + const shouldSkip = clusterConfig && !MongoDBClusterProcess.isConfigurationSupportedInCurrentEnv(clusterConfig); + + const eachModel = describe.skipIf(shouldSkip).each(models); eachModel(`$displayName`, function (model) { const configsWithDescriptions = getConfigsWithDescriptions(accuracyTestConfigs); const accuracyRunId = `${process.env.MDB_ACCURACY_RUN_ID}`; - const mdbIntegration = setupMongoDBIntegrationTest(); + const mdbIntegration = setupMongoDBIntegrationTest(clusterConfig); const { populateTestData, cleanupTestDatabases } = prepareTestData(mdbIntegration); - const atlasApiClientId = process.env.MDB_MCP_API_CLIENT_ID; - const atlasApiClientSecret = process.env.MDB_MCP_API_CLIENT_SECRET; - const voyageApiKey = process.env.MDB_VOYAGE_API_KEY; + const userConfig: Partial<{ [k in keyof UserConfig]: string }> = { + apiClientId: process.env.MDB_MCP_API_CLIENT_ID, + apiClientSecret: process.env.MDB_MCP_API_CLIENT_SECRET, + voyageApiKey: process.env.MDB_VOYAGE_API_KEY, + ...partialUserConfig, + }; let commitSHA: string; let accuracyResultStorage: AccuracyResultStorage; @@ -83,12 +99,7 @@ export function describeAccuracyTests(accuracyTestConfigs: AccuracyTestConfig[]) commitSHA = retrievedCommitSHA; accuracyResultStorage = getAccuracyResultStorage(); - testMCPClient = await AccuracyTestingClient.initializeClient( - mdbIntegration.connectionString(), - atlasApiClientId, - atlasApiClientSecret, - voyageApiKey - ); + testMCPClient = await AccuracyTestingClient.initializeClient(mdbIntegration.connectionString(), userConfig); agent = getVercelToolCallingAgent(); }); diff --git a/tests/integration/tools/mongodb/mongodbClusterProcess.ts b/tests/integration/tools/mongodb/mongodbClusterProcess.ts index b0f7ee863..bd0da659f 100644 --- a/tests/integration/tools/mongodb/mongodbClusterProcess.ts +++ b/tests/integration/tools/mongodb/mongodbClusterProcess.ts @@ -27,7 +27,8 @@ export class MongoDBClusterProcess { return new MongoDBClusterProcess( () => runningContainer.stop(), - () => `mongodb://${runningContainer.getHost()}:${runningContainer.getMappedPort(27017)}` + () => + `mongodb://${runningContainer.getHost()}:${runningContainer.getMappedPort(27017)}/?directConnection=true` ); } else if (MongoDBClusterProcess.isMongoRunnerOptions(config)) { const { downloadOptions, serverArgs } = config;