From 6edcdd9a1a599e6766eddd5de882076f8defc399 Mon Sep 17 00:00:00 2001 From: Rhys Howell Date: Mon, 23 Sep 2024 16:18:23 -0400 Subject: [PATCH 1/6] feat(participant): route generic prompt by intent and update generic prompt --- src/participant/constants.ts | 7 ++ src/participant/participant.ts | 145 +++++++++++++++++++++++------ src/participant/prompts/generic.ts | 16 ++-- src/participant/prompts/intent.ts | 78 ++++++++++++++++ 4 files changed, 211 insertions(+), 35 deletions(-) create mode 100644 src/participant/prompts/intent.ts diff --git a/src/participant/constants.ts b/src/participant/constants.ts index b03e66274..1c5777269 100644 --- a/src/participant/constants.ts +++ b/src/participant/constants.ts @@ -10,6 +10,7 @@ export type ParticipantResponseType = | 'docs' | 'generic' | 'emptyRequest' + | 'cancelledRequest' | 'askToConnect' | 'askForNamespace'; @@ -61,6 +62,12 @@ function createChatResult( }; } +export function createCancelledRequestChatResult( + history: ReadonlyArray +): ChatResult { + return createChatResult('cancelledRequest', history); +} + export function emptyRequestChatResult( history: ReadonlyArray ): ChatResult { diff --git a/src/participant/participant.ts b/src/participant/participant.ts index c662cefdc..087455149 100644 --- a/src/participant/participant.ts +++ b/src/participant/participant.ts @@ -20,6 +20,7 @@ import { queryRequestChatResult, docsRequestChatResult, schemaRequestChatResult, + createCancelledRequestChatResult, } from './constants'; import { QueryPrompt } from './prompts/query'; import { COL_NAME_ID, DB_NAME_ID, NamespacePrompt } from './prompts/namespace'; @@ -41,6 +42,7 @@ import { } from '../telemetry/telemetryService'; import { DocsChatbotAIService } from './docsChatbotAIService'; import type TelemetryService from '../telemetry/telemetryService'; +import { IntentPrompt, type PromptIntent } from './prompts/intent'; const log = createLogger('participant'); @@ -229,8 +231,15 @@ export default class ParticipantController { } } - // @MongoDB what is mongodb? - async handleGenericRequest( + _handleCancelledRequest({ + context, + }: { + context: vscode.ChatContext; + }): ChatResult { + return createCancelledRequestChatResult(context.history); + } + + async _handleRoutedGenericRequest( request: vscode.ChatRequest, context: vscode.ChatContext, stream: vscode.ChatResponseStream, @@ -241,10 +250,6 @@ export default class ParticipantController { context, }); - const abortController = new AbortController(); - token.onCancellationRequested(() => { - abortController.abort(); - }); const responseContent = await this.getChatResponseContent({ messages, token, @@ -259,6 +264,91 @@ export default class ParticipantController { return genericRequestChatResult(context.history); } + async _routeRequestToHandler({ + context, + promptIntent, + request, + stream, + token, + }: { + context: vscode.ChatContext; + promptIntent: Omit; + request: vscode.ChatRequest; + stream: vscode.ChatResponseStream; + token: vscode.CancellationToken; + }): Promise { + switch (promptIntent) { + case 'Query': + return this.handleQueryRequest(request, context, stream, token); + case 'Docs': + return this.handleDocsRequest(request, context, stream, token); + case 'Schema': + return this.handleSchemaRequest(request, context, stream, token); + case 'Code': + return this.handleQueryRequest(request, context, stream, token); + default: + return this._handleRoutedGenericRequest( + request, + context, + stream, + token + ); + } + } + + async _getIntentFromChatRequest({ + context, + request, + token, + }: { + context: vscode.ChatContext; + request: vscode.ChatRequest; + token: vscode.CancellationToken; + }): Promise { + const messages = IntentPrompt.buildMessages({ + request, + context, + }); + + const responseContent = await this.getChatResponseContent({ + messages, + token, + }); + + return IntentPrompt.getIntentFromModelResponse(responseContent); + } + + async handleGenericRequest( + request: vscode.ChatRequest, + context: vscode.ChatContext, + stream: vscode.ChatResponseStream, + token: vscode.CancellationToken + ): Promise { + // We "prompt chain" to handle the generic requests. + // First we ask the model to parse for intent. + // If there is an intent, we can route it to one of the handlers (/commands). + // When there is no intention or it's generic we handle it with a generic handler. + const promptIntent = await this._getIntentFromChatRequest({ + context, + request, + token, + }); + + if (token.isCancellationRequested) { + return this._handleCancelledRequest({ + context, + }); + } + + return this._routeRequestToHandler({ + context, + promptIntent, + request, + stream, + token, + }); + } + async connectWithParticipant({ id, command, @@ -670,17 +760,17 @@ export default class ParticipantController { // The sample documents returned from this are simplified (strings and arrays shortened). // The sample documents are only returned when a user has the setting enabled. async _fetchCollectionSchemaAndSampleDocuments({ - abortSignal, databaseName, collectionName, amountOfDocumentsToSample = NUM_DOCUMENTS_TO_SAMPLE, schemaFormat = 'simplified', + token, }: { - abortSignal; databaseName: string; collectionName: string; amountOfDocumentsToSample?: number; schemaFormat?: 'simplified' | 'full'; + token: vscode.CancellationToken; }): Promise<{ schema?: string; sampleDocuments?: Document[]; @@ -693,6 +783,11 @@ export default class ParticipantController { }; } + const abortController = new AbortController(); + token.onCancellationRequested(() => { + abortController.abort(); + }); + try { const sampleDocuments = await dataService.sample( `${databaseName}.${collectionName}`, @@ -702,7 +797,7 @@ export default class ParticipantController { }, { promoteValues: false }, { - abortSignal, + abortSignal: abortController.signal, } ); @@ -836,10 +931,11 @@ export default class ParticipantController { }); } - const abortController = new AbortController(); - token.onCancellationRequested(() => { - abortController.abort(); - }); + if (token.isCancellationRequested) { + return this._handleCancelledRequest({ + context, + }); + } stream.push( new vscode.ChatResponseProgressPart( @@ -856,11 +952,11 @@ export default class ParticipantController { amountOfDocumentsSampled, // There can be fewer than the amount we attempt to sample. schema, } = await this._fetchCollectionSchemaAndSampleDocuments({ - abortSignal: abortController.signal, databaseName, schemaFormat: 'full', collectionName, amountOfDocumentsToSample: DOCUMENTS_TO_SAMPLE_FOR_SCHEMA_PROMPT, + token, })); if (!schema || amountOfDocumentsSampled === 0) { @@ -958,19 +1054,20 @@ export default class ParticipantController { }); } - const abortController = new AbortController(); - token.onCancellationRequested(() => { - abortController.abort(); - }); + if (token.isCancellationRequested) { + return this._handleCancelledRequest({ + context, + }); + } let schema: string | undefined; let sampleDocuments: Document[] | undefined; try { ({ schema, sampleDocuments } = await this._fetchCollectionSchemaAndSampleDocuments({ - abortSignal: abortController.signal, databaseName, collectionName, + token, })); } catch (e) { // When an error fetching the collection schema or sample docs occurs, @@ -1067,10 +1164,6 @@ export default class ParticipantController { context, }); - const abortController = new AbortController(); - token.onCancellationRequested(() => { - abortController.abort(); - }); const responseContent = await this.getChatResponseContent({ messages, token, @@ -1096,11 +1189,7 @@ export default class ParticipantController { vscode.CancellationToken ] ): Promise { - const [request, context, stream, token] = args; - const abortController = new AbortController(); - token.onCancellationRequested(() => { - abortController.abort(); - }); + const [request, context, stream] = args; const chatId = ChatMetadataStore.getChatIdFromHistoryOrNewChatId( context.history diff --git a/src/participant/prompts/generic.ts b/src/participant/prompts/generic.ts index 733a86fe9..9d890acec 100644 --- a/src/participant/prompts/generic.ts +++ b/src/participant/prompts/generic.ts @@ -5,12 +5,15 @@ import { getHistoryMessages } from './history'; export class GenericPrompt { static getAssistantPrompt(): vscode.LanguageModelChatMessage { const prompt = `You are a MongoDB expert. -Your task is to help the user craft MongoDB queries and aggregation pipelines that perform their task. -Keep your response concise. -You should suggest queries that are performant and correct. -Respond with markdown, suggest code in a Markdown code block that begins with \`\`\`javascript and ends with \`\`\`. -You can imagine the schema, collection, and database name. -Respond in MongoDB shell syntax using the \`\`\`javascript code block syntax.`; +Your task is to help the user with MongoDB related questions. +When applicable, you may suggest MongoDB code, queries, and aggregation pipelines that perform their task. +Rules: +1. Keep your response concise. +2. You should suggest code that is performant and correct. +3. Respond with markdown. +4. When relevant, provide code in a Markdown code block that begins with \`\`\`javascript and ends with \`\`\`. +5. Use MongoDB shell syntax for code unless the user requests a specific language. +6. If you require additional information to provide a response, ask the user for it.`; // eslint-disable-next-line new-cap return vscode.LanguageModelChatMessage.Assistant(prompt); @@ -22,7 +25,6 @@ Respond in MongoDB shell syntax using the \`\`\`javascript code block syntax.`; } static getEmptyRequestResponse(): string { - // TODO(VSCODE-572): Generic empty response handler return vscode.l10n.t( 'Ask anything about MongoDB, from writing queries to questions about your cluster.' ); diff --git a/src/participant/prompts/intent.ts b/src/participant/prompts/intent.ts new file mode 100644 index 000000000..60b9b98a9 --- /dev/null +++ b/src/participant/prompts/intent.ts @@ -0,0 +1,78 @@ +import * as vscode from 'vscode'; + +import { getHistoryMessages } from './history'; + +export type PromptIntent = 'Query' | 'Schema' | 'Docs' | 'Default'; + +export class IntentPrompt { + static getAssistantPrompt(): vscode.LanguageModelChatMessage { + const prompt = `You are a MongoDB expert. +Your task is to help guide a conversation with a user to the correct handler. +You will be provided a conversation and your task is to determine the intent of the user. +The intent handlers are: +- Query +- Schema +- Docs +- Default +Rules: +1. Respond only with the intent handler. +2. Use the "Query" intent handler when the user is asking for code that relates to a specific collection. +3. Use the "Docs" intent handler when the user is asking a question that involves MongoDB documentation. +4. Use the "Schema" intent handler when the user is asking for the schema or shape of documents of a specific collection. +5. Use the "Default" intent handler when a user is asking for code that does NOT relate to a specific collection. +6. Use the "Default" intent handler for everything that may not be handled by another handler. +7. If you are uncertain of the intent, use the "Default" intent handler. + +Example: +User: How do I create an index in my pineapples collection? +Response: +Query + +Example: +User: +What is $vectorSearch? +Response: +Docs +`; + + // eslint-disable-next-line new-cap + return vscode.LanguageModelChatMessage.Assistant(prompt); + } + + static getUserPrompt(prompt: string): vscode.LanguageModelChatMessage { + // eslint-disable-next-line new-cap + return vscode.LanguageModelChatMessage.User(prompt); + } + + static buildMessages({ + context, + request, + }: { + request: { + prompt: string; + }; + context: vscode.ChatContext; + }): vscode.LanguageModelChatMessage[] { + const messages = [ + IntentPrompt.getAssistantPrompt(), + ...getHistoryMessages({ context }), + IntentPrompt.getUserPrompt(request.prompt), + ]; + + return messages; + } + + static getIntentFromModelResponse(response: string): PromptIntent { + response = response.trim(); + switch (response) { + case 'Query': + return 'Query'; + case 'Schema': + return 'Schema'; + case 'Docs': + return 'Docs'; + default: + return 'Default'; + } + } +} From 314ae29376e966fdae907a6d37893094b8a152b7 Mon Sep 17 00:00:00 2001 From: Rhys Howell Date: Tue, 24 Sep 2024 12:29:17 -0400 Subject: [PATCH 2/6] fixup: give explicit use syntax, add accuracy tests --- package.json | 2 +- src/participant/participant.ts | 6 - src/participant/prompts/generic.ts | 4 +- .../ai-accuracy-tests/ai-accuracy-tests.ts | 104 +++++++++++++++++- 4 files changed, 104 insertions(+), 12 deletions(-) diff --git a/package.json b/package.json index d1d102828..65c6f1625 100644 --- a/package.json +++ b/package.json @@ -54,7 +54,7 @@ "test": "npm run test-webview && npm run test-extension", "test-extension": "cross-env NODE_OPTIONS=--no-force-async-hooks-checks xvfb-maybe node ./out/test/runTest.js", "test-webview": "mocha -r ts-node/register --file ./src/test/setup-webview.ts src/test/suite/views/webview-app/**/*.test.tsx", - "ai-accuracy-tests": "mocha -r ts-node/register --file ./src/test/ai-accuracy-tests/test-setup.ts ./src/test/ai-accuracy-tests/ai-accuracy-tests.ts", + "ai-accuracy-tests": "env TS_NODE_FILES=true mocha -r ts-node/register --file ./src/test/ai-accuracy-tests/test-setup.ts ./src/test/ai-accuracy-tests/ai-accuracy-tests.ts", "analyze-bundle": "webpack --mode production --analyze", "vscode:prepublish": "npm run clean && npm run compile:constants && npm run compile:resources && webpack --mode production", "check": "npm run lint && npm run depcheck", diff --git a/src/participant/participant.ts b/src/participant/participant.ts index 51b5e357e..da193fcc2 100644 --- a/src/participant/participant.ts +++ b/src/participant/participant.ts @@ -959,12 +959,6 @@ export default class ParticipantController { }); } - stream.push( - new vscode.ChatResponseProgressPart( - 'Fetching documents and analyzing schema...' - ) - ); - let sampleDocuments: Document[] | undefined; let amountOfDocumentsSampled: number; let schema: string | undefined; diff --git a/src/participant/prompts/generic.ts b/src/participant/prompts/generic.ts index 9d890acec..8cd6972f9 100644 --- a/src/participant/prompts/generic.ts +++ b/src/participant/prompts/generic.ts @@ -13,7 +13,9 @@ Rules: 3. Respond with markdown. 4. When relevant, provide code in a Markdown code block that begins with \`\`\`javascript and ends with \`\`\`. 5. Use MongoDB shell syntax for code unless the user requests a specific language. -6. If you require additional information to provide a response, ask the user for it.`; +6. If you require additional information to provide a response, ask the user for it. +7. When specifying a database, use the MongoDB syntax use('databaseName'). +`; // eslint-disable-next-line new-cap return vscode.LanguageModelChatMessage.Assistant(prompt); diff --git a/src/test/ai-accuracy-tests/ai-accuracy-tests.ts b/src/test/ai-accuracy-tests/ai-accuracy-tests.ts index 712a98647..46a3247d5 100644 --- a/src/test/ai-accuracy-tests/ai-accuracy-tests.ts +++ b/src/test/ai-accuracy-tests/ai-accuracy-tests.ts @@ -22,6 +22,7 @@ import { import { NamespacePrompt } from '../../participant/prompts/namespace'; import { runCodeInMessage } from './assertions'; import { parseForDatabaseAndCollectionName } from '../../participant/participant'; +import { IntentPrompt } from '../../participant/prompts/intent'; const numberOfRunsPerTest = 1; @@ -37,7 +38,7 @@ type AssertProps = { type TestCase = { testCase: string; - type: 'generic' | 'query' | 'namespace'; + type: 'intent' | 'generic' | 'query' | 'namespace'; userInput: string; // Some tests can edit the documents in a collection. // As we want tests to run in isolation this flag will cause the fixture @@ -51,7 +52,9 @@ type TestCase = { only?: boolean; // Translates to mocha's it.only so only this test will run. }; -const namespaceTestCases: TestCase[] = [ +const namespaceTestCases: (TestCase & { + type: 'namespace'; +})[] = [ { testCase: 'Namespace included in query', type: 'namespace', @@ -101,7 +104,9 @@ const namespaceTestCases: TestCase[] = [ }, ]; -const queryTestCases: TestCase[] = [ +const queryTestCases: (TestCase & { + type: 'query'; +})[] = [ { testCase: 'Basic query', type: 'query', @@ -238,7 +243,92 @@ const queryTestCases: TestCase[] = [ }, ]; -const testCases: TestCase[] = [...namespaceTestCases, ...queryTestCases]; +const intentTestCases: (TestCase & { + type: 'intent'; +})[] = [ + { + testCase: 'Docs intent', + type: 'intent', + userInput: + 'Where can I find more information on how to connect to MongoDB?', + assertResult: ({ responseContent }: AssertProps): void => { + expect(responseContent).to.equal('Docs'); + }, + }, + { + testCase: 'Docs intent 2', + type: 'intent', + userInput: 'What are the options when creating an aggregation cursor?', + assertResult: ({ responseContent }: AssertProps): void => { + expect(responseContent).to.equal('Docs'); + }, + }, + { + testCase: 'Query intent', + type: 'intent', + userInput: + 'which collectors specialize only in mint items? and are located in London or New York? an array of their names in a field called collectors', + assertResult: ({ responseContent }: AssertProps): void => { + expect(responseContent).to.equal('Query'); + }, + }, + { + testCase: 'Schema intent', + type: 'intent', + userInput: 'What do the documents in the collection pineapple look like?', + assertResult: ({ responseContent }: AssertProps): void => { + expect(responseContent).to.equal('Schema'); + }, + }, + { + testCase: 'Default/Generic intent 1', + type: 'intent', + userInput: 'How can I connect to MongoDB?', + assertResult: ({ responseContent }: AssertProps): void => { + expect(responseContent).to.equal('Default'); + }, + }, + { + testCase: 'Default/Generic intent 2', + type: 'intent', + userInput: 'What is the size breakdown of all of the databases?', + assertResult: ({ responseContent }: AssertProps): void => { + expect(responseContent).to.equal('Default'); + }, + }, +]; + +const genericTestCases: (TestCase & { + type: 'generic'; +})[] = [ + { + testCase: 'Database meta data question', + type: 'generic', + userInput: + 'How do I print the name and size of the largest database? Using the print function', + assertResult: async ({ + responseContent, + connectionString, + }: AssertProps): Promise => { + const output = await runCodeInMessage(responseContent, connectionString); + const printOutput = output.printOutput.join(''); + + expect(printOutput).to.include('Antiques'); + expect(printOutput).to.not.include('UFO'); + expect(printOutput).to.not.include('CookBook'); + expect(printOutput).to.not.include('pets'); + expect(printOutput).to.not.include('FarmData'); + expect(printOutput).to.include('8192'); // The size of the Antiques database. + }, + }, +]; + +const testCases: TestCase[] = [ + ...namespaceTestCases, + ...queryTestCases, + ...intentTestCases, + ...genericTestCases, +]; const projectRoot = path.join(__dirname, '..', '..', '..'); @@ -310,6 +400,12 @@ const buildMessages = async ({ fixtures: Fixtures; }): Promise => { switch (testCase.type) { + case 'intent': + return IntentPrompt.buildMessages({ + request: { prompt: testCase.userInput }, + context: { history: [] }, + }); + case 'generic': return GenericPrompt.buildMessages({ request: { prompt: testCase.userInput }, From af64f9a7009144c4c49f34e4887a7fee88b1670c Mon Sep 17 00:00:00 2001 From: Rhys Howell Date: Tue, 24 Sep 2024 15:37:08 -0400 Subject: [PATCH 3/6] fixup: update accuracy tests --- .../ai-accuracy-tests/ai-accuracy-tests.ts | 43 ++++++++++++++++--- 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/src/test/ai-accuracy-tests/ai-accuracy-tests.ts b/src/test/ai-accuracy-tests/ai-accuracy-tests.ts index 04d08d327..5603f30b2 100644 --- a/src/test/ai-accuracy-tests/ai-accuracy-tests.ts +++ b/src/test/ai-accuracy-tests/ai-accuracy-tests.ts @@ -20,7 +20,7 @@ import { type TestResult, } from './create-test-results-html-page'; import { NamespacePrompt } from '../../participant/prompts/namespace'; -import { runCodeInMessage } from './assertions'; +import { anyOf, runCodeInMessage } from './assertions'; import { parseForDatabaseAndCollectionName } from '../../participant/participant'; import { IntentPrompt } from '../../participant/prompts/intent'; @@ -313,12 +313,41 @@ const genericTestCases: (TestCase & { const output = await runCodeInMessage(responseContent, connectionString); const printOutput = output.printOutput.join(''); - expect(printOutput).to.include('Antiques'); - expect(printOutput).to.not.include('UFO'); - expect(printOutput).to.not.include('CookBook'); - expect(printOutput).to.not.include('pets'); - expect(printOutput).to.not.include('FarmData'); - expect(printOutput).to.include('8192'); // The size of the Antiques database. + // Don't check the name since they're all the base 8192. + expect(printOutput).to.include('8192'); + }, + }, + { + testCase: 'Code question with database, collection, and fields named', + type: 'generic', + userInput: + 'How many sightings happened in the "year" "2020" and "2021"? database "UFO" collection "sightings". code to just return the one total number. also, the year is a string', + assertResult: async ({ + responseContent, + connectionString, + }: AssertProps): Promise => { + const output = await runCodeInMessage(responseContent, connectionString); + anyOf([ + (): void => { + expect(output.printOutput.join('')).to.equal('2'); + }, + (): void => { + expect(output.data?.result?.content).to.equal('2'); + }, + (): void => { + expect(output.data?.result?.content).to.equal(2); + }, + (): void => { + expect( + Object.entries(output.data?.result?.content[0])[0][1] + ).to.equal(2); + }, + (): void => { + expect( + Object.entries(output.data?.result?.content[0])[0][1] + ).to.equal('2'); + }, + ])(null); }, }, ]; From 9e835b484d195208141667dd90e01ef33e95fe68 Mon Sep 17 00:00:00 2001 From: Rhys Howell Date: Tue, 24 Sep 2024 19:05:16 -0400 Subject: [PATCH 4/6] fixup: avoid extra variable decleration, add tests --- src/participant/prompts/generic.ts | 4 +- src/participant/prompts/intent.ts | 4 +- .../suite/participant/participant.test.ts | 61 ++++++++++++++++++- 3 files changed, 61 insertions(+), 8 deletions(-) diff --git a/src/participant/prompts/generic.ts b/src/participant/prompts/generic.ts index 8cd6972f9..57d72f9b6 100644 --- a/src/participant/prompts/generic.ts +++ b/src/participant/prompts/generic.ts @@ -41,13 +41,11 @@ Rules: }; context: vscode.ChatContext; }): vscode.LanguageModelChatMessage[] { - const messages = [ + return [ GenericPrompt.getAssistantPrompt(), ...getHistoryMessages({ context }), GenericPrompt.getUserPrompt(request.prompt), ]; - - return messages; } } diff --git a/src/participant/prompts/intent.ts b/src/participant/prompts/intent.ts index 60b9b98a9..5dee3ff8c 100644 --- a/src/participant/prompts/intent.ts +++ b/src/participant/prompts/intent.ts @@ -53,13 +53,11 @@ Docs }; context: vscode.ChatContext; }): vscode.LanguageModelChatMessage[] { - const messages = [ + return [ IntentPrompt.getAssistantPrompt(), ...getHistoryMessages({ context }), IntentPrompt.getUserPrompt(request.prompt), ]; - - return messages; } static getIntentFromModelResponse(response: string): PromptIntent { diff --git a/src/test/suite/participant/participant.test.ts b/src/test/suite/participant/participant.test.ts index 6e1e9889e..954dc6ec5 100644 --- a/src/test/suite/participant/participant.test.ts +++ b/src/test/suite/participant/participant.test.ts @@ -435,14 +435,71 @@ suite('Participant Controller Test Suite', function () { }); suite('generic command', function () { - test('generates a query', async function () { + suite('when the intent is recognized', function () { + beforeEach(function () { + sendRequestStub.onCall(0).resolves({ + text: ['Schema'], + }); + }); + + test('routes to the appropriate handler', async function () { + const chatRequestMock = { + prompt: + 'what is the shape of the documents in the pineapple collection?', + command: undefined, + references: [], + }; + const res = await invokeChatHandler(chatRequestMock); + + expect(sendRequestStub).to.have.been.calledTwice; + const intentRequest = sendRequestStub.firstCall.args[0]; + expect(intentRequest).to.have.length(2); + expect(intentRequest[0].content).to.include( + 'Your task is to help guide a conversation with a user to the correct handler.' + ); + expect(intentRequest[1].content).to.equal( + 'what is the shape of the documents in the pineapple collection?' + ); + const genericRequest = sendRequestStub.secondCall.args[0]; + expect(genericRequest).to.have.length(2); + expect(genericRequest[0].content).to.include( + 'Parse all user messages to find a database name and a collection name.' + ); + expect(genericRequest[1].content).to.equal( + 'what is the shape of the documents in the pineapple collection?' + ); + + expect(res?.metadata.intent).to.equal('askForNamespace'); + }); + }); + + test('default handler asks for intent and shows code run actions', async function () { const chatRequestMock = { prompt: 'how to find documents in my collection?', command: undefined, references: [], }; - await invokeChatHandler(chatRequestMock); + const res = await invokeChatHandler(chatRequestMock); + + expect(sendRequestStub).to.have.been.calledTwice; + const intentRequest = sendRequestStub.firstCall.args[0]; + expect(intentRequest).to.have.length(2); + expect(intentRequest[0].content).to.include( + 'Your task is to help guide a conversation with a user to the correct handler.' + ); + expect(intentRequest[1].content).to.equal( + 'how to find documents in my collection?' + ); + const genericRequest = sendRequestStub.secondCall.args[0]; + expect(genericRequest).to.have.length(2); + expect(genericRequest[0].content).to.include( + 'Your task is to help the user with MongoDB related questions.' + ); + expect(genericRequest[1].content).to.equal( + 'how to find documents in my collection?' + ); + expect(res?.metadata.intent).to.equal('generic'); expect(chatStreamStub?.button.getCall(0).args[0]).to.deep.equal({ command: 'mdb.runParticipantQuery', title: '▶️ Run', From d805bcb5cb3ad1f7f0ca38ddc22f0e35d70c642d Mon Sep 17 00:00:00 2001 From: Rhys Howell Date: Wed, 25 Sep 2024 12:13:59 -0400 Subject: [PATCH 5/6] chore: add links to test names in accuracy test results --- .../create-test-results-html-page.ts | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/test/ai-accuracy-tests/create-test-results-html-page.ts b/src/test/ai-accuracy-tests/create-test-results-html-page.ts index f58a50950..c3774a59e 100644 --- a/src/test/ai-accuracy-tests/create-test-results-html-page.ts +++ b/src/test/ai-accuracy-tests/create-test-results-html-page.ts @@ -23,6 +23,9 @@ export type TestOutputs = { [testName: string]: TestOutput; }; +const createTestLinkId = (testName: string): string => + encodeURIComponent(testName.replace(/ /g, '-')); + function getTestResultsTable(testResults: TestResult[]): string { const headers = Object.keys(testResults[0]) .map((key) => `${key}`) @@ -30,8 +33,15 @@ function getTestResultsTable(testResults: TestResult[]): string { const resultRows = testResults .map((result) => { - const row = Object.values(result) - .map((value) => `${value}`) + const row = Object.entries(result) + .map( + ([field, value]) => + `${ + field === 'Test' + ? `${value}` + : value + }` + ) .join(''); return `${row}`; }) @@ -56,7 +66,9 @@ function getTestOutputTables(testOutputs: TestOutputs): string { .map((out) => `${out}`) .join(''); return ` -

