-
Notifications
You must be signed in to change notification settings - Fork 133
chore: Add new session-level service for getting embeddings of a specific collection MCP-246 #626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8ac71ba
cb52116
082fce9
ed7a16e
d68deee
32fe96d
2e013f8
998cf1b
81f9ddd
0a1c789
c68e4ad
539c4a5
44a3ce8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,186 @@ | ||||||||||||||||||||||||
import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; | ||||||||||||||||||||||||
import { BSON, type Document } from "bson"; | ||||||||||||||||||||||||
import type { UserConfig } from "../config.js"; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
export type VectorFieldIndexDefinition = { | ||||||||||||||||||||||||
type: "vector"; | ||||||||||||||||||||||||
path: string; | ||||||||||||||||||||||||
numDimensions: number; | ||||||||||||||||||||||||
quantization: "none" | "scalar" | "binary"; | ||||||||||||||||||||||||
similarity: "euclidean" | "cosine" | "dotProduct"; | ||||||||||||||||||||||||
}; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
export type EmbeddingNamespace = `${string}.${string}`; | ||||||||||||||||||||||||
export class VectorSearchEmbeddings { | ||||||||||||||||||||||||
constructor( | ||||||||||||||||||||||||
private readonly config: UserConfig, | ||||||||||||||||||||||||
private readonly embeddings: Map<EmbeddingNamespace, VectorFieldIndexDefinition[]> = new Map(), | ||||||||||||||||||||||||
private readonly atlasSearchStatus: Map<string, boolean> = new Map() | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We probably want to listen to connection changes (here or in Session) to clear the orphan entries. |
||||||||||||||||||||||||
) {} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
cleanupEmbeddingsForNamespace({ database, collection }: { database: string; collection: string }): void { | ||||||||||||||||||||||||
const embeddingDefKey: EmbeddingNamespace = `${database}.${collection}`; | ||||||||||||||||||||||||
this.embeddings.delete(embeddingDefKey); | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
async embeddingsForNamespace({ | ||||||||||||||||||||||||
database, | ||||||||||||||||||||||||
collection, | ||||||||||||||||||||||||
provider, | ||||||||||||||||||||||||
}: { | ||||||||||||||||||||||||
database: string; | ||||||||||||||||||||||||
collection: string; | ||||||||||||||||||||||||
provider: NodeDriverServiceProvider; | ||||||||||||||||||||||||
}): Promise<VectorFieldIndexDefinition[]> { | ||||||||||||||||||||||||
if (!(await this.isAtlasSearchAvailable(provider))) { | ||||||||||||||||||||||||
return []; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
// We only need the embeddings for validation now, so don't query them if | ||||||||||||||||||||||||
// validation is disabled. | ||||||||||||||||||||||||
if (this.config.disableEmbeddingsValidation) { | ||||||||||||||||||||||||
return []; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
const embeddingDefKey: EmbeddingNamespace = `${database}.${collection}`; | ||||||||||||||||||||||||
const definition = this.embeddings.get(embeddingDefKey); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if (!definition) { | ||||||||||||||||||||||||
const allSearchIndexes = await provider.getSearchIndexes(database, collection); | ||||||||||||||||||||||||
const vectorSearchIndexes = allSearchIndexes.filter((index) => index.type === "vectorSearch"); | ||||||||||||||||||||||||
const vectorFields = vectorSearchIndexes | ||||||||||||||||||||||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-member-access | ||||||||||||||||||||||||
.flatMap<Document>((index) => (index.latestDefinition?.fields as Document) ?? []) | ||||||||||||||||||||||||
.filter((field) => this.isVectorFieldIndexDefinition(field)); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
this.embeddings.set(embeddingDefKey, vectorFields); | ||||||||||||||||||||||||
return vectorFields; | ||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||
return definition; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
Comment on lines
+58
to
+60
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
async findFieldsWithWrongEmbeddings( | ||||||||||||||||||||||||
{ | ||||||||||||||||||||||||
database, | ||||||||||||||||||||||||
collection, | ||||||||||||||||||||||||
provider, | ||||||||||||||||||||||||
}: { | ||||||||||||||||||||||||
database: string; | ||||||||||||||||||||||||
collection: string; | ||||||||||||||||||||||||
provider: NodeDriverServiceProvider; | ||||||||||||||||||||||||
}, | ||||||||||||||||||||||||
document: Document | ||||||||||||||||||||||||
): Promise<VectorFieldIndexDefinition[]> { | ||||||||||||||||||||||||
if (!(await this.isAtlasSearchAvailable(provider))) { | ||||||||||||||||||||||||
return []; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
// While we can do our best effort to ensure that the embedding validation is correct | ||||||||||||||||||||||||
// based on https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-quantization/ | ||||||||||||||||||||||||
// it's a complex process so we will also give the user the ability to disable this validation | ||||||||||||||||||||||||
if (this.config.disableEmbeddingsValidation) { | ||||||||||||||||||||||||
return []; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
const embeddings = await this.embeddingsForNamespace({ database, collection, provider }); | ||||||||||||||||||||||||
return embeddings.filter((emb) => !this.documentPassesEmbeddingValidation(emb, document)); | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
async isAtlasSearchAvailable(provider: NodeDriverServiceProvider): Promise<boolean> { | ||||||||||||||||||||||||
const providerUri = provider.getURI(); | ||||||||||||||||||||||||
if (!providerUri) { | ||||||||||||||||||||||||
// no URI? can't be cached | ||||||||||||||||||||||||
return await this.canListAtlasSearchIndexes(provider); | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if (this.atlasSearchStatus.has(providerUri)) { | ||||||||||||||||||||||||
// has should ensure that get is always defined | ||||||||||||||||||||||||
return this.atlasSearchStatus.get(providerUri) ?? false; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
const availability = await this.canListAtlasSearchIndexes(provider); | ||||||||||||||||||||||||
this.atlasSearchStatus.set(providerUri, availability); | ||||||||||||||||||||||||
return availability; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
private isVectorFieldIndexDefinition(doc: Document): doc is VectorFieldIndexDefinition { | ||||||||||||||||||||||||
return doc["type"] === "vector"; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
private documentPassesEmbeddingValidation(definition: VectorFieldIndexDefinition, document: Document): boolean { | ||||||||||||||||||||||||
const fieldPath = definition.path.split("."); | ||||||||||||||||||||||||
let fieldRef: unknown = document; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
for (const field of fieldPath) { | ||||||||||||||||||||||||
if (fieldRef && typeof fieldRef === "object" && field in fieldRef) { | ||||||||||||||||||||||||
fieldRef = (fieldRef as Record<string, unknown>)[field]; | ||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||
return true; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
switch (definition.quantization) { | ||||||||||||||||||||||||
case "none": | ||||||||||||||||||||||||
return true; | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The validation always returns true for 'none' quantization without checking array dimensions or type. This could allow invalid embeddings to pass validation.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As we are not implementing all the proper validations per the scope of the project, and this actual validation is wrong (not only arrays are valid, also different types of BinData) we will need to trust the user on this one. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Its worth adding this as a comment. |
||||||||||||||||||||||||
case "scalar": | ||||||||||||||||||||||||
case "binary": | ||||||||||||||||||||||||
if (fieldRef instanceof BSON.Binary) { | ||||||||||||||||||||||||
try { | ||||||||||||||||||||||||
const elements = fieldRef.toFloat32Array(); | ||||||||||||||||||||||||
return elements.length === definition.numDimensions; | ||||||||||||||||||||||||
} catch { | ||||||||||||||||||||||||
// bits are also supported | ||||||||||||||||||||||||
try { | ||||||||||||||||||||||||
const bits = fieldRef.toBits(); | ||||||||||||||||||||||||
return bits.length === definition.numDimensions; | ||||||||||||||||||||||||
} catch { | ||||||||||||||||||||||||
return false; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||
if (!Array.isArray(fieldRef)) { | ||||||||||||||||||||||||
return false; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if (fieldRef.length !== definition.numDimensions) { | ||||||||||||||||||||||||
return false; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if (!fieldRef.every((e) => this.isANumber(e))) { | ||||||||||||||||||||||||
return false; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
break; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
return true; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
private async canListAtlasSearchIndexes(provider: NodeDriverServiceProvider): Promise<boolean> { | ||||||||||||||||||||||||
try { | ||||||||||||||||||||||||
await provider.getSearchIndexes("test", "test"); | ||||||||||||||||||||||||
return true; | ||||||||||||||||||||||||
} catch { | ||||||||||||||||||||||||
return false; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
Comment on lines
+162
to
+167
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we consider that atlas local deployments might take time to boot up search management service. This might become a common case as soon as we provide local-atlas tools. |
||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
private isANumber(value: unknown): boolean { | ||||||||||||||||||||||||
if (typeof value === "number") { | ||||||||||||||||||||||||
return true; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if ( | ||||||||||||||||||||||||
value instanceof BSON.Int32 || | ||||||||||||||||||||||||
value instanceof BSON.Decimal128 || | ||||||||||||||||||||||||
value instanceof BSON.Double || | ||||||||||||||||||||||||
value instanceof BSON.Long | ||||||||||||||||||||||||
) { | ||||||||||||||||||||||||
return true; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
return false; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
import { z } from "zod"; | ||
import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; | ||
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; | ||
import type { ToolArgs, OperationType } from "../../tool.js"; | ||
import { type ToolArgs, type OperationType, formatUntrustedData } from "../../tool.js"; | ||
import { zEJSON } from "../../args.js"; | ||
|
||
export class InsertManyTool extends MongoDBToolBase { | ||
|
@@ -23,19 +23,42 @@ export class InsertManyTool extends MongoDBToolBase { | |
documents, | ||
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> { | ||
const provider = await this.ensureConnected(); | ||
const result = await provider.insertMany(database, collection, documents); | ||
|
||
const embeddingValidations = new Set( | ||
...(await Promise.all( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A small note about Promise here, the only reason When we're already in a tool call, is there a possibility to assume that Just to be clear - what we have is not a problem. I see a small scope of improvement because I had to re-look what async task |
||
documents.flatMap((document) => | ||
this.session.vectorSearchEmbeddings.findFieldsWithWrongEmbeddings( | ||
{ database, collection, provider }, | ||
document | ||
) | ||
) | ||
)) | ||
kmruiz marked this conversation as resolved.
Show resolved
Hide resolved
kmruiz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
); | ||
|
||
if (embeddingValidations.size > 0) { | ||
// tell the LLM what happened | ||
const embeddingValidationMessages = [...embeddingValidations].map( | ||
(validation) => | ||
`- Field ${validation.path} is an embedding with ${validation.numDimensions} dimensions and ${validation.quantization} quantization, and the provided value is not compatible.` | ||
); | ||
|
||
return { | ||
content: formatUntrustedData( | ||
"There were errors when inserting documents. No document was inserted.", | ||
...embeddingValidationMessages | ||
), | ||
isError: true, | ||
}; | ||
} | ||
|
||
const result = await provider.insertMany(database, collection, documents); | ||
const content = formatUntrustedData( | ||
"Documents were inserted successfully.", | ||
`Inserted \`${result.insertedCount}\` document(s) into ${database}.${collection}.`, | ||
`Inserted IDs: ${Object.values(result.insertedIds).join(", ")}` | ||
); | ||
return { | ||
content: [ | ||
{ | ||
text: `Inserted \`${result.insertedCount}\` document(s) into collection "${collection}"`, | ||
type: "text", | ||
}, | ||
{ | ||
text: `Inserted IDs: ${Object.values(result.insertedIds).join(", ")}`, | ||
type: "text", | ||
}, | ||
], | ||
content, | ||
}; | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.