Skip to content
53 changes: 39 additions & 14 deletions src/common/connectionManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,33 @@ export interface ConnectionState {
connectedAtlasCluster?: AtlasClusterConnectionInfo;
}

export interface ConnectionStateConnected extends ConnectionState {
tag: "connected";
serviceProvider: NodeDriverServiceProvider;
export class ConnectionStateConnected implements ConnectionState {
public tag = "connected" as const;

constructor(
public serviceProvider: NodeDriverServiceProvider,
public connectionStringAuthType?: ConnectionStringAuthType,
public connectedAtlasCluster?: AtlasClusterConnectionInfo
) {}

private _isSearchSupported?: boolean;

public async isSearchSupported(): Promise<boolean> {
if (this._isSearchSupported === undefined) {
try {
const dummyDatabase = `search-index-test-db-${Date.now()}`;
const dummyCollection = `search-index-test-coll-${Date.now()}`;
// If a cluster supports search indexes, the call below will succeed
// with a cursor otherwise will throw an Error
await this.serviceProvider.getSearchIndexes(dummyDatabase, dummyCollection);
this._isSearchSupported = true;
} catch {
this._isSearchSupported = false;
}
}

return this._isSearchSupported;
}
}

export interface ConnectionStateConnecting extends ConnectionState {
Expand Down Expand Up @@ -199,12 +223,10 @@ export class MCPConnectionManager extends ConnectionManager {
});
}