${testName} [${output.testType}]

+

Prompt: ${output.prompt}

From 1fde75cb071871c0743cea01d257273b48f254fe Mon Sep 17 00:00:00 2001 From: Rhys Howell Date: Wed, 25 Sep 2024 12:57:45 -0400 Subject: [PATCH 6/6] add complex aggregation for query and generic --- .../ai-accuracy-tests/ai-accuracy-tests.ts | 102 ++++++++++++++---- .../ai-accuracy-tests/fixtures/recipes.ts | 15 +++ 2 files changed, 98 insertions(+), 19 deletions(-) diff --git a/src/test/ai-accuracy-tests/ai-accuracy-tests.ts b/src/test/ai-accuracy-tests/ai-accuracy-tests.ts index 5603f30b2..11382b4a1 100644 --- a/src/test/ai-accuracy-tests/ai-accuracy-tests.ts +++ b/src/test/ai-accuracy-tests/ai-accuracy-tests.ts @@ -241,6 +241,43 @@ const queryTestCases: (TestCase & { expect(output.data?.result?.content[0].collectors).to.include('Monkey'); }, }, + { + testCase: 'Complex aggregation with string and number manipulation', + type: 'query', + databaseName: 'CookBook', + collectionName: 'recipes', + userInput: + 'what percentage of recipes have "salt" in their ingredients? "ingredients" is a field ' + + 'with an array of strings of the ingredients. Only consider recipes ' + + 'that have the "difficulty Medium or Easy. Return is as a string named "saltPercentage" like ' + + '"75%", rounded to the nearest whole number.', + assertResult: async ({ + responseContent, + connectionString, + }: AssertProps): Promise => { + const output = await runCodeInMessage(responseContent, connectionString); + + anyOf([ + (): void => { + const lines = responseContent.trim().split('\n'); + const lastLine = lines[lines.length - 1]; + + expect(lastLine).to.include('saltPercentage'); + expect(output.data?.result?.content).to.include('67%'); + }, + (): void => { + expect(output.printOutput[output.printOutput.length - 1]).to.equal( + "{ saltPercentage: '67%' }" + ); + }, + (): void => { + expect(output.data?.result?.content[0].saltPercentage).to.equal( + '67%' + ); + }, + ])(null); + }, + }, ]; const intentTestCases: (TestCase & { @@ -350,6 +387,23 @@ const genericTestCases: (TestCase & { ])(null); }, }, + { + testCase: 'Complex aggregation code generation', + type: 'generic', + userInput: + 'what percentage of recipes have "salt" in their ingredients? "ingredients" is a field ' + + 'with an array of strings of the ingredients. Only consider recipes ' + + 'that have the "difficulty Medium or Easy. Return is as a string named "saltPercentage" like ' + + '"75%", rounded to the nearest whole number. db CookBook, collection recipes', + assertResult: async ({ + responseContent, + connectionString, + }: AssertProps): Promise => { + const output = await runCodeInMessage(responseContent, connectionString); + + expect(output.data?.result?.content[0].saltPercentage).to.equal('67%'); + }, + }, ]; const testCases: TestCase[] = [ @@ -588,12 +642,14 @@ describe('AI Accuracy Tests', function () { testFunction( `should pass for input: "${testCase.userInput}" if average accuracy is above threshold`, - // eslint-disable-next-line no-loop-func + // eslint-disable-next-line no-loop-func, complexity async function () { console.log(`Starting test run of ${testCase.testCase}.`); const testRunDidSucceed: boolean[] = []; - const successFullRunStats: { + // Successful and unsuccessful runs are both tracked as long as the model + // returns a response. + const runStats: { promptTokens: number; completionTokens: number; executionTimeMS: number; @@ -620,12 +676,15 @@ describe('AI Accuracy Tests', function () { } const startTime = Date.now(); + let responseContent: ChatCompletion | undefined; + let executionTimeMS = 0; try { - const responseContent = await runTest({ + responseContent = await runTest({ testCase, aiBackend, fixtures, }); + executionTimeMS = Date.now() - startTime; testOutputs[testCase.testCase].outputs.push( responseContent.content ); @@ -636,11 +695,6 @@ describe('AI Accuracy Tests', function () { mongoClient, }); - successFullRunStats.push({ - completionTokens: responseContent.usageStats.completionTokens, - promptTokens: responseContent.usageStats.promptTokens, - executionTimeMS: Date.now() - startTime, - }); success = true; console.log( @@ -653,6 +707,18 @@ describe('AI Accuracy Tests', function () { ); } + if ( + responseContent && + responseContent?.usageStats?.completionTokens > 0 && + executionTimeMS !== 0 + ) { + runStats.push({ + completionTokens: responseContent.usageStats.completionTokens, + promptTokens: responseContent.usageStats.promptTokens, + executionTimeMS, + }); + } + testRunDidSucceed.push(success); } @@ -673,21 +739,19 @@ describe('AI Accuracy Tests', function () { Accuracy: averageAccuracy, Pass: didFail ? '✗' : '✓', 'Avg Execution Time (ms)': - successFullRunStats.length > 0 - ? successFullRunStats.reduce((a, b) => a + b.executionTimeMS, 0) / - successFullRunStats.length + runStats.length > 0 + ? runStats.reduce((a, b) => a + b.executionTimeMS, 0) / + runStats.length : 0, 'Avg Prompt Tokens': - successFullRunStats.length > 0 - ? successFullRunStats.reduce((a, b) => a + b.promptTokens, 0) / - successFullRunStats.length + runStats.length > 0 + ? runStats.reduce((a, b) => a + b.promptTokens, 0) / + runStats.length : 0, 'Avg Completion Tokens': - successFullRunStats.length > 0 - ? successFullRunStats.reduce( - (a, b) => a + b.completionTokens, - 0 - ) / successFullRunStats.length + runStats.length > 0 + ? runStats.reduce((a, b) => a + b.completionTokens, 0) / + runStats.length : 0, }); diff --git a/src/test/ai-accuracy-tests/fixtures/recipes.ts b/src/test/ai-accuracy-tests/fixtures/recipes.ts index efb347bb3..f9e8d7346 100644 --- a/src/test/ai-accuracy-tests/fixtures/recipes.ts +++ b/src/test/ai-accuracy-tests/fixtures/recipes.ts @@ -12,6 +12,7 @@ const recipes: Fixture = { 'tomato sauce', 'onions', 'garlic', + 'salt', ], preparationTime: 60, difficulty: 'Medium', @@ -23,6 +24,19 @@ const recipes: Fixture = { preparationTime: 10, difficulty: 'Easy', }, + { + title: 'Pineapple', + ingredients: ['pineapple'], + preparationTime: 5, + difficulty: 'Very Hard', + }, + { + title: 'Pizza', + ingredients: ['dough', 'tomato sauce', 'mozzarella cheese', 'basil'], + optionalIngredients: ['pineapple'], + preparationTime: 40, + difficulty: 'Medium', + }, { title: 'Beef Wellington', ingredients: [ @@ -30,6 +44,7 @@ const recipes: Fixture = { 'mushroom duxelles', 'puff pastry', 'egg wash', + 'salt', ], preparationTime: 120, difficulty: 'Hard',