diff --git a/package-lock.json b/package-lock.json index 290bccbb8..ea84bf9b8 100644 --- a/package-lock.json +++ b/package-lock.json @@ -14,6 +14,7 @@ "@mongodb-js/devtools-proxy-support": "^0.5.3", "@mongosh/arg-parser": "^3.19.0", "@mongosh/service-provider-node-driver": "^3.17.0", + "ai": "^5.0.72", "bson": "^6.10.4", "express": "^5.1.0", "lru-cache": "^11.1.0", @@ -26,6 +27,7 @@ "oauth4webapi": "^3.8.0", "openapi-fetch": "^0.14.0", "ts-levenshtein": "^1.0.7", + "voyage-ai-provider": "^2.0.0", "yargs-parser": "21.1.1", "zod": "^3.25.76" }, @@ -48,7 +50,6 @@ "@typescript-eslint/parser": "^8.44.0", "@vitest/coverage-v8": "^3.2.4", "@vitest/eslint-plugin": "^1.3.4", - "ai": "^5.0.72", "duplexpair": "^1.0.2", "eslint": "^9.34.0", "eslint-config-prettier": "^10.1.8", @@ -96,42 +97,10 @@ "zod": "^3.25.76 || ^4.1.8" } }, - "node_modules/@ai-sdk/azure/node_modules/@ai-sdk/provider": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-2.0.0.tgz", - "integrity": "sha512-6o7Y2SeO9vFKB8lArHXehNuusnpddKPk7xqL7T2/b+OvXMRIXUO1rR4wcv1hAFUAT9avGZshty3Wlua/XA7TvA==", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "json-schema": "^0.4.0" - }, - "engines": { - "node": ">=18" - } - }, - "node_modules/@ai-sdk/azure/node_modules/@ai-sdk/provider-utils": { - "version": "3.0.12", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-3.0.12.tgz", - "integrity": "sha512-ZtbdvYxdMoria+2SlNarEk6Hlgyf+zzcznlD55EAl+7VZvJaSg2sqPvwArY7L6TfDEDJsnCq0fdhBSkYo0Xqdg==", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "@ai-sdk/provider": "2.0.0", - "@standard-schema/spec": "^1.0.0", - "eventsource-parser": "^3.0.5" - }, - "engines": { - "node": ">=18" - }, - "peerDependencies": { - "zod": "^3.25.76 || ^4.1.8" - } - }, "node_modules/@ai-sdk/gateway": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/@ai-sdk/gateway/-/gateway-2.0.0.tgz", "integrity": "sha512-Gj0PuawK7NkZuyYgO/h5kDK/l6hFOjhLdTq3/Lli1FTl47iGmwhH1IZQpAL3Z09BeFYWakcwUmn02ovIm2wy9g==", - "dev": true, "license": "Apache-2.0", "dependencies": { "@ai-sdk/provider": "2.0.0", @@ -145,37 +114,6 @@ "zod": "^3.25.76 || ^4.1.8" } }, - "node_modules/@ai-sdk/gateway/node_modules/@ai-sdk/provider": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-2.0.0.tgz", - "integrity": "sha512-6o7Y2SeO9vFKB8lArHXehNuusnpddKPk7xqL7T2/b+OvXMRIXUO1rR4wcv1hAFUAT9avGZshty3Wlua/XA7TvA==", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "json-schema": "^0.4.0" - }, - "engines": { - "node": ">=18" - } - }, - "node_modules/@ai-sdk/gateway/node_modules/@ai-sdk/provider-utils": { - "version": "3.0.12", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-3.0.12.tgz", - "integrity": "sha512-ZtbdvYxdMoria+2SlNarEk6Hlgyf+zzcznlD55EAl+7VZvJaSg2sqPvwArY7L6TfDEDJsnCq0fdhBSkYo0Xqdg==", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "@ai-sdk/provider": "2.0.0", - "@standard-schema/spec": "^1.0.0", - "eventsource-parser": "^3.0.5" - }, - "engines": { - "node": ">=18" - }, - "peerDependencies": { - "zod": "^3.25.76 || ^4.1.8" - } - }, "node_modules/@ai-sdk/google": { "version": "2.0.23", "resolved": "https://registry.npmjs.org/@ai-sdk/google/-/google-2.0.23.tgz", @@ -193,37 +131,6 @@ "zod": "^3.25.76 || ^4.1.8" } }, - "node_modules/@ai-sdk/google/node_modules/@ai-sdk/provider": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-2.0.0.tgz", - "integrity": "sha512-6o7Y2SeO9vFKB8lArHXehNuusnpddKPk7xqL7T2/b+OvXMRIXUO1rR4wcv1hAFUAT9avGZshty3Wlua/XA7TvA==", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "json-schema": "^0.4.0" - }, - "engines": { - "node": ">=18" - } - }, - "node_modules/@ai-sdk/google/node_modules/@ai-sdk/provider-utils": { - "version": "3.0.12", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-3.0.12.tgz", - "integrity": "sha512-ZtbdvYxdMoria+2SlNarEk6Hlgyf+zzcznlD55EAl+7VZvJaSg2sqPvwArY7L6TfDEDJsnCq0fdhBSkYo0Xqdg==", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "@ai-sdk/provider": "2.0.0", - "@standard-schema/spec": "^1.0.0", - "eventsource-parser": "^3.0.5" - }, - "engines": { - "node": ">=18" - }, - "peerDependencies": { - "zod": "^3.25.76 || ^4.1.8" - } - }, "node_modules/@ai-sdk/openai": { "version": "2.0.52", "resolved": "https://registry.npmjs.org/@ai-sdk/openai/-/openai-2.0.52.tgz", @@ -241,11 +148,10 @@ "zod": "^3.25.76 || ^4.1.8" } }, - "node_modules/@ai-sdk/openai/node_modules/@ai-sdk/provider": { + "node_modules/@ai-sdk/provider": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-2.0.0.tgz", "integrity": "sha512-6o7Y2SeO9vFKB8lArHXehNuusnpddKPk7xqL7T2/b+OvXMRIXUO1rR4wcv1hAFUAT9avGZshty3Wlua/XA7TvA==", - "dev": true, "license": "Apache-2.0", "dependencies": { "json-schema": "^0.4.0" @@ -254,11 +160,10 @@ "node": ">=18" } }, - "node_modules/@ai-sdk/openai/node_modules/@ai-sdk/provider-utils": { + "node_modules/@ai-sdk/provider-utils": { "version": "3.0.12", "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-3.0.12.tgz", "integrity": "sha512-ZtbdvYxdMoria+2SlNarEk6Hlgyf+zzcznlD55EAl+7VZvJaSg2sqPvwArY7L6TfDEDJsnCq0fdhBSkYo0Xqdg==", - "dev": true, "license": "Apache-2.0", "dependencies": { "@ai-sdk/provider": "2.0.0", @@ -2023,7 +1928,6 @@ "version": "1.9.0", "resolved": "https://registry.npmjs.org/@opentelemetry/api/-/api-1.9.0.tgz", "integrity": "sha512-3giAOQvZiH5F9bMlMiv8+GSPMeqg0dbaeo58/0SlA9sxSqZhnUtxzX9/2FzyhS9sWQf5S0GJE0AKBrFqjpeYcg==", - "dev": true, "license": "Apache-2.0", "engines": { "node": ">=8.0.0" @@ -3995,7 +3899,6 @@ "version": "1.0.0", "resolved": "https://registry.npmjs.org/@standard-schema/spec/-/spec-1.0.0.tgz", "integrity": "sha512-m2bOd0f2RT9k8QJx1JN85cZYyH1RqFBdlwtkSlf4tBDYLCiiZnv1fIIwacK6cqwXavOydf0NPToMQgpKq+dVlA==", - "dev": true, "license": "MIT" }, "node_modules/@tootallnate/quickjs-emscripten": { @@ -4734,7 +4637,6 @@ "version": "3.0.3", "resolved": "https://registry.npmjs.org/@vercel/oidc/-/oidc-3.0.3.tgz", "integrity": "sha512-yNEQvPcVrK9sIe637+I0jD6leluPxzwJKx/Haw6F4H77CdDsszUn5V3o96LPziXkSNE2B83+Z3mjqGKBK/R6Gg==", - "dev": true, "license": "Apache-2.0", "engines": { "node": ">= 20" @@ -4985,10 +4887,9 @@ } }, "node_modules/ai": { - "version": "5.0.76", - "resolved": "https://registry.npmjs.org/ai/-/ai-5.0.76.tgz", - "integrity": "sha512-ZCxi1vrpyCUnDbtYrO/W8GLvyacV9689f00yshTIQ3mFFphbD7eIv40a2AOZBv3GGRA7SSRYIDnr56wcS/gyQg==", - "dev": true, + "version": "5.0.72", + "resolved": "https://registry.npmjs.org/ai/-/ai-5.0.72.tgz", + "integrity": "sha512-LB4APrlESLGHG/5x+VVdl0yYPpHPHpnGd5Gwl7AWVL+n7T0GYsNos/S/6dZ5CZzxLnPPEBkRgvJC4rupeZqyNg==", "license": "Apache-2.0", "dependencies": { "@ai-sdk/gateway": "2.0.0", @@ -5003,37 +4904,6 @@ "zod": "^3.25.76 || ^4.1.8" } }, - "node_modules/ai/node_modules/@ai-sdk/provider": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-2.0.0.tgz", - "integrity": "sha512-6o7Y2SeO9vFKB8lArHXehNuusnpddKPk7xqL7T2/b+OvXMRIXUO1rR4wcv1hAFUAT9avGZshty3Wlua/XA7TvA==", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "json-schema": "^0.4.0" - }, - "engines": { - "node": ">=18" - } - }, - "node_modules/ai/node_modules/@ai-sdk/provider-utils": { - "version": "3.0.12", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-3.0.12.tgz", - "integrity": "sha512-ZtbdvYxdMoria+2SlNarEk6Hlgyf+zzcznlD55EAl+7VZvJaSg2sqPvwArY7L6TfDEDJsnCq0fdhBSkYo0Xqdg==", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "@ai-sdk/provider": "2.0.0", - "@standard-schema/spec": "^1.0.0", - "eventsource-parser": "^3.0.5" - }, - "engines": { - "node": ">=18" - }, - "peerDependencies": { - "zod": "^3.25.76 || ^4.1.8" - } - }, "node_modules/ajv": { "version": "6.12.6", "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", @@ -9356,7 +9226,6 @@ "version": "0.4.0", "resolved": "https://registry.npmjs.org/json-schema/-/json-schema-0.4.0.tgz", "integrity": "sha512-es94M3nTIfsEPisRafak+HDLfHXnKBhV3vU5eqPcS3flIWqcxJWgXHXiey3YrpaNsanY5ei1VoYEbOzijuq9BA==", - "dev": true, "license": "(AFL-2.1 OR BSD-3-Clause)" }, "node_modules/json-schema-to-ts": { @@ -14640,6 +14509,24 @@ "url": "https://github.com/sponsors/jonschlinkert" } }, + "node_modules/voyage-ai-provider": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/voyage-ai-provider/-/voyage-ai-provider-2.0.0.tgz", + "integrity": "sha512-AX00egENhHOAfuHAhvmoBVQNG6+f717763CfyPefjahDTxbt6nCE0IlDXn5nkzLIu00JoM/PDFYDYQ17NYQqPw==", + "license": "MIT", + "dependencies": { + "@ai-sdk/provider": "^2.0.0", + "@ai-sdk/provider-utils": "^3.0.0" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4" + }, + "peerDependenciesMeta": { + "zod": { + "optional": true + } + } + }, "node_modules/walk-up-path": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/walk-up-path/-/walk-up-path-4.0.0.tgz", diff --git a/package.json b/package.json index ca4e48e52..f31e1cdaa 100644 --- a/package.json +++ b/package.json @@ -76,7 +76,6 @@ "@typescript-eslint/parser": "^8.44.0", "@vitest/coverage-v8": "^3.2.4", "@vitest/eslint-plugin": "^1.3.4", - "ai": "^5.0.72", "duplexpair": "^1.0.2", "eslint": "^9.34.0", "eslint-config-prettier": "^10.1.8", @@ -104,6 +103,7 @@ "@mongodb-js/devtools-proxy-support": "^0.5.3", "@mongosh/arg-parser": "^3.19.0", "@mongosh/service-provider-node-driver": "^3.17.0", + "ai": "^5.0.72", "bson": "^6.10.4", "express": "^5.1.0", "lru-cache": "^11.1.0", @@ -116,6 +116,7 @@ "oauth4webapi": "^3.8.0", "openapi-fetch": "^0.14.0", "ts-levenshtein": "^1.0.7", + "voyage-ai-provider": "^2.0.0", "yargs-parser": "21.1.1", "zod": "^3.25.76" }, diff --git a/src/common/errors.ts b/src/common/errors.ts index 13779ee1c..5880eb781 100644 --- a/src/common/errors.ts +++ b/src/common/errors.ts @@ -4,6 +4,9 @@ export enum ErrorCodes { ForbiddenCollscan = 1_000_002, ForbiddenWriteOperation = 1_000_003, AtlasSearchNotSupported = 1_000_004, + NoEmbeddingsProviderConfigured = 1_000_005, + AtlasVectorSearchIndexNotFound = 1_000_006, + AtlasVectorSearchInvalidQuery = 1_000_007, } export class MongoDBError extends Error { diff --git a/src/common/search/embeddingsProvider.ts b/src/common/search/embeddingsProvider.ts new file mode 100644 index 000000000..efc93e436 --- /dev/null +++ b/src/common/search/embeddingsProvider.ts @@ -0,0 +1,87 @@ +import { createVoyage } from "voyage-ai-provider"; +import type { VoyageProvider } from "voyage-ai-provider"; +import { embedMany } from "ai"; +import type { UserConfig } from "../config.js"; +import assert from "assert"; +import { createFetch } from "@mongodb-js/devtools-proxy-support"; +import { z } from "zod"; + +type EmbeddingsInput = string; +type Embeddings = number[]; +export type EmbeddingParameters = { + inputType: "query" | "document"; +}; + +export interface EmbeddingsProvider< + SupportedModels extends string, + SupportedEmbeddingParameters extends EmbeddingParameters, +> { + embed( + modelId: SupportedModels, + content: EmbeddingsInput[], + parameters: SupportedEmbeddingParameters + ): Promise; +} + +export const zVoyageModels = z + .enum(["voyage-3-large", "voyage-3.5", "voyage-3.5-lite", "voyage-code-3"]) + .default("voyage-3-large"); + +export const zVoyageEmbeddingParameters = z.object({ + outputDimension: z + .union([z.literal(256), z.literal(512), z.literal(1024), z.literal(2048), z.literal(4096)]) + .optional() + .default(1024), + outputDType: z.enum(["float", "int8", "uint8", "binary", "ubinary"]).optional().default("float"), +}); + +type VoyageModels = z.infer; +type VoyageEmbeddingParameters = z.infer & EmbeddingParameters; + +class VoyageEmbeddingsProvider implements EmbeddingsProvider { + private readonly voyage: VoyageProvider; + + constructor({ voyageApiKey }: UserConfig, providedFetch?: typeof fetch) { + assert(voyageApiKey, "The VoyageAI API Key does not exist. This is likely a bug."); + + // We should always use, by default, any enterprise proxy that the user has configured. + // Direct requests to VoyageAI might get blocked by the network if they don't go through + // the provided proxy. + const customFetch: typeof fetch = (providedFetch ?? + createFetch({ useEnvironmentVariableProxies: true })) as unknown as typeof fetch; + + this.voyage = createVoyage({ apiKey: voyageApiKey, fetch: customFetch }); + } + + static isConfiguredIn({ voyageApiKey }: UserConfig): boolean { + return !!voyageApiKey; + } + + async embed( + modelId: Model, + content: EmbeddingsInput[], + parameters: VoyageEmbeddingParameters + ): Promise { + const model = this.voyage.textEmbeddingModel(modelId); + const { embeddings } = await embedMany({ + model, + values: content, + providerOptions: { voyage: parameters }, + }); + + return embeddings; + } +} + +export function getEmbeddingsProvider( + userConfig: UserConfig +): EmbeddingsProvider | undefined { + if (VoyageEmbeddingsProvider.isConfiguredIn(userConfig)) { + return new VoyageEmbeddingsProvider(userConfig); + } + + return undefined; +} + +export const zSupportedEmbeddingParameters = zVoyageEmbeddingParameters.extend({ model: zVoyageModels }); +export type SupportedEmbeddingParameters = z.infer; diff --git a/src/common/search/vectorSearchEmbeddingsManager.ts b/src/common/search/vectorSearchEmbeddingsManager.ts index b6c06e485..a86d70269 100644 --- a/src/common/search/vectorSearchEmbeddingsManager.ts +++ b/src/common/search/vectorSearchEmbeddingsManager.ts @@ -3,6 +3,9 @@ import { BSON, type Document } from "bson"; import type { UserConfig } from "../config.js"; import type { ConnectionManager } from "../connectionManager.js"; import z from "zod"; +import { ErrorCodes, MongoDBError } from "../errors.js"; +import { getEmbeddingsProvider } from "./embeddingsProvider.js"; +import type { EmbeddingParameters, SupportedEmbeddingParameters } from "./embeddingsProvider.js"; export const similarityEnum = z.enum(["cosine", "euclidean", "dotProduct"]); export type Similarity = z.infer; @@ -32,7 +35,8 @@ export class VectorSearchEmbeddingsManager { constructor( private readonly config: UserConfig, private readonly connectionManager: ConnectionManager, - private readonly embeddings: Map = new Map() + private readonly embeddings: Map = new Map(), + private readonly embeddingsProvider: typeof getEmbeddingsProvider = getEmbeddingsProvider ) { connectionManager.events.on("connection-close", () => { this.embeddings.clear(); @@ -51,7 +55,7 @@ export class VectorSearchEmbeddingsManager { database: string; collection: string; }): Promise { - const provider = await this.assertAtlasSearchIsAvailable(); + const provider = await this.atlasSearchEnabledProvider(); if (!provider) { return []; } @@ -90,7 +94,7 @@ export class VectorSearchEmbeddingsManager { }, document: Document ): Promise { - const provider = await this.assertAtlasSearchIsAvailable(); + const provider = await this.atlasSearchEnabledProvider(); if (!provider) { return []; } @@ -108,7 +112,7 @@ export class VectorSearchEmbeddingsManager { .filter((e) => e !== undefined); } - private async assertAtlasSearchIsAvailable(): Promise { + private async atlasSearchEnabledProvider(): Promise { const connectionState = this.connectionManager.currentConnectionState; if (connectionState.tag === "connected" && (await connectionState.isSearchSupported())) { return connectionState.serviceProvider; @@ -216,6 +220,57 @@ export class VectorSearchEmbeddingsManager { return undefined; } + public async generateEmbeddings({ + database, + collection, + path, + rawValues, + embeddingParameters, + inputType, + }: { + database: string; + collection: string; + path: string; + rawValues: string[]; + embeddingParameters: SupportedEmbeddingParameters; + inputType: EmbeddingParameters["inputType"]; + }): Promise { + const provider = await this.atlasSearchEnabledProvider(); + if (!provider) { + throw new MongoDBError( + ErrorCodes.AtlasSearchNotSupported, + "Atlas Search is not supported in this cluster." + ); + } + + const embeddingsProvider = this.embeddingsProvider(this.config); + + if (!embeddingsProvider) { + throw new MongoDBError(ErrorCodes.NoEmbeddingsProviderConfigured, "No embeddings provider configured."); + } + + if (this.config.disableEmbeddingsValidation) { + return await embeddingsProvider.embed(embeddingParameters.model, rawValues, { + inputType, + ...embeddingParameters, + }); + } + + const embeddingInfoForCollection = await this.embeddingsForNamespace({ database, collection }); + const embeddingInfoForPath = embeddingInfoForCollection.find((definition) => definition.path === path); + if (!embeddingInfoForPath) { + throw new MongoDBError( + ErrorCodes.AtlasVectorSearchIndexNotFound, + `No Vector Search index found for path "${path}" in namespace "${database}.${collection}"` + ); + } + + return await embeddingsProvider.embed(embeddingParameters.model, rawValues, { + inputType, + ...embeddingParameters, + }); + } + private isANumber(value: unknown): boolean { if (typeof value === "number") { return true; diff --git a/src/tools/mongodb/read/aggregate.ts b/src/tools/mongodb/read/aggregate.ts index 9ac18d357..c55786af9 100644 --- a/src/tools/mongodb/read/aggregate.ts +++ b/src/tools/mongodb/read/aggregate.ts @@ -13,9 +13,57 @@ import { operationWithFallback } from "../../../helpers/operationWithFallback.js import { AGG_COUNT_MAX_TIME_MS_CAP, ONE_MB, CURSOR_LIMITS_TO_LLM_TEXT } from "../../../helpers/constants.js"; import { zEJSON } from "../../args.js"; import { LogId } from "../../../common/logger.js"; +import { zSupportedEmbeddingParameters } from "../../../common/search/embeddingsProvider.js"; + +const AnyStage = zEJSON(); +const VectorSearchStage = z.object({ + $vectorSearch: z + .object({ + exact: z + .boolean() + .optional() + .default(false) + .describe( + "When true, uses an ENN algorithm, otherwise uses ANN. Using ENN is not compatible with numCandidates, in that case, numCandidates must be left empty." + ), + index: z.string().describe("Name of the index, as retrieved from the `collection-indexes` tool."), + path: z + .string() + .describe( + "Field, in dot notation, where to search. There must be a vector search index for that field. Note to LLM: When unsure, use the 'collection-indexes' tool to validate that the field is indexed with a vector search index." + ), + queryVector: z + .union([z.string(), z.array(z.number())]) + .describe( + "The content to search for. The embeddingParameters field is mandatory if the queryVector is a string, in that case, the tool generates the embedding automatically using the provided configuration." + ), + numCandidates: z + .number() + .int() + .positive() + .optional() + .describe("Number of candidates for the ANN algorithm. Mandatory when exact is false."), + limit: z.number().int().positive().optional().default(10), + filter: zEJSON() + .optional() + .describe( + "MQL filter that can only use pre-filter fields from the index definition. Note to LLM: If unsure, use the `collection-indexes` tool to learn which fields can be used for pre-filtering." + ), + embeddingParameters: zSupportedEmbeddingParameters + .optional() + .describe( + "The embedding model and its parameters to use to generate embeddings before searching. It is mandatory if queryVector is a string value. Note to LLM: If unsure, ask the user before providing one." + ), + }) + .passthrough(), +}); export const AggregateArgs = { - pipeline: z.array(zEJSON()).describe("An array of aggregation stages to execute"), + pipeline: z + .array(z.union([AnyStage, VectorSearchStage])) + .describe( + "An array of aggregation stages to execute. $vectorSearch can only appear as the first stage of the aggregation pipeline or as the first stage of a $unionWith subpipeline. When using $vectorSearch, unless the user explicitly asks for the embeddings, $unset any embedding field to avoid reaching context limits." + ), responseBytesLimit: z.number().optional().default(ONE_MB).describe(`\ The maximum number of bytes to return in the response. This value is capped by the server’s configured maxBytesPerQuery and cannot be exceeded. \ Note to LLM: If the entire aggregation result is required, use the "export" tool instead of increasing this limit.\ @@ -38,8 +86,7 @@ export class AggregateTool extends MongoDBToolBase { let aggregationCursor: AggregationCursor | undefined = undefined; try { const provider = await this.ensureConnected(); - - this.assertOnlyUsesPermittedStages(pipeline); + await this.assertOnlyUsesPermittedStages(pipeline); // Check if aggregate operation uses an index if enabled if (this.config.indexCheck) { @@ -50,6 +97,12 @@ export class AggregateTool extends MongoDBToolBase { }); } + pipeline = await this.replaceRawValuesWithEmbeddingsIfNecessary({ + database, + collection, + pipeline, + }); + const cappedResultsPipeline = [...pipeline]; if (this.config.maxDocumentsPerQuery > 0) { cappedResultsPipeline.push({ $limit: this.config.maxDocumentsPerQuery }); @@ -107,8 +160,10 @@ export class AggregateTool extends MongoDBToolBase { } } - private assertOnlyUsesPermittedStages(pipeline: Record[]): void { + private async assertOnlyUsesPermittedStages(pipeline: Record[]): Promise { const writeOperations: OperationType[] = ["update", "create", "delete"]; + const isSearchSupported = await this.session.isSearchSupported(); + let writeStageForbiddenError = ""; if (this.config.readOnly) { @@ -118,14 +173,22 @@ export class AggregateTool extends MongoDBToolBase { "When 'create', 'update', or 'delete' operations are disabled, you can not run pipelines with $out or $merge stages."; } - if (!writeStageForbiddenError) { - return; - } - for (const stage of pipeline) { - if (stage.$out || stage.$merge) { + // This validates that in readOnly mode or "write" operations are disabled, we can't use $out or $merge. + // This is really important because aggregates are the only "multi-faceted" tool in the MQL, where you + // can both read and write. + if ((stage.$out || stage.$merge) && writeStageForbiddenError) { throw new MongoDBError(ErrorCodes.ForbiddenWriteOperation, writeStageForbiddenError); } + + // This ensure that you can't use $vectorSearch if the cluster does not support MongoDB Search + // either in Atlas or in a local cluster. + if (stage.$vectorSearch && !isSearchSupported) { + throw new MongoDBError( + ErrorCodes.AtlasSearchNotSupported, + "Atlas Search is not supported in this cluster." + ); + } } } @@ -160,6 +223,52 @@ export class AggregateTool extends MongoDBToolBase { }, undefined); } + private async replaceRawValuesWithEmbeddingsIfNecessary({ + database, + collection, + pipeline, + }: { + database: string; + collection: string; + pipeline: Document[]; + }): Promise { + for (const stage of pipeline) { + if ("$vectorSearch" in stage) { + const { $vectorSearch: vectorSearchStage } = stage as z.infer; + + if (Array.isArray(vectorSearchStage.queryVector)) { + continue; + } + + if (!vectorSearchStage.embeddingParameters) { + throw new MongoDBError( + ErrorCodes.AtlasVectorSearchInvalidQuery, + "embeddingModel is mandatory if queryVector is a raw string." + ); + } + + const embeddingParameters = vectorSearchStage.embeddingParameters; + delete vectorSearchStage.embeddingParameters; + + const [embeddings] = await this.session.vectorSearchEmbeddingsManager.generateEmbeddings({ + database, + collection, + path: vectorSearchStage.path, + rawValues: [vectorSearchStage.queryVector], + embeddingParameters, + inputType: "query", + }); + + // $vectorSearch.queryVector can be a BSON.Binary: that it's not either number or an array. + // It's not exactly valid from the LLM perspective (they can't provide binaries). + // That's why we overwrite the stage in an untyped way, as what we expose and what LLMs can use is different. + vectorSearchStage.queryVector = embeddings as number[]; + } + } + + return pipeline; + } + private generateMessage({ aggResultsCount, documents, diff --git a/tests/accuracy/aggregate.test.ts b/tests/accuracy/aggregate.test.ts index 08b1ca613..85340a331 100644 --- a/tests/accuracy/aggregate.test.ts +++ b/tests/accuracy/aggregate.test.ts @@ -1,5 +1,6 @@ import { describeAccuracyTests } from "./sdk/describeAccuracyTests.js"; import { Matcher } from "./sdk/matcher.js"; +import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; describeAccuracyTests([ { @@ -24,4 +25,193 @@ describeAccuracyTests([ }, ], }, + { + prompt: "Run a vectorSearch query on musicfy.songs on path 'title_embeddings' using the index 'titles' with the model voyage-3-large to find all 'hammer of justice' songs.", + expectedToolCalls: [ + { + toolName: "collection-indexes", + parameters: { + database: "musicfy", + collection: "songs", + }, + optional: true, + }, + { + toolName: "aggregate", + parameters: { + database: "musicfy", + collection: "songs", + pipeline: [ + { + $vectorSearch: { + exact: Matcher.anyOf(Matcher.undefined, Matcher.boolean(false)), + index: "titles", + path: "title_embeddings", + queryVector: "hammer of justice", + embeddingParameters: { + model: "voyage-3-large", + outputDimension: Matcher.anyOf( + Matcher.undefined, + Matcher.number((n) => n === 1024) + ), + }, + filter: Matcher.emptyObjectOrUndefined, + }, + }, + ], + responseBytesLimit: Matcher.anyOf(Matcher.number(), Matcher.undefined), + }, + }, + ], + mockedTools: { + "collection-indexes": (): CallToolResult => { + return { + content: [ + { + type: "text", + text: JSON.stringify({ + name: "titles", + type: "vectorSearch", + status: "READY", + queryable: true, + latestDefinition: { + type: "vector", + path: "title_embeddings", + numDimensions: 1024, + quantization: "none", + similarity: "euclidean", + }, + }), + }, + ], + }; + }, + }, + }, + { + prompt: "Run an exact vectorSearch query on musicfy.songs on path 'title_embeddings' using the index 'titles' with the model voyage-3-large to find 10 'hammer of justice' songs in any order.", + expectedToolCalls: [ + { + toolName: "collection-indexes", + parameters: { + database: "musicfy", + collection: "songs", + }, + optional: true, + }, + { + toolName: "aggregate", + parameters: { + database: "musicfy", + collection: "songs", + pipeline: [ + { + $vectorSearch: { + exact: Matcher.anyOf(Matcher.undefined, Matcher.boolean(true)), + index: "titles", + path: "title_embeddings", + queryVector: "hammer of justice", + limit: 10, + embeddingParameters: { + model: "voyage-3-large", + outputDimension: Matcher.anyOf( + Matcher.undefined, + Matcher.number((n) => n === 1024) + ), + }, + filter: Matcher.emptyObjectOrUndefined, + }, + }, + ], + responseBytesLimit: Matcher.anyOf(Matcher.number(), Matcher.undefined), + }, + }, + ], + mockedTools: { + "collection-indexes": (): CallToolResult => { + return { + content: [ + { + type: "text", + text: JSON.stringify({ + name: "titles", + type: "vectorSearch", + status: "READY", + queryable: true, + latestDefinition: { + type: "vector", + path: "title_embeddings", + numDimensions: 1024, + quantization: "none", + similarity: "euclidean", + }, + }), + }, + ], + }; + }, + }, + }, + { + prompt: "Run an approximate vectorSearch query on mflix.movies on path 'plot_embeddings' with the model voyage-3-large to find all 'sci-fy' movies.", + expectedToolCalls: [ + { + toolName: "collection-indexes", + parameters: { + database: "mflix", + collection: "movies", + }, + }, + { + toolName: "aggregate", + parameters: { + database: "mflix", + collection: "movies", + pipeline: [ + { + $vectorSearch: { + exact: Matcher.anyOf(Matcher.undefined, Matcher.boolean(false)), + index: "my-index", + path: "plot_embeddings", + queryVector: "sci-fy", + embeddingParameters: { + model: "voyage-3-large", + outputDimension: Matcher.anyOf( + Matcher.undefined, + Matcher.number((n) => n === 1024) + ), + }, + filter: Matcher.emptyObjectOrUndefined, + }, + }, + ], + responseBytesLimit: Matcher.anyOf(Matcher.number(), Matcher.undefined), + }, + }, + ], + mockedTools: { + "collection-indexes": (): CallToolResult => { + return { + content: [ + { + type: "text", + text: JSON.stringify({ + name: "my-index", + type: "vectorSearch", + status: "READY", + queryable: true, + latestDefinition: { + type: "vector", + path: "plot_embeddings", + numDimensions: 1024, + quantization: "none", + similarity: "euclidean", + }, + }), + }, + ], + }; + }, + }, + }, ]); diff --git a/tests/integration/tools/mongodb/read/aggregate.test.ts b/tests/integration/tools/mongodb/read/aggregate.test.ts index d585d5786..e167830ac 100644 --- a/tests/integration/tools/mongodb/read/aggregate.test.ts +++ b/tests/integration/tools/mongodb/read/aggregate.test.ts @@ -6,9 +6,16 @@ import { defaultTestConfig, } from "../../../helpers.js"; import { beforeEach, describe, expect, it, vi, afterEach } from "vitest"; -import { describeWithMongoDB, getDocsFromUntrustedContent, validateAutoConnectBehavior } from "../mongodbHelpers.js"; +import { + createVectorSearchIndexAndWait, + describeWithMongoDB, + getDocsFromUntrustedContent, + validateAutoConnectBehavior, + waitUntilSearchIsReady, +} from "../mongodbHelpers.js"; import * as constants from "../../../../../src/helpers/constants.js"; import { freshInsertDocuments } from "./find.test.js"; +import { BSON } from "bson"; describeWithMongoDB("aggregate tool", (integration) => { afterEach(() => { @@ -20,7 +27,8 @@ describeWithMongoDB("aggregate tool", (integration) => { ...databaseCollectionParameters, { name: "pipeline", - description: "An array of aggregation stages to execute", + description: + "An array of aggregation stages to execute. $vectorSearch can only appear as the first stage of the aggregation pipeline or as the first stage of a $unionWith subpipeline. When using $vectorSearch, unless the user explicitly asks for the embeddings, $unset any embedding field to avoid reaching context limits.", type: "array", required: true, }, @@ -377,3 +385,297 @@ describeWithMongoDB( getUserConfig: () => ({ ...defaultTestConfig, maxDocumentsPerQuery: -1, maxBytesPerQuery: -1 }), } ); + +import { DOCUMENT_EMBEDDINGS } from "./vyai/embeddings.js"; + +describeWithMongoDB( + "aggregate tool with atlas search enabled", + (integration) => { + beforeEach(async () => { + await integration.mongoClient().db(integration.randomDbName()).collection("databases").drop(); + }); + + for (const [dataType, embedding] of Object.entries(DOCUMENT_EMBEDDINGS)) { + for (const similarity of ["euclidean", "cosine", "dotProduct"]) { + describe.skipIf(!process.env.TEST_MDB_MCP_VOYAGE_API_KEY)( + `querying with dataType ${dataType} and similarity ${similarity}`, + () => { + it(`should be able to return elements from within a vector search query with data type ${dataType}`, async () => { + await waitUntilSearchIsReady(integration.mongoClient()); + + const collection = integration + .mongoClient() + .db(integration.randomDbName()) + .collection("databases"); + await collection.insertOne({ name: "mongodb", description_embedding: embedding }); + + await createVectorSearchIndexAndWait( + integration.mongoClient(), + integration.randomDbName(), + "databases", + [ + { + type: "vector", + path: "description_embedding", + numDimensions: 256, + similarity, + quantization: "none", + }, + ] + ); + + // now query the index + await integration.connectMcpClient(); + const response = await integration.mcpClient().callTool({ + name: "aggregate", + arguments: { + database: integration.randomDbName(), + collection: "databases", + pipeline: [ + { + $vectorSearch: { + index: "default", + path: "description_embedding", + queryVector: embedding, + numCandidates: 10, + limit: 10, + embeddingParameters: { + model: "voyage-3-large", + outputDimension: 256, + outputDType: dataType, + }, + }, + }, + { + $project: { + description_embedding: 0, + }, + }, + ], + }, + }); + + const responseContent = getResponseContent(response); + expect(responseContent).toContain( + "The aggregation resulted in 1 documents. Returning 1 documents." + ); + const untrustedDocs = getDocsFromUntrustedContent<{ name: string }>(responseContent); + expect(untrustedDocs).toHaveLength(1); + expect(untrustedDocs[0]?.name).toBe("mongodb"); + }); + + it("should be able to return elements from within a vector search query using binary encoding", async () => { + await waitUntilSearchIsReady(integration.mongoClient()); + + const collection = integration + .mongoClient() + .db(integration.randomDbName()) + .collection("databases"); + await collection.insertOne({ + name: "mongodb", + description_embedding: BSON.Binary.fromFloat32Array(new Float32Array(embedding)), + }); + + await createVectorSearchIndexAndWait( + integration.mongoClient(), + integration.randomDbName(), + "databases", + [ + { + type: "vector", + path: "description_embedding", + numDimensions: 256, + similarity, + quantization: "none", + }, + ] + ); + + // now query the index + await integration.connectMcpClient(); + const response = await integration.mcpClient().callTool({ + name: "aggregate", + arguments: { + database: integration.randomDbName(), + collection: "databases", + pipeline: [ + { + $vectorSearch: { + index: "default", + path: "description_embedding", + queryVector: embedding, + numCandidates: 10, + limit: 10, + embeddingParameters: { + model: "voyage-3-large", + outputDimension: 256, + outputDType: dataType, + }, + }, + }, + { + $project: { + description_embedding: 0, + }, + }, + ], + }, + }); + + const responseContent = getResponseContent(response); + expect(responseContent).toContain( + "The aggregation resulted in 1 documents. Returning 1 documents." + ); + const untrustedDocs = getDocsFromUntrustedContent<{ name: string }>(responseContent); + expect(untrustedDocs).toHaveLength(1); + expect(untrustedDocs[0]?.name).toBe("mongodb"); + }); + + it("should be able too return elements from within a vector search query using scalar quantization", async () => { + await waitUntilSearchIsReady(integration.mongoClient()); + + const collection = integration + .mongoClient() + .db(integration.randomDbName()) + .collection("databases"); + await collection.insertOne({ + name: "mongodb", + description_embedding: BSON.Binary.fromFloat32Array(new Float32Array(embedding)), + }); + + await createVectorSearchIndexAndWait( + integration.mongoClient(), + integration.randomDbName(), + "databases", + [ + { + type: "vector", + path: "description_embedding", + numDimensions: 256, + similarity, + quantization: "scalar", + }, + ] + ); + + // now query the index + await integration.connectMcpClient(); + const response = await integration.mcpClient().callTool({ + name: "aggregate", + arguments: { + database: integration.randomDbName(), + collection: "databases", + pipeline: [ + { + $vectorSearch: { + index: "default", + path: "description_embedding", + queryVector: embedding, + numCandidates: 10, + limit: 10, + embeddingParameters: { + model: "voyage-3-large", + outputDimension: 256, + outputDType: dataType, + }, + }, + }, + { + $project: { + description_embedding: 0, + }, + }, + ], + }, + }); + + const responseContent = getResponseContent(response); + expect(responseContent).toContain( + "The aggregation resulted in 1 documents. Returning 1 documents." + ); + const untrustedDocs = getDocsFromUntrustedContent<{ name: string }>(responseContent); + expect(untrustedDocs).toHaveLength(1); + expect(untrustedDocs[0]?.name).toBe("mongodb"); + }); + + it("should be able too return elements from within a vector search query using binary quantization", async () => { + await waitUntilSearchIsReady(integration.mongoClient()); + + const collection = integration + .mongoClient() + .db(integration.randomDbName()) + .collection("databases"); + await collection.insertOne({ + name: "mongodb", + description_embedding: BSON.Binary.fromFloat32Array(new Float32Array(embedding)), + }); + + await createVectorSearchIndexAndWait( + integration.mongoClient(), + integration.randomDbName(), + "databases", + [ + { + type: "vector", + path: "description_embedding", + numDimensions: 256, + similarity, + quantization: "binary", + }, + ] + ); + + // now query the index + await integration.connectMcpClient(); + const response = await integration.mcpClient().callTool({ + name: "aggregate", + arguments: { + database: integration.randomDbName(), + collection: "databases", + pipeline: [ + { + $vectorSearch: { + index: "default", + path: "description_embedding", + queryVector: embedding, + numCandidates: 10, + limit: 10, + embeddingParameters: { + model: "voyage-3-large", + outputDimension: 256, + outputDType: dataType, + }, + }, + }, + { + $project: { + description_embedding: 0, + }, + }, + ], + }, + }); + + const responseContent = getResponseContent(response); + expect(responseContent).toContain( + "The aggregation resulted in 1 documents. Returning 1 documents." + ); + const untrustedDocs = getDocsFromUntrustedContent<{ name: string }>(responseContent); + expect(untrustedDocs).toHaveLength(1); + expect(untrustedDocs[0]?.name).toBe("mongodb"); + }); + } + ); + } + } + }, + { + getUserConfig: () => ({ + ...defaultTestConfig, + voyageApiKey: process.env.TEST_MDB_MCP_VOYAGE_API_KEY ?? "", + maxDocumentsPerQuery: -1, + maxBytesPerQuery: -1, + }), + downloadOptions: { search: true }, + } +); diff --git a/tests/integration/tools/mongodb/read/vyai/embeddings.ts b/tests/integration/tools/mongodb/read/vyai/embeddings.ts new file mode 100644 index 000000000..f3d01d93a --- /dev/null +++ b/tests/integration/tools/mongodb/read/vyai/embeddings.ts @@ -0,0 +1,62 @@ +export const DOCUMENT_EMBEDDINGS = { + float: [ + -0.119673342, 0.028537489, 0.050937884, -0.093283832, 0.050631031, 0.008438504, -0.006635733, -0.082850769, + -0.056768127, -0.06443949, 2.637e-5, 0.014422172, -0.087146744, -0.04173224, 0.039277405, 0.029304625, + 0.004717892, 0.117832206, 0.031759463, 0.019945556, 0.031606037, -0.155882195, -0.02086612, -0.090828992, + -0.026849788, 0.010126205, 0.009512496, 0.130106404, 0.042039096, -0.06658747, -0.055847559, -0.038663693, + -0.072110862, 0.073338278, 0.034521155, -0.058302399, -0.052472156, -0.036975995, 0.004602821, -0.06443949, + -0.008093293, -0.061984655, 0.098807223, -0.1429943, 0.012197475, 0.003567186, -0.099420927, 0.087146744, + -0.085305609, 0.011737193, 0.02086612, -0.022707248, 0.04173224, 0.052779011, -0.005523385, 0.045721356, + -0.094511256, -0.09512496, 0.086533032, -0.028844343, -0.042039096, -0.006750803, -0.050324172, 0.125810429, + -0.052472156, -0.02147983, -0.013808462, 0.019945556, -0.072417714, 0.047869336, -0.03958426, -0.016877009, + 0.071804009, 0.017797574, 0.010816629, -0.144221723, -0.004986389, 0.089601576, -0.10985399, 0.101262063, + -0.022707248, 0.001006867, -0.002358946, 0.067508034, 0.124583013, -0.154654771, 0.031606037, -0.165701538, + 0.003202796, -0.009512496, 0.080395937, 0.106171735, 0.004756248, 0.123969309, -0.01396189, 0.024088096, + -0.013118039, 0.02792378, 0.026849788, 0.020098984, -0.113536254, 2.3973e-5, -0.111081414, 0.051858447, + -0.053392723, 0.060757235, 0.044800788, -0.049403612, -0.075179406, 0.03958426, -0.013808462, -0.013578322, + -0.079782225, -0.16447413, 0.007594654, 0.039277405, 0.042039096, -0.035595149, 0.034828011, 0.006022024, + 0.038356841, 0.045107644, 0.084078193, -0.044493936, 0.024548376, 0.008822073, 0.027003214, -0.0487899, + 0.067201182, -0.053392723, 0.108012855, 0.070883438, 0.022553822, 0.110467695, -0.055540707, -0.030685471, + -0.146676555, 0.064746343, -0.036669139, -0.046948772, 0.020559266, -0.142380595, -0.010049492, 0.015112595, + 0.091442712, 0.022707248, -0.050937884, 0.026849788, -0.075486265, 0.018181141, 0.014192032, 0.041118532, + -0.038049985, -0.011813907, 0.067201182, 0.005293244, -0.059222963, -0.088374153, -0.098193504, 0.012350903, + -0.030838897, 0.113536254, -0.035595149, 0.073338278, 0.146676555, -0.013271467, -0.043266516, -0.061984655, + -0.054006428, 0.120287046, 0.052472156, 0.022860678, -0.018948279, 0.007671368, -0.008822073, 0.021786686, + 0.033447165, -0.065666914, 0.025162086, 0.005715169, 0.042345952, 0.006520663, -0.025775796, 0.060757235, + -0.044800788, 0.052779011, 0.033140309, -0.033293735, -0.01856471, 0.045107644, -0.052779011, 0.038049985, + -0.086533032, -0.077327386, -0.051244736, -0.155882195, 0.010356346, -0.15956445, 0.019331846, -0.04756248, + -0.0145756, 0.130720109, -0.007096016, 0.041425385, -0.042652804, 0.005600099, -0.017030437, 0.002493195, + 0.032219745, -0.054313287, 0.044493936, -0.011813907, 0.025622368, 0.054006428, -0.010586488, -0.055847559, + 0.034981437, 0.077327386, 0.024548376, 0.106171735, 0.032066315, 0.069962874, 0.059836671, -0.031452607, + -0.00027569, -0.022246968, 0.058302399, -0.005369958, -0.101875767, 0.032986883, 0.09512496, -0.085919321, + 0.005408315, -0.037436277, 0.034367729, 0.077941097, -0.04756248, 0.000110276, -0.02792378, -0.059836671, + 0.02086612, 0.060450379, -0.045107644, 0.002627444, 0.081623361, 0.054313287, -0.022400394, 0.065053202, + 0.074565701, 0.04081168, -0.021786686, 0.044493936, 0.073338278, 0.003221974, 0.001419203, 0.00740287, + ], + int8: [ + 2, -29, 12, -11, 18, 0, -11, -43, -11, -38, 2, -4, -30, 16, 7, -5, 19, 37, 35, 18, 27, -32, -19, -40, -20, 2, + -13, 31, 28, 10, 11, 11, 0, 26, 9, -7, -7, 0, 4, -15, -15, -17, 33, -10, 9, -12, -35, 24, -11, -5, 9, -12, 20, + 20, -9, 11, 0, -33, 50, -29, -4, -5, 2, 55, -7, 8, 13, 17, -8, 16, 0, -15, 30, 14, 12, -27, -19, -6, -28, 43, + -3, 3, 22, 21, -15, -33, -16, -27, 16, -14, 24, 14, -27, 42, 14, 9, 6, 10, 21, -1, -31, -19, -25, 15, -1, 0, -5, + -17, -22, 17, -8, -9, -10, -58, 8, 7, 15, -25, 4, 5, 14, 8, 54, -12, 0, 11, -9, 6, 29, -1, 16, 4, 14, 41, -9, + -2, -32, 31, 1, 0, 0, -53, 5, 15, 14, 2, -5, 13, 0, -14, 1, 5, -9, -9, 13, 0, -1, -15, -20, 12, 14, 7, 17, 7, + 28, -10, 17, -20, -15, 7, 28, -10, -2, -11, -6, 12, -5, -12, 9, -18, -2, -21, -4, 0, -7, 14, 15, 13, -9, 3, -14, + -2, -12, -36, 1, -34, -11, -41, 10, -24, 6, 24, 10, 1, -10, 3, -10, 9, -4, -27, -6, 5, 5, 10, -5, -3, -13, 25, + 11, 23, 0, 4, 11, 28, -2, -3, 17, 8, -34, -5, 19, -43, -13, -32, 13, 0, -16, -15, -16, -1, -11, -23, 8, 1, 2, 1, + 8, 27, 31, 14, -3, 0, 12, 10, 7, 18, + ], + uint8: [ + 129, 98, 139, 115, 146, 127, 116, 84, 115, 88, 130, 123, 96, 143, 135, 122, 147, 165, 163, 146, 154, 94, 107, + 86, 106, 129, 114, 158, 155, 137, 138, 139, 127, 153, 137, 120, 119, 128, 132, 111, 112, 109, 161, 117, 136, + 115, 91, 152, 115, 121, 136, 115, 148, 148, 118, 139, 128, 94, 178, 98, 123, 121, 129, 182, 119, 135, 141, 145, + 118, 143, 128, 111, 158, 142, 139, 100, 107, 121, 99, 170, 123, 130, 150, 148, 111, 94, 110, 99, 143, 112, 151, + 142, 100, 169, 142, 137, 133, 138, 149, 125, 96, 107, 101, 143, 125, 128, 122, 109, 104, 145, 118, 117, 116, 68, + 136, 134, 143, 101, 131, 132, 142, 135, 182, 115, 127, 138, 118, 133, 157, 126, 144, 131, 142, 168, 117, 124, + 94, 158, 129, 127, 126, 73, 133, 142, 141, 129, 122, 141, 126, 113, 129, 133, 117, 117, 141, 128, 125, 112, 106, + 140, 141, 135, 145, 135, 155, 116, 144, 106, 111, 135, 156, 117, 124, 115, 121, 140, 122, 114, 136, 108, 125, + 105, 123, 127, 119, 142, 142, 141, 118, 131, 112, 125, 114, 90, 129, 93, 116, 85, 137, 102, 134, 152, 138, 128, + 117, 131, 117, 137, 122, 99, 120, 132, 132, 137, 122, 123, 114, 152, 139, 151, 127, 132, 138, 155, 125, 124, + 145, 135, 92, 121, 147, 84, 113, 94, 140, 126, 110, 111, 111, 126, 116, 104, 135, 129, 129, 129, 136, 154, 159, + 141, 123, 127, 140, 138, 134, 146, + ], +} as const; diff --git a/tests/unit/common/search/vectorSearchEmbeddingsManager.test.ts b/tests/unit/common/search/vectorSearchEmbeddingsManager.test.ts index ad6949668..fe5e23c61 100644 --- a/tests/unit/common/search/vectorSearchEmbeddingsManager.test.ts +++ b/tests/unit/common/search/vectorSearchEmbeddingsManager.test.ts @@ -13,6 +13,11 @@ import { ConnectionStateConnected } from "../../../../src/common/connectionManag import type { InsertOneResult } from "mongodb"; import type { DropDatabaseResult } from "@mongosh/service-provider-node-driver/lib/node-driver-service-provider.js"; import EventEmitter from "events"; +import { + type EmbeddingParameters, + type EmbeddingsProvider, + type getEmbeddingsProvider, +} from "../../../../src/common/search/embeddingsProvider.js"; type MockedServiceProvider = NodeDriverServiceProvider & { getSearchIndexes: MockedFunction; @@ -25,6 +30,10 @@ type MockedConnectionManager = ConnectionManager & { currentConnectionState: ConnectionStateConnected; }; +type MockedEmbeddingsProvider = EmbeddingsProvider & { + embed: MockedFunction["embed"]>; +}; + const database = "my" as const; const collection = "collection" as const; const mapKey = `${database}.${collection}` as EmbeddingNamespace; @@ -78,6 +87,14 @@ describe("VectorSearchEmbeddingsManager", () => { getURI: () => "mongodb://my-test", } as unknown as MockedServiceProvider; + const embeddingsProvider: MockedEmbeddingsProvider = { + embed: vi.fn(), + }; + + const getMockedEmbeddingsProvider: typeof getEmbeddingsProvider = () => { + return embeddingsProvider; + }; + const connectionManager: MockedConnectionManager = { currentConnectionState: new ConnectionStateConnected(provider), events: eventEmitter, @@ -85,6 +102,7 @@ describe("VectorSearchEmbeddingsManager", () => { beforeEach(() => { provider.getSearchIndexes.mockReset(); + embeddingsProvider.embed.mockReset(); provider.createSearchIndexes.mockResolvedValue([]); provider.insertOne.mockResolvedValue({} as unknown as InsertOneResult); @@ -371,4 +389,117 @@ describe("VectorSearchEmbeddingsManager", () => { }); }); }); + + describe("generate embeddings", () => { + const embeddingToGenerate = { + database: "mydb", + collection: "mycoll", + path: "embedding_field", + rawValues: ["oops"], + embeddingParameters: { model: "voyage-3-large", outputDimension: 1024, outputDType: "float" } as const, + inputType: "query" as const, + }; + + let embeddings: VectorSearchEmbeddingsManager; + + beforeEach(() => { + embeddings = new VectorSearchEmbeddingsManager( + embeddingValidationDisabled, + connectionManager, + new Map(), + getMockedEmbeddingsProvider + ); + }); + + describe("when atlas search is not available", () => { + beforeEach(() => { + embeddings = new VectorSearchEmbeddingsManager( + embeddingValidationEnabled, + connectionManager, + new Map(), + getMockedEmbeddingsProvider + ); + + provider.getSearchIndexes.mockRejectedValue(new Error()); + }); + + it("throws an exception", async () => { + await expect(embeddings.generateEmbeddings(embeddingToGenerate)).rejects.toThrowError(); + }); + }); + + describe("when atlas search is available", () => { + describe("when embedding validation is disabled", () => { + beforeEach(() => { + embeddings = new VectorSearchEmbeddingsManager( + embeddingValidationDisabled, + connectionManager, + new Map(), + getMockedEmbeddingsProvider + ); + }); + + describe("when no index is available for path", () => { + it("returns the embeddings as is", async () => { + embeddingsProvider.embed.mockResolvedValue([[0xc0ffee]]); + + const [result] = await embeddings.generateEmbeddings(embeddingToGenerate); + expect(result).toEqual([0xc0ffee]); + }); + }); + }); + + describe("when embedding validation is enabled", () => { + beforeEach(() => { + embeddings = new VectorSearchEmbeddingsManager( + embeddingValidationEnabled, + connectionManager, + new Map(), + getMockedEmbeddingsProvider + ); + }); + + describe("when no index is available for path", () => { + it("throws an exception", async () => { + await expect(embeddings.generateEmbeddings(embeddingToGenerate)).rejects.toThrowError(); + }); + }); + + describe("when index is available on path", () => { + beforeEach(() => { + provider.getSearchIndexes.mockResolvedValue([ + { + id: "65e8c766d0450e3e7ab9855f", + name: "vector-search-test", + type: "vectorSearch", + status: "READY", + queryable: true, + latestDefinition: { + fields: [ + { + type: "vector", + path: embeddingToGenerate.path, + numDimensions: 1024, + similarity: "euclidean", + }, + { type: "filter", path: "genres" }, + { type: "filter", path: "year" }, + ], + }, + }, + ]); + }); + + describe("when embedding validation is disabled", () => { + it("returns the embeddings as is", async () => { + embeddingsProvider.embed.mockResolvedValue([[0xc0ffee]]); + + const [result] = await embeddings.generateEmbeddings(embeddingToGenerate); + expect(result).toEqual([0xc0ffee]); + }); + }); + }); + }); + }); + }); });