return this.changeState("connection-success", {
tag: "connected",
connectedAtlasCluster: settings.atlas,
serviceProvider: await serviceProvider,
connectionStringAuthType,
});
return this.changeState(
"connection-success",
new ConnectionStateConnected(await serviceProvider, connectionStringAuthType, settings.atlas)
);
} catch (error: unknown) {
const errorReason = error instanceof Error ? error.message : `${error as string}`;
this.changeState("connection-error", {
Expand Down Expand Up @@ -270,11 +292,14 @@ export class MCPConnectionManager extends ConnectionManager {
this.currentConnectionState.tag === "connecting" &&
this.currentConnectionState.connectionStringAuthType?.startsWith("oidc")
) {
this.changeState("connection-success", {
...this.currentConnectionState,
tag: "connected",
serviceProvider: await this.currentConnectionState.serviceProvider,
});
this.changeState(
"connection-success",
new ConnectionStateConnected(
await this.currentConnectionState.serviceProvider,
this.currentConnectionState.connectionStringAuthType,
this.currentConnectionState.connectedAtlasCluster
)
);
}

this.logger.info({
Expand Down
22 changes: 9 additions & 13 deletions src/common/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,15 @@ export class Session extends EventEmitter<SessionEvents> {
return this.connectionManager.currentConnectionState.tag === "connected";
}

get isConnectedToMongot(): Promise<boolean> {
const state = this.connectionManager.currentConnectionState;
if (state.tag === "connected") {
return state.isSearchSupported();
}

return Promise.resolve(false);
}

get serviceProvider(): NodeDriverServiceProvider {
if (this.isConnectedToMongoDB) {
const state = this.connectionManager.currentConnectionState as ConnectionStateConnected;
Expand All @@ -153,17 +162,4 @@ export class Session extends EventEmitter<SessionEvents> {
get connectedAtlasCluster(): AtlasClusterConnectionInfo | undefined {
return this.connectionManager.currentConnectionState.connectedAtlasCluster;
}

async isSearchIndexSupported(): Promise<boolean> {
try {
const dummyDatabase = `search-index-test-db-${Date.now()}`;
const dummyCollection = `search-index-test-coll-${Date.now()}`;
// If a cluster supports search indexes, the call below will succeed
// with a cursor otherwise will throw an Error
await this.serviceProvider.getSearchIndexes(dummyDatabase, dummyCollection);
return true;
} catch {
return false;
}
}
}
2 changes: 1 addition & 1 deletion src/resources/common/debug.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ export class DebugResource extends ReactiveResource<

switch (this.current.tag) {
case "connected": {
const searchIndexesSupported = await this.session.isSearchIndexSupported();
const searchIndexesSupported = await this.session.isConnectedToMongot;
result += `The user is connected to the MongoDB cluster${searchIndexesSupported ? " with support for search indexes" : " without any support for search indexes"}.`;
break;
}
Expand Down
132 changes: 123 additions & 9 deletions src/tools/mongodb/create/createIndex.ts
Original file line number Diff line number Diff line change
@@ -1,33 +1,147 @@
import { z } from "zod";
import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
import type { ToolArgs, OperationType } from "../../tool.js";
import type { ToolCategory } from "../../tool.js";
import { type ToolArgs, type OperationType, FeatureFlags } from "../../tool.js";
import type { IndexDirection } from "mongodb";

const vectorSearchIndexDefinition = z.object({
type: z.literal("vectorSearch"),
fields: z
.array(
z.discriminatedUnion("type", [
z
.object({
type: z.literal("filter"),
path: z
.string()
.describe(
"Name of the field to index. For nested fields, use dot notation to specify path to embedded fields"
),
})
.strict()
.describe("Definition for a field that will be used for pre-filtering results."),
z
.object({
type: z.literal("vector"),
path: z
.string()
.describe(
"Name of the field to index. For nested fields, use dot notation to specify path to embedded fields"
),
numDimensions: z
.number()
.min(1)
.max(8192)
.describe(
"Number of vector dimensions that MongoDB Vector Search enforces at index-time and query-time"
),
similarity: z
.enum(["cosine", "euclidean", "dotProduct"])
.default("cosine")
.describe(
"Vector similarity function to use to search for top K-nearest neighbors. You can set this field only for vector-type fields."
),
quantization: z
.enum(["none", "scalar", "binary"])
.optional()
.default("none")
.describe(
"Type of automatic vector quantization for your vectors. Use this setting only if your embeddings are float or double vectors."
),
})
.strict()
.describe("Definition for a field that contains vector embeddings."),
])
)
.nonempty()
.refine((fields) => fields.some((f) => f.type === "vector"), {
message: "At least one vector field must be defined",
})
.describe(
"Definitions for the vector and filter fields to index, one definition per document. You must specify `vector` for fields that contain vector embeddings and `filter` for additional fields to filter on. At least one vector-type field definition is required."
),
});

export class CreateIndexTool extends MongoDBToolBase {
public name = "create-index";
protected description = "Create an index for a collection";
protected argsShape = {
...DbOperationArgs,
keys: z.object({}).catchall(z.custom<IndexDirection>()).describe("The index definition"),
name: z.string().optional().describe("The name of the index"),
definition: z
.array(
z.discriminatedUnion("type", [
z.object({
type: z.literal("classic"),
keys: z.object({}).catchall(z.custom<IndexDirection>()).describe("The index definition"),
}),
...(this.isFeatureFlagEnabled(FeatureFlags.VectorSearch) ? [vectorSearchIndexDefinition] : []),
])
)
.describe(
"The index definition. Use 'classic' for standard indexes and 'vectorSearch' for vector search indexes"
),
};

public operationType: OperationType = "create";

protected async execute({
database,
collection,
keys,
name,
definition: definitions,
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
const provider = await this.ensureConnected();
const indexes = await provider.createIndexes(database, collection, [
{
key: keys,
name,
},
]);
let indexes: string[] = [];
const definition = definitions[0];
if (!definition) {
throw new Error("Index definition not provided. Expected one of the following: `classic`, `vectorSearch`");
}

switch (definition.type) {
case "classic":
indexes = await provider.createIndexes(database, collection, [
{
key: definition.keys,
name,
},
]);
break;
case "vectorSearch":
{
const isVectorSearchSupported = await this.session.isConnectedToMongot;
if (!isVectorSearchSupported) {
// TODO: remove hacky casts once we merge the local dev tools
const isLocalAtlasAvailable =
(this.server?.tools.filter((t) => t.category === ("atlas-local" as unknown as ToolCategory))
.length ?? 0) > 0;

const CTA = isLocalAtlasAvailable ? "`atlas-local` tools" : "Atlas CLI";
return {
content: [
{
text: `The connected MongoDB deployment does not support vector search indexes. Either connect to a MongoDB Atlas cluster or use the ${CTA} to create and manage a local Atlas deployment.`,
type: "text",
},
],
isError: true,
};
}

indexes = await provider.createSearchIndexes(database, collection, [
{
name,
definition: {
fields: definition.fields,
},
type: "vectorSearch",
},
]);
}

break;
}

return {
content: [
Expand Down
2 changes: 1 addition & 1 deletion src/tools/mongodb/mongodbTool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export const DbOperationArgs = {
};

export abstract class MongoDBToolBase extends ToolBase {
private server?: Server;
protected server?: Server;
public category: ToolCategory = "mongodb";

protected async ensureConnected(): Promise<NodeDriverServiceProvider> {
Expand Down
14 changes: 14 additions & 0 deletions src/tools/tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ export type ToolCallbackArgs<Args extends ZodRawShape> = Parameters<ToolCallback

export type ToolExecutionContext<Args extends ZodRawShape = ZodRawShape> = Parameters<ToolCallback<Args>>[1];

export const enum FeatureFlags {
VectorSearch = "vectorSearch",
}

/**
* The type of operation the tool performs. This is used when evaluating if a tool is allowed to run based on
* the config's `disabledTools` and `readOnly` settings.
Expand Down Expand Up @@ -314,6 +318,16 @@ export abstract class ToolBase {

this.telemetry.emitEvents([event]);
}

// TODO: Move this to a separate file
protected isFeatureFlagEnabled(flag: FeatureFlags): boolean {
switch (flag) {
case FeatureFlags.VectorSearch:
return this.config.voyageApiKey !== "";
default:
return false;
}
}
}

/**
Expand Down
Loading
Loading