From fb6c21c5121caae3b84eb129d15a28b67ce29b52 Mon Sep 17 00:00:00 2001 From: Sergey Petushkov Date: Thu, 20 Jul 2023 07:50:42 +0200 Subject: [PATCH 1/2] chore(atlas-service, query-bar): move ai endpoint handling to atlas-service package --- package-lock.json | 8 +- packages/atlas-service/package.json | 2 + packages/atlas-service/src/main.spec.ts | 188 +++++++++++- packages/atlas-service/src/main.ts | 117 +++++++- packages/atlas-service/src/renderer.ts | 14 +- packages/atlas-service/src/util.ts | 6 + packages/compass-query-bar/package.json | 2 - .../components/query-history/index.spec.tsx | 47 ++- .../src/modules/ai-query-request.spec.ts | 270 ------------------ .../src/modules/ai-query-request.ts | 105 ------- .../src/stores/ai-query-reducer.spec.ts | 192 +++---------- .../src/stores/ai-query-reducer.ts | 7 +- .../src/stores/query-bar-store.ts | 39 ++- .../test/create-mock-ai-endpoint.ts | 108 ------- 14 files changed, 401 insertions(+), 704 deletions(-) delete mode 100644 packages/compass-query-bar/src/modules/ai-query-request.spec.ts delete mode 100644 packages/compass-query-bar/src/modules/ai-query-request.ts delete mode 100644 packages/compass-query-bar/test/create-mock-ai-endpoint.ts diff --git a/package-lock.json b/package-lock.json index 23895a1307c..35dc9643e74 100644 --- a/package-lock.json +++ b/package-lock.json @@ -39819,6 +39819,8 @@ "depcheck": "^1.4.1", "eslint": "^7.25.0", "mocha": "^10.2.0", + "mongodb": "^5.7.0", + "mongodb-schema": "^11.2.1", "nyc": "^15.1.0", "prettier": "^2.7.1", "sinon": "^9.2.3", @@ -43281,7 +43283,6 @@ "@testing-library/react": "^12.1.4", "@testing-library/user-event": "^13.5.0", "chai": "^4.2.0", - "chai-as-promised": "^7.1.1", "depcheck": "^1.4.1", "electron": "^23.3.9", "eslint": "^7.25.0", @@ -43294,7 +43295,6 @@ "mongodb-query-parser": "^2.5.0", "mongodb-query-util": "^2.0.0", "mongodb-schema": "^11.2.1", - "node-fetch": "^2.6.7", "nyc": "^15.1.0", "react": "^17.0.2", "react-dom": "^17.0.2", @@ -54485,6 +54485,8 @@ "electron": "^23.3.9", "eslint": "^7.25.0", "mocha": "^10.2.0", + "mongodb": "^5.7.0", + "mongodb-schema": "^11.2.1", "node-fetch": "^2.6.7", "nyc": "^15.1.0", "prettier": "^2.7.1", @@ -55897,7 +55899,6 @@ "@testing-library/user-event": "^13.5.0", "bson": "^5.3.0", "chai": "^4.2.0", - "chai-as-promised": "^7.1.1", "compass-preferences-model": "^2.10.0", "depcheck": "^1.4.1", "electron": "^23.3.9", @@ -55911,7 +55912,6 @@ "mongodb-query-parser": "^2.5.0", "mongodb-query-util": "^2.0.0", "mongodb-schema": "^11.2.1", - "node-fetch": "^2.6.7", "nyc": "^15.1.0", "react": "^17.0.2", "react-dom": "^17.0.2", diff --git a/packages/atlas-service/package.json b/packages/atlas-service/package.json index 5fce66ecc76..b7ec614cece 100644 --- a/packages/atlas-service/package.json +++ b/packages/atlas-service/package.json @@ -59,6 +59,8 @@ "depcheck": "^1.4.1", "eslint": "^7.25.0", "mocha": "^10.2.0", + "mongodb": "^5.7.0", + "mongodb-schema": "^11.2.1", "nyc": "^15.1.0", "prettier": "^2.7.1", "sinon": "^9.2.3", diff --git a/packages/atlas-service/src/main.spec.ts b/packages/atlas-service/src/main.spec.ts index 3cd9c80a839..fc75e8e4ce6 100644 --- a/packages/atlas-service/src/main.spec.ts +++ b/packages/atlas-service/src/main.spec.ts @@ -1,6 +1,6 @@ import Sinon from 'sinon'; import { expect } from 'chai'; -import { AtlasService } from './main'; +import { AtlasService, throwIfNotOk } from './main'; describe('AtlasServiceMain', function () { const sandbox = Sinon.createSandbox(); @@ -23,17 +23,22 @@ describe('AtlasServiceMain', function () { AtlasService['plugin'] = mockOidcPlugin; + const fetch = AtlasService['fetch']; + const apiBaseUrl = process.env.DEV_AI_QUERY_ENDPOINT; const issuer = process.env.COMPASS_OIDC_ISSUER; const clientId = process.env.COMPASS_CLIENT_ID; before(function () { + process.env.DEV_AI_QUERY_ENDPOINT = 'http://example.com'; process.env.COMPASS_OIDC_ISSUER = 'http://example.com'; process.env.COMPASS_CLIENT_ID = '1234abcd'; }); after(function () { + process.env.DEV_AI_QUERY_ENDPOINT = apiBaseUrl; process.env.COMPASS_OIDC_ISSUER = issuer; process.env.COMPASS_CLIENT_ID = clientId; + AtlasService['fetch'] = fetch; }); afterEach(function () { @@ -88,4 +93,185 @@ describe('AtlasServiceMain', function () { expect(err).to.have.property('message', 'COMPASS_CLIENT_ID is required'); } }); + + describe('getQueryFromUserPrompt', function () { + it('makes a post request with the user prompt to the endpoint in the environment', async function () { + AtlasService['fetch'] = sandbox.stub().resolves({ + ok: true, + json() { + return Promise.resolve({ + content: { query: { find: { test: 'pineapple' } } }, + }); + }, + }) as any; + + const res = await AtlasService.getQueryFromUserPrompt({ + userPrompt: 'test', + signal: new AbortController().signal, + collectionName: 'jam', + schema: { _id: { types: [{ bsonType: 'ObjectId' }] } }, + sampleDocuments: [{ _id: 1234 }], + }); + + const { args } = ( + AtlasService['fetch'] as unknown as Sinon.SinonStub + ).getCall(0); + + expect(AtlasService['fetch']).to.have.been.calledOnce; + expect(args[0]).to.eq('http://example.com/ai/api/v1/mql-query'); + expect(args[1].body).to.eq( + '{"userPrompt":"test","collectionName":"jam","schema":{"_id":{"types":[{"bsonType":"ObjectId"}]}},"sampleDocuments":[{"_id":1234}]}' + ); + expect(res).to.have.nested.property( + 'content.query.find.test', + 'pineapple' + ); + }); + + it('uses the abort signal in the fetch request', async function () { + const c = new AbortController(); + c.abort(); + try { + await AtlasService.getQueryFromUserPrompt({ + signal: c.signal, + userPrompt: 'test', + collectionName: 'test.test', + }); + expect.fail('Expected getQueryFromUserPrompt to throw'); + } catch (err) { + expect(err).to.have.property('message', 'This operation was aborted'); + } + }); + + it('throws if the request would be too much for the ai', async function () { + try { + await AtlasService.getQueryFromUserPrompt({ + userPrompt: 'test', + collectionName: 'test.test', + sampleDocuments: [{ test: '4'.repeat(60000) }], + }); + expect.fail('Expected getQueryFromUserPrompt to throw'); + } catch (err) { + expect(err).to.have.property( + 'message', + 'Error: too large of a request to send to the ai. Please use a smaller prompt or collection with smaller documents.' + ); + } + }); + + it('passes fewer documents if the request would be too much for the ai with all of the documents', async function () { + AtlasService['fetch'] = sandbox.stub().resolves({ + ok: true, + json() { + return Promise.resolve({}); + }, + }) as any; + + await AtlasService.getQueryFromUserPrompt({ + userPrompt: 'test', + collectionName: 'test.test', + sampleDocuments: [ + { a: '1' }, + { a: '2' }, + { a: '3' }, + { a: '4'.repeat(50000) }, + ], + }); + + const { args } = ( + AtlasService['fetch'] as unknown as Sinon.SinonStub + ).getCall(0); + + expect(AtlasService['fetch']).to.have.been.calledOnce; + expect(args[1].body).to.eq( + '{"userPrompt":"test","collectionName":"test.test","sampleDocuments":[{"a":"1"}]}' + ); + }); + + it('throws the error', async function () { + AtlasService['fetch'] = sandbox.stub().resolves({ + ok: false, + status: 500, + statusText: 'Internal Server Error', + }) as any; + + try { + await AtlasService.getQueryFromUserPrompt({ + userPrompt: 'test', + collectionName: 'test.test', + }); + expect.fail('Expected getQueryFromUserPrompt to throw'); + } catch (err) { + expect(err).to.have.property('message', '500 Internal Server Error'); + } + }); + + it('should throw if DEV_AI_QUERY_ENDPOINT is not set', async function () { + delete process.env.DEV_AI_QUERY_ENDPOINT; + + try { + await AtlasService.getQueryFromUserPrompt({ + userPrompt: 'test', + collectionName: 'test.test', + }); + expect.fail('Expected AtlasService.signIn() to throw'); + } catch (err) { + expect(err).to.have.property( + 'message', + 'No AI Query endpoint to fetch. Please set the environment variable `DEV_AI_QUERY_ENDPOINT`' + ); + } + }); + }); + + describe('throwIfNotOk', function () { + it('should not throw if res is ok', async function () { + await throwIfNotOk({ + ok: true, + status: 200, + statusText: 'OK', + json() { + return Promise.resolve({}); + }, + }); + }); + + it('should throw network error if res is not ok', async function () { + try { + await throwIfNotOk({ + ok: false, + status: 500, + statusText: 'Whoops', + json() { + return Promise.resolve({}); + }, + }); + expect.fail('Expected throwIfNotOk to throw'); + } catch (err) { + expect(err).to.have.property('name', 'NetworkError'); + expect(err).to.have.property('message', '500 Whoops'); + } + }); + + it('should try to parse AIError from body and throw it', async function () { + try { + await throwIfNotOk({ + ok: false, + status: 500, + statusText: 'Whoops', + json() { + return Promise.resolve({ + name: 'AIError', + errorMessage: 'tortillas', + codeName: 'ExampleCode', + }); + }, + }); + expect.fail('Expected throwIfNotOk to throw'); + } catch (err) { + expect(err).to.have.property('name', 'Error'); + expect(err).to.have.property('message', 'ExampleCode: tortillas'); + } + }); + }); }); diff --git a/packages/atlas-service/src/main.ts b/packages/atlas-service/src/main.ts index 55639eb46fd..347dc3a850b 100644 --- a/packages/atlas-service/src/main.ts +++ b/packages/atlas-service/src/main.ts @@ -2,9 +2,14 @@ import { shell } from 'electron'; import { URL, URLSearchParams } from 'url'; import * as plugin from '@mongodb-js/oidc-plugin'; import { oidcServerRequestHandler } from '@mongodb-js/devtools-connect'; +// TODO(https://github.com/node-fetch/node-fetch/issues/1652): Remove this when +// node-fetch types match the built in AbortSignal from node. +import type { AbortSignal as NodeFetchAbortSignal } from 'node-fetch/externals'; import type { Response } from 'node-fetch'; import fetch from 'node-fetch'; -import type { IntrospectInfo, Token, UserInfo } from './util'; +import type { SimplifiedSchema } from 'mongodb-schema'; +import type { Document } from 'mongodb'; +import type { AIQuery, IntrospectInfo, Token, UserInfo } from './util'; import { ipcExpose } from './util'; const redirectRequestHandler = oidcServerRequestHandler.bind(null, { @@ -12,16 +17,42 @@ const redirectRequestHandler = oidcServerRequestHandler.bind(null, { productDocsLink: 'https://www.mongodb.com/docs/compass', }); -function throwIfNotOk(res: Response) { +const SPECIAL_AI_ERROR_NAME = 'AIError'; + +export async function throwIfNotOk( + res: Pick +) { if (res.ok) { return; } - const err = new Error(`NetworkError: ${res.statusText}`); - err.name = 'NetworkError'; + + let serverErrorName = 'NetworkError'; + let serverErrorMessage = `${res.status} ${res.statusText}`; + // Special case for AI endpoint only: + // We try to parse the response to see if the server returned any information + // we can show a user. + try { + // Why are we having a custom format and not following what mms does? + const messageJSON = await res.json(); + if (messageJSON.name === SPECIAL_AI_ERROR_NAME) { + serverErrorName = 'Error'; + serverErrorMessage = `${messageJSON.codeName as string}: ${ + messageJSON.errorMessage as string + }`; + } + } catch (err) { + // no-op, use the default status and statusText in the message. + } + const err = new Error(serverErrorMessage); + err.name = serverErrorName; (err as any).statusCode = res.status; throw err; } +const MAX_REQUEST_SIZE = 5000; + +const MIN_SAMPLE_DOCUMENTS = 1; + export class AtlasService { private constructor() { // singleton @@ -54,6 +85,8 @@ export class AtlasService { private static signInPromise: Promise | null = null; + private static fetch: typeof fetch = fetch; + private static get clientId() { if (!process.env.COMPASS_CLIENT_ID) { throw new Error('COMPASS_CLIENT_ID is required'); @@ -68,6 +101,15 @@ export class AtlasService { return process.env.COMPASS_OIDC_ISSUER; } + private static get apiBaseUrl() { + if (!process.env.DEV_AI_QUERY_ENDPOINT) { + throw new Error( + 'No AI Query endpoint to fetch. Please set the environment variable `DEV_AI_QUERY_ENDPOINT`' + ); + } + return process.env.DEV_AI_QUERY_ENDPOINT; + } + static init() { if (this.calledOnce) { return; @@ -78,6 +120,7 @@ export class AtlasService { 'introspect', 'isAuthenticated', 'signIn', + 'getQueryFromUserPrompt', ]); } @@ -113,20 +156,20 @@ export class AtlasService { } static async getUserInfo(): Promise { - const res = await fetch(`${this.issuer}/v1/userinfo`, { + const res = await this.fetch(`${this.issuer}/v1/userinfo`, { headers: { Authorization: `Bearer ${this.token?.accessToken ?? ''}`, Accept: 'application/json', }, }); - throwIfNotOk(res); + await throwIfNotOk(res); return res.json(); } static async introspect() { const url = new URL(`${this.issuer}/v1/introspect`); url.searchParams.set('client_id', this.clientId); - const res = await fetch(url.toString(), { + const res = await this.fetch(url.toString(), { method: 'POST', body: new URLSearchParams([ ['token', this.token?.accessToken ?? ''], @@ -136,7 +179,65 @@ export class AtlasService { Accept: 'application/json', }, }); - throwIfNotOk(res); + await throwIfNotOk(res); return res.json() as Promise; } + + static async getQueryFromUserPrompt({ + signal, + userPrompt, + collectionName, + schema, + sampleDocuments, + }: { + userPrompt: string; + collectionName: string; + schema?: SimplifiedSchema; + sampleDocuments?: Document[]; + signal?: AbortSignal; + }) { + if (signal?.aborted) { + const err = signal.reason ?? new Error('This operation was aborted.'); + throw err; + } + + let msgBody = JSON.stringify({ + userPrompt, + collectionName, + schema, + sampleDocuments, + }); + if (msgBody.length > MAX_REQUEST_SIZE) { + // When the message body is over the max size, we try + // to see if with fewer sample documents we can still perform the request. + // If that fails we throw an error indicating this collection's + // documents are too large to send to the ai. + msgBody = JSON.stringify({ + userPrompt, + collectionName, + schema, + sampleDocuments: sampleDocuments?.slice(0, MIN_SAMPLE_DOCUMENTS), + }); + // Why this is not happening on the backend? + if (msgBody.length > MAX_REQUEST_SIZE) { + throw new Error( + 'Error: too large of a request to send to the ai. Please use a smaller prompt or collection with smaller documents.' + ); + } + } + + const res = await this.fetch(`${this.apiBaseUrl}/ai/api/v1/mql-query`, { + signal: signal as NodeFetchAbortSignal | undefined, + method: 'POST', + headers: { + Authorization: `Bearer ${this.token?.accessToken ?? ''}`, + 'Content-Type': 'application/json', + }, + body: msgBody, + }); + + await throwIfNotOk(res); + + return res.json() as Promise; + } } diff --git a/packages/atlas-service/src/renderer.ts b/packages/atlas-service/src/renderer.ts index 61e84678bed..6fb1273fa45 100644 --- a/packages/atlas-service/src/renderer.ts +++ b/packages/atlas-service/src/renderer.ts @@ -4,6 +4,16 @@ import { ipcInvoke } from './util'; export function AtlasService() { return ipcInvoke< typeof AtlasServiceMain, - 'getUserInfo' | 'introspect' | 'isAuthenticated' | 'signIn' - >('AtlasService', ['getUserInfo', 'introspect', 'isAuthenticated', 'signIn']); + | 'getUserInfo' + | 'introspect' + | 'isAuthenticated' + | 'signIn' + | 'getQueryFromUserPrompt' + >('AtlasService', [ + 'getUserInfo', + 'introspect', + 'isAuthenticated', + 'signIn', + 'getQueryFromUserPrompt', + ]); } diff --git a/packages/atlas-service/src/util.ts b/packages/atlas-service/src/util.ts index 442b4b9f1a0..0447508b601 100644 --- a/packages/atlas-service/src/util.ts +++ b/packages/atlas-service/src/util.ts @@ -7,6 +7,12 @@ export type IntrospectInfo = { active: boolean }; export type Token = plugin.IdPServerResponse; +export type AIQuery = { + content?: { + query?: unknown; + }; +}; + type SerializedError = { $$error: Error & { statusCode?: number } }; function serializeErrorForIpc(err: any): SerializedError { diff --git a/packages/compass-query-bar/package.json b/packages/compass-query-bar/package.json index 0e8fecf90db..fa2495faf90 100644 --- a/packages/compass-query-bar/package.json +++ b/packages/compass-query-bar/package.json @@ -76,7 +76,6 @@ "@testing-library/react": "^12.1.4", "@testing-library/user-event": "^13.5.0", "chai": "^4.2.0", - "chai-as-promised": "^7.1.1", "depcheck": "^1.4.1", "electron": "^23.3.9", "eslint": "^7.25.0", @@ -88,7 +87,6 @@ "mongodb-ns": "^2.4.0", "mongodb-query-parser": "^2.5.0", "mongodb-schema": "^11.2.1", - "node-fetch": "^2.6.7", "nyc": "^15.1.0", "react": "^17.0.2", "react-dom": "^17.0.2", diff --git a/packages/compass-query-bar/src/components/query-history/index.spec.tsx b/packages/compass-query-bar/src/components/query-history/index.spec.tsx index 67b82efd5ec..5a535effab3 100644 --- a/packages/compass-query-bar/src/components/query-history/index.spec.tsx +++ b/packages/compass-query-bar/src/components/query-history/index.spec.tsx @@ -1,6 +1,4 @@ import React from 'react'; -import { applyMiddleware, createStore as _createStore } from 'redux'; -import thunk from 'redux-thunk'; import { expect } from 'chai'; import { render, @@ -13,15 +11,10 @@ import { Provider } from 'react-redux'; import Sinon from 'sinon'; import fs from 'fs'; import os from 'os'; - import QueryHistory from '.'; import { FavoriteQueryStorage, RecentQueryStorage } from '../../utils'; -import { - INITIAL_STATE, - fetchRecents, - fetchFavorites, -} from '../../stores/query-bar-reducer'; -import { rootQueryBarReducer } from '../../stores/query-bar-store'; +import { fetchRecents, fetchFavorites } from '../../stores/query-bar-reducer'; +import configureStore from '../../stores/query-bar-store'; const BASE_QUERY = { filter: { name: 'hello' }, @@ -31,6 +24,7 @@ const BASE_QUERY = { skip: 10, limit: 20, }; + const RECENT_QUERY = { _id: 'one', _lastExecuted: new Date(), @@ -48,28 +42,21 @@ function createStore(basepath?: string) { const favoriteQueryStorage = new FavoriteQueryStorage(basepath); const recentQueryStorage = new RecentQueryStorage(basepath); - const store = _createStore( - rootQueryBarReducer, - { - queryBar: { - ...INITIAL_STATE, - namespace: 'airbnb.listings', - host: 'localhost', - }, - } as any, - applyMiddleware( - thunk.withExtraArgument({ - favoriteQueryStorage, - recentQueryStorage, - dataProvider: { - sample: () => { - /* no-op for unsupported environments. */ - return Promise.resolve([]); - }, + const store = configureStore({ + namespace: 'airbnb.listings', + favoriteQueryStorage, + recentQueryStorage, + dataProvider: { + dataProvider: { + sample() { + return Promise.resolve([]); }, - }) - ) - ); + getConnectionString() { + return { hosts: [] } as any; + }, + }, + }, + }); return { store, diff --git a/packages/compass-query-bar/src/modules/ai-query-request.spec.ts b/packages/compass-query-bar/src/modules/ai-query-request.spec.ts deleted file mode 100644 index b242baf8fe1..00000000000 --- a/packages/compass-query-bar/src/modules/ai-query-request.spec.ts +++ /dev/null @@ -1,270 +0,0 @@ -import chai from 'chai'; -import chaiAsPromised from 'chai-as-promised'; -import { ObjectId } from 'mongodb'; -import type { Document } from 'mongodb'; -import type { SimplifiedSchema } from 'mongodb-schema'; - -const { expect } = chai; -chai.use(chaiAsPromised); - -import { runFetchAIQuery } from './ai-query-request'; -import { - startMockAIServer, - TEST_AUTH_USERNAME, - TEST_AUTH_PASSWORD, -} from '../../test/create-mock-ai-endpoint'; - -const mockUserPrompt: { - userPrompt: string; - collectionName: string; - schema?: SimplifiedSchema; - sampleDocuments?: Document[]; -} = { - userPrompt: 'test', - collectionName: 'jam', - schema: { - _id: { - types: [ - { - bsonType: 'ObjectId', - }, - ], - }, - }, - sampleDocuments: [ - { - _id: new ObjectId(), - }, - ], -}; - -describe('#runFetchAIQuery', function () { - describe('with a valid server endpoint set in the environment', function () { - let stopServer: () => Promise; - let getRequests: () => any[]; - - beforeEach(async function () { - // Start a mock server to pass an ai response. - // Set the server endpoint in the env. - const { - endpoint, - getRequests: _getRequests, - stop, - } = await startMockAIServer(); - - getRequests = _getRequests; - stopServer = stop; - process.env.DEV_AI_QUERY_ENDPOINT = endpoint; - process.env.DEV_AI_USERNAME = TEST_AUTH_USERNAME; - process.env.DEV_AI_PASSWORD = TEST_AUTH_PASSWORD; - }); - - afterEach(async function () { - await stopServer(); - delete process.env.DEV_AI_QUERY_ENDPOINT; - delete process.env.DEV_AI_USERNAME; - delete process.env.DEV_AI_PASSWORD; - }); - - it('makes a post request with the user prompt to the endpoint in the environment', async function () { - const id = new ObjectId(); - const response = await runFetchAIQuery({ - userPrompt: 'test', - signal: new AbortController().signal, - collectionName: 'jam', - schema: { - _id: { - types: [ - { - bsonType: 'ObjectId', - }, - ], - }, - }, - sampleDocuments: [ - { - _id: id, - }, - ], - }); - const requests = getRequests(); - expect(requests[0].content).to.deep.equal({ - userPrompt: 'test', - collectionName: 'jam', - schema: { - _id: { - types: [ - { - bsonType: 'ObjectId', - }, - ], - }, - }, - sampleDocuments: [ - { - _id: id.toString(), - }, - ], - }); - expect(requests[0].req.url).to.equal('/ai/api/v1/mql-query'); - - expect(response).to.deep.equal({ - content: { - query: { - find: { - test: 'pineapple', - }, - }, - }, - }); - }); - - it('uses the abort signal in the fetch request', async function () { - const abortController = new AbortController(); - abortController.abort(); - - const promise = runFetchAIQuery({ - ...mockUserPrompt, - signal: abortController.signal, - }); - - await expect(promise).to.be.rejectedWith('The user aborted a request.'); - }); - - it('throws if the request would be too much for the ai', async function () { - const promise = runFetchAIQuery({ - ...mockUserPrompt, - sampleDocuments: [ - { - test: '4'.repeat(60000), - }, - ], - signal: new AbortController().signal, - }); - - await expect(promise).to.be.rejectedWith( - 'Error: too large of a request to send to the ai. Please use a smaller prompt or collection with smaller documents.' - ); - }); - - it('passes fewer documents if the request would be too much for the ai with all of the documents', async function () { - const response = await runFetchAIQuery({ - ...mockUserPrompt, - sampleDocuments: [ - { - a: ['1'], - }, - { - a: ['2'], - }, - { - a: ['3'], - }, - { - a: ['4'.repeat(50000)], - }, - ], - signal: new AbortController().signal, - }); - - const requests = getRequests(); - expect(requests[0].content.sampleDocuments).to.deep.equal([ - { - a: ['1'], - }, - ]); - expect(!!response).to.be.true; - }); - }); - - describe('with no endpoint set in environment', function () { - it('throws an error', async function () { - const promise = runFetchAIQuery({ - ...mockUserPrompt, - signal: new AbortController().signal, - }); - - await expect(promise).to.be.rejectedWith( - 'No AI Query endpoint to fetch. Please set the environment variable `DEV_AI_QUERY_ENDPOINT`' - ); - }); - }); - - describe('when the server errors', function () { - let stopServer: () => Promise; - - beforeEach(async function () { - // Start a mock server to pass an ai response. - // Set the server endpoint in the env. - const { endpoint, stop } = await startMockAIServer({ - response: { - status: 500, - body: 'error', - }, - }); - - stopServer = stop; - process.env.DEV_AI_QUERY_ENDPOINT = endpoint; - process.env.DEV_AI_USERNAME = TEST_AUTH_USERNAME; - process.env.DEV_AI_PASSWORD = TEST_AUTH_PASSWORD; - }); - - afterEach(async function () { - await stopServer(); - delete process.env.DEV_AI_QUERY_ENDPOINT; - delete process.env.DEV_AI_USERNAME; - delete process.env.DEV_AI_PASSWORD; - }); - - it('throws the error', async function () { - const promise = runFetchAIQuery({ - ...mockUserPrompt, - signal: new AbortController().signal, - }); - - await expect(promise).to.be.rejectedWith( - 'Error: 500 Internal Server Error' - ); - }); - }); - - describe('when the server errors with an AIError', function () { - let stopServer: () => Promise; - - beforeEach(async function () { - // Start a mock server to pass an ai response. - // Set the server endpoint in the env. - const { endpoint, stop } = await startMockAIServer({ - response: { - status: 500, - body: { - name: 'AIError', - errorMessage: 'tortillas', - codeName: 'ExampleCode', - }, - }, - }); - - stopServer = stop; - process.env.DEV_AI_QUERY_ENDPOINT = endpoint; - process.env.DEV_AI_USERNAME = TEST_AUTH_USERNAME; - process.env.DEV_AI_PASSWORD = TEST_AUTH_PASSWORD; - }); - - afterEach(async function () { - await stopServer(); - delete process.env.DEV_AI_QUERY_ENDPOINT; - delete process.env.DEV_AI_USERNAME; - delete process.env.DEV_AI_PASSWORD; - }); - - it('throws the error', async function () { - const promise = runFetchAIQuery({ - ...mockUserPrompt, - signal: new AbortController().signal, - }); - - await expect(promise).to.be.rejectedWith('Error: ExampleCode: tortillas'); - }); - }); -}); diff --git a/packages/compass-query-bar/src/modules/ai-query-request.ts b/packages/compass-query-bar/src/modules/ai-query-request.ts deleted file mode 100644 index ca32da2013f..00000000000 --- a/packages/compass-query-bar/src/modules/ai-query-request.ts +++ /dev/null @@ -1,105 +0,0 @@ -import fetch from 'node-fetch'; -// TODO(https://github.com/node-fetch/node-fetch/issues/1652): Remove this when -// node-fetch types match the built in AbortSignal from node. -import type { AbortSignal as NodeFetchAbortSignal } from 'node-fetch/externals'; -import type { SimplifiedSchema } from 'mongodb-schema'; -import type { Document } from 'mongodb'; - -const serverErrorMessageName = 'AIError'; - -function getAIQueryEndpoint(): string { - if (!process.env.DEV_AI_QUERY_ENDPOINT) { - throw new Error( - 'No AI Query endpoint to fetch. Please set the environment variable `DEV_AI_QUERY_ENDPOINT`' - ); - } - - return process.env.DEV_AI_QUERY_ENDPOINT; -} - -function getAIBasicAuth(): string { - if (!process.env.DEV_AI_USERNAME || !process.env.DEV_AI_PASSWORD) { - throw new Error( - 'No AI auth information found. Please set the environment variable `DEV_AI_USERNAME` and `DEV_AI_PASSWORD`' - ); - } - - const authBuffer = Buffer.from( - `${process.env.DEV_AI_USERNAME}:${process.env.DEV_AI_PASSWORD}` - ); - return `Basic ${authBuffer.toString('base64')}`; -} - -const MAX_REQUEST_SIZE = 5000; -const MIN_SAMPLE_DOCUMENTS = 1; - -export async function runFetchAIQuery({ - signal, - userPrompt, - collectionName, - schema, - sampleDocuments, -}: { - signal: AbortSignal; - userPrompt: string; - collectionName: string; - schema?: SimplifiedSchema; - sampleDocuments?: Document[]; -}) { - let msgBody = JSON.stringify({ - userPrompt, - collectionName, - schema, - sampleDocuments, - }); - if (msgBody.length > MAX_REQUEST_SIZE) { - // When the message body is over the max size, we try - // to see if with fewer sample documents we can still perform the request. - // If that fails we throw an error indicating this collection's - // documents are too large to send to the ai. - msgBody = JSON.stringify({ - userPrompt, - collectionName, - schema, - sampleDocuments: sampleDocuments?.slice(0, MIN_SAMPLE_DOCUMENTS), - }); - if (msgBody.length > MAX_REQUEST_SIZE) { - throw new Error( - 'Error: too large of a request to send to the ai. Please use a smaller prompt or collection with smaller documents.' - ); - } - } - - const endpoint = `${getAIQueryEndpoint()}/ai/api/v1/mql-query`; - - const res = await fetch(endpoint, { - signal: signal as NodeFetchAbortSignal, - method: 'POST', - headers: { - Authorization: getAIBasicAuth(), - 'Content-Type': 'application/json', - }, - body: msgBody, - }); - - if (!res.ok) { - // We try to parse the response to see if the server returned any - // information we can show a user. - let serverErrorMessage = `${res.status} ${res.statusText}`; - try { - const messageJSON = await res.json(); - if (messageJSON.name === serverErrorMessageName) { - serverErrorMessage = `${messageJSON.codeName as string}: ${ - messageJSON.errorMessage as string - }`; - } - } catch (err) { - // no-op, use the default status and statusText in the message. - } - throw new Error(`Error: ${serverErrorMessage}`); - } - - const jsonResponse = await res.json(); - - return jsonResponse; -} diff --git a/packages/compass-query-bar/src/stores/ai-query-reducer.spec.ts b/packages/compass-query-bar/src/stores/ai-query-reducer.spec.ts index c70e43cfa51..06537cafd87 100644 --- a/packages/compass-query-bar/src/stores/ai-query-reducer.spec.ts +++ b/packages/compass-query-bar/src/stores/ai-query-reducer.spec.ts @@ -1,8 +1,7 @@ import { expect } from 'chai'; -import { ObjectId } from 'mongodb'; import { promises as fs } from 'fs'; import os from 'os'; - +import Sinon from 'sinon'; import configureStore from './query-bar-store'; import type { QueryBarStoreOptions } from './query-bar-store'; import { @@ -10,33 +9,14 @@ import { cancelAIQuery, runAIQuery, } from './ai-query-reducer'; -import { - startMockAIServer, - TEST_AUTH_USERNAME, - TEST_AUTH_PASSWORD, -} from '../../test/create-mock-ai-endpoint'; - -function _createStore(opts: Partial) { - return configureStore({ - dataProvider: { - dataProvider: { - getConnectionString: () => - ({ - hosts: [], - } as any), - sample: () => - Promise.resolve([ - { - _id: new ObjectId(), - }, - ]), - }, - }, - ...opts, - }); -} describe('aiQueryReducer', function () { + const sandbox = Sinon.createSandbox(); + + afterEach(function () { + sandbox.resetHistory(); + }); + let tmpDir: string; before(async function () { @@ -44,120 +24,49 @@ describe('aiQueryReducer', function () { }); function createStore(opts: Partial = {}) { - return _createStore({ + return configureStore({ basepath: tmpDir, ...opts, }); } describe('runAIQuery', function () { - describe('with a successful server response (mock server)', function () { - let stopServer: () => Promise; - let getRequests: () => any[]; - - beforeEach(async function () { - // Start a mock server to pass an ai response. - // Set the server endpoint in the env. - const { - endpoint, - getRequests: _getRequests, - stop, - } = await startMockAIServer(); - - stopServer = stop; - getRequests = _getRequests; - process.env.DEV_AI_QUERY_ENDPOINT = endpoint; - process.env.DEV_AI_USERNAME = TEST_AUTH_USERNAME; - process.env.DEV_AI_PASSWORD = TEST_AUTH_PASSWORD; - }); - - afterEach(async function () { - await stopServer(); - delete process.env.DEV_AI_QUERY_ENDPOINT; - delete process.env.DEV_AI_USERNAME; - delete process.env.DEV_AI_PASSWORD; - }); - + describe('with a successful server response', function () { it('should succeed', async function () { - const sampleDocs = [ - { - _id: new ObjectId(), - a: { - b: 3, - }, - }, - { - _id: new ObjectId(), - a: { - b: 'a', - }, - c: 'pineapple', - }, - ]; - const resultSchema = { - _id: { - types: [ - { - bsonType: 'ObjectId', - }, - ], - }, - a: { - types: [ - { - bsonType: 'Document', - fields: { - b: { - types: [ - { - bsonType: 'Number', - }, - { - bsonType: 'String', - }, - ], - }, - }, - }, - ], - }, - c: { - types: [ - { - bsonType: 'String', - }, - ], - }, + const mockAtlasService = { + getQueryFromUserPrompt: sandbox + .stub() + .resolves({ content: { query: { _id: 1 } } }), + }; + + const mockDataService = { + sample: sandbox.stub().resolves([{ _id: 42 }]), + getConnectionString: sandbox.stub().returns({ hosts: [] }), }; + const store = createStore({ namespace: 'database.collection', dataProvider: { - dataProvider: { - getConnectionString: () => - ({ - hosts: [], - } as any), - sample: () => Promise.resolve(sampleDocs), - }, + dataProvider: mockDataService as any, }, + atlasService: mockAtlasService as any, }); - let didSetFetchId = false; - store.subscribe(() => { - if (store.getState().aiQuery.aiQueryFetchId !== -1) { - didSetFetchId = true; - } - }); + expect(store.getState().aiQuery.status).to.equal('ready'); + await store.dispatch(runAIQuery('testing prompt')); - expect(didSetFetchId).to.equal(true); - expect(getRequests()[0].content).to.deep.equal({ - userPrompt: 'testing prompt', - schema: resultSchema, - // Parse stringify to make _ids stringified for deep check. - sampleDocuments: JSON.parse(JSON.stringify(sampleDocs)), - collectionName: 'collection', - }); + expect(mockAtlasService.getQueryFromUserPrompt).to.have.been.calledOnce; + expect( + mockAtlasService.getQueryFromUserPrompt.getCall(0) + ).to.have.nested.property('args[0].userPrompt', 'testing prompt'); + expect( + mockAtlasService.getQueryFromUserPrompt.getCall(0) + ).to.have.nested.property('args[0].collectionName', 'collection'); + expect(mockAtlasService.getQueryFromUserPrompt.getCall(0)) + .to.have.nested.property('args[0].sampleDocuments') + .deep.eq([{ _id: 42 }]); + expect(store.getState().aiQuery.aiQueryFetchId).to.equal(-1); expect(store.getState().aiQuery.errorMessage).to.equal(undefined); expect(store.getState().aiQuery.status).to.equal('success'); @@ -165,36 +74,19 @@ describe('aiQueryReducer', function () { }); describe('when there is an error', function () { - let stopServer: () => Promise; - - beforeEach(async function () { - const { endpoint, stop } = await startMockAIServer({ - response: { - status: 500, - body: 'test', - }, - }); - - stopServer = stop; - process.env.DEV_AI_QUERY_ENDPOINT = endpoint; - process.env.DEV_AI_USERNAME = TEST_AUTH_USERNAME; - process.env.DEV_AI_PASSWORD = TEST_AUTH_PASSWORD; - }); - - afterEach(async function () { - await stopServer(); - delete process.env.DEV_AI_QUERY_ENDPOINT; - delete process.env.DEV_AI_USERNAME; - delete process.env.DEV_AI_PASSWORD; - }); - it('sets the error on the store', async function () { - const store = createStore(); + const mockAtlasService = { + getQueryFromUserPrompt: sandbox + .stub() + .rejects(new Error('500 Internal Server Error')), + }; + + const store = createStore({ atlasService: mockAtlasService as any }); expect(store.getState().aiQuery.errorMessage).to.equal(undefined); await store.dispatch(runAIQuery('testing prompt') as any); expect(store.getState().aiQuery.aiQueryFetchId).to.equal(-1); expect(store.getState().aiQuery.errorMessage).to.equal( - 'Error: 500 Internal Server Error' + '500 Internal Server Error' ); expect(store.getState().aiQuery.status).to.equal('ready'); }); diff --git a/packages/compass-query-bar/src/stores/ai-query-reducer.ts b/packages/compass-query-bar/src/stores/ai-query-reducer.ts index 348131edea6..65ed6256fe9 100644 --- a/packages/compass-query-bar/src/stores/ai-query-reducer.ts +++ b/packages/compass-query-bar/src/stores/ai-query-reducer.ts @@ -6,7 +6,6 @@ import preferences from 'compass-preferences-model'; import type { QueryBarThunkAction } from './query-bar-store'; import { isAction } from '../utils'; -import { runFetchAIQuery } from '../modules/ai-query-request'; import { mapQueryToFormFields } from '../utils/query'; import type { QueryFormFields } from '../constants/query-properties'; import { DEFAULT_FIELD_VALUES } from '../constants/query-bar-store'; @@ -110,7 +109,7 @@ export const runAIQuery = ( Promise, AIQueryStartedAction | AIQueryFailedAction | AIQuerySucceededAction > => { - return async (dispatch, getState, { dataProvider }) => { + return async (dispatch, getState, { dataService, atlasService }) => { const { aiQuery: { aiQueryFetchId: existingFetchId }, queryBar: { namespace }, @@ -131,7 +130,7 @@ export const runAIQuery = ( let jsonResponse; try { - const sampleDocuments = await dataProvider.sample( + const sampleDocuments = await dataService.sample( namespace, { query: {}, @@ -148,7 +147,7 @@ export const runAIQuery = ( const schema = await getSimplifiedSchema(sampleDocuments); const { collection: collectionName } = toNS(namespace); - jsonResponse = await runFetchAIQuery({ + jsonResponse = await atlasService.getQueryFromUserPrompt({ signal: abortController.signal, userPrompt, collectionName, diff --git a/packages/compass-query-bar/src/stores/query-bar-store.ts b/packages/compass-query-bar/src/stores/query-bar-store.ts index c358a3d8171..5f78e54f3a7 100644 --- a/packages/compass-query-bar/src/stores/query-bar-store.ts +++ b/packages/compass-query-bar/src/stores/query-bar-store.ts @@ -30,10 +30,7 @@ import { AtlasService } from '@mongodb-js/atlas-service/renderer'; const { basepath } = getStoragePaths() || {}; // Partial of DataService that mms shares with Compass. -type DataProvider = { - getConnectionString: DataService['getConnectionString']; - sample: DataService['sample']; -}; +type QueryBarDataService = Pick; export type QueryBarStoreOptions = { serverVersion: string; @@ -42,11 +39,14 @@ export type QueryBarStoreOptions = { query: BaseQuery; namespace: string; dataProvider: { - dataProvider?: DataProvider; + dataProvider?: QueryBarDataService; }; + atlasService: ReturnType; // For testing. basepath?: string; + favoriteQueryStorage?: FavoriteQueryStorage; + recentQueryStorage?: RecentQueryStorage; }; export const rootQueryBarReducer = combineReducers({ @@ -61,9 +61,8 @@ export type QueryBarExtraArgs = { localAppRegistry?: AppRegistry; favoriteQueryStorage: FavoriteQueryStorage; recentQueryStorage: RecentQueryStorage; - dataProvider: { - sample: DataProvider['sample']; - }; + dataService: Pick; + atlasService: ReturnType; }; export type QueryBarThunkDispatch = @@ -82,17 +81,17 @@ function createStore(options: Partial = {}) { query, namespace, dataProvider, + atlasService = AtlasService(), + recentQueryStorage = new RecentQueryStorage( + options.basepath ?? basepath, + namespace + ), + favoriteQueryStorage = new FavoriteQueryStorage( + options.basepath ?? basepath, + namespace + ), } = options; - const recentQueryStorage = new RecentQueryStorage( - options.basepath ?? basepath, - namespace - ); - const favoriteQueryStorage = new FavoriteQueryStorage( - options.basepath ?? basepath, - namespace - ); - return _createStore( rootQueryBarReducer, { @@ -109,9 +108,9 @@ function createStore(options: Partial = {}) { }, applyMiddleware( thunk.withExtraArgument({ - dataProvider: dataProvider?.dataProvider ?? { + dataService: dataProvider?.dataProvider ?? { sample: () => { - /* no-op for unsupported environments. */ + /* no-op for environments where dataService is not provided at all. */ return Promise.resolve([]); }, }, @@ -119,7 +118,7 @@ function createStore(options: Partial = {}) { globalAppRegistry, recentQueryStorage, favoriteQueryStorage, - atlasService: AtlasService(), + atlasService, }) ) ); diff --git a/packages/compass-query-bar/test/create-mock-ai-endpoint.ts b/packages/compass-query-bar/test/create-mock-ai-endpoint.ts deleted file mode 100644 index fdcfc6628a5..00000000000 --- a/packages/compass-query-bar/test/create-mock-ai-endpoint.ts +++ /dev/null @@ -1,108 +0,0 @@ -import http from 'http'; -import { once } from 'events'; -import type { AddressInfo } from 'net'; - -export const TEST_AUTH_USERNAME = 'testuser'; -export const TEST_AUTH_PASSWORD = 'testpass'; - -// Throws if doesn't match. -function checkReqAuth(req: http.IncomingMessage) { - const header = req.headers.authorization ?? ''; - const token = header.split(/\s+/).pop() ?? ''; - const auth = Buffer.from(token, 'base64').toString(); - const [username, password] = auth.split(':'); - - if (username !== TEST_AUTH_USERNAME || password !== TEST_AUTH_PASSWORD) { - throw new Error('no match'); - } -} - -export async function startMockAIServer( - { - response, - }: { - response: { - status: number; - body: any; - }; - } = { - response: { - status: 200, - body: { - content: { - query: { - find: { - test: 'pineapple', - }, - }, - }, - }, - }, - } -): Promise<{ - getRequests: () => { - content: any; - req: any; - }[]; - endpoint: string; - server: http.Server; - stop: () => Promise; -}> { - const requests: { - content: any; - req: any; - }[] = []; - const server = http - .createServer((req, res) => { - try { - checkReqAuth(req); - } catch (err) { - res.writeHead(401); - res.end('Not authorized.'); - return; - } - - let body = ''; - req - .setEncoding('utf8') - .on('data', (chunk) => { - body += chunk; - }) - .on('end', () => { - const jsonObject = JSON.parse(body); - requests.push({ - req, - content: jsonObject, - }); - - res.setHeader('Content-Type', 'application/json'); - if (response.status !== 200) { - res.writeHead(response.status); - } - return res.end(JSON.stringify(response.body)); - }); - }) - .listen(0); - await once(server, 'listening'); - - // address() returns either a string or AddressInfo. - const address = server.address() as AddressInfo; - - const endpoint = `http://localhost:${address.port}`; - - async function stop() { - server.close(); - await once(server, 'close'); - } - - function getRequests() { - return requests; - } - - return { - getRequests, - endpoint, - server, - stop, - }; -} From d7dc64749e14f9e5d6022e91d8ec139ac35fdc29 Mon Sep 17 00:00:00 2001 From: Sergey Petushkov Date: Thu, 20 Jul 2023 12:32:43 +0200 Subject: [PATCH 2/2] chore(atlas-service): add signal handling for ipc --- packages/atlas-service/src/main.ts | 41 ++++++- packages/atlas-service/src/util.spec.ts | 135 ++++++++++++++++++++++++ packages/atlas-service/src/util.ts | 112 +++++++++++++++++--- 3 files changed, 267 insertions(+), 21 deletions(-) diff --git a/packages/atlas-service/src/main.ts b/packages/atlas-service/src/main.ts index 347dc3a850b..6da21e04dcf 100644 --- a/packages/atlas-service/src/main.ts +++ b/packages/atlas-service/src/main.ts @@ -135,17 +135,29 @@ export class AtlasService { } } - static async signIn(): Promise { + static async signIn({ + signal, + }: { signal?: AbortSignal } = {}): Promise { if (this.signInPromise) { return this.signInPromise; } try { + if (signal?.aborted) { + const err = signal.reason ?? new Error('This operation was aborted.'); + throw err; + } + this.signInPromise = (async () => { this.token = await this.plugin.mongoClientOptions.authMechanismProperties.REQUEST_TOKEN_CALLBACK( { clientId: this.clientId, issuer: this.issuer }, - // Required driver specific stuff - { version: 0 } + { + // Required driver specific stuff + version: 0, + // This seems to be just an abort signal? We probably can make it + // explicit when adding a proper interface for this + timeoutContext: signal, + } ); return this.token; })(); @@ -155,20 +167,36 @@ export class AtlasService { } } - static async getUserInfo(): Promise { + static async getUserInfo({ + signal, + }: { signal?: AbortSignal } = {}): Promise { + if (signal?.aborted) { + const err = signal.reason ?? new Error('This operation was aborted.'); + throw err; + } + const res = await this.fetch(`${this.issuer}/v1/userinfo`, { headers: { Authorization: `Bearer ${this.token?.accessToken ?? ''}`, Accept: 'application/json', }, + signal: signal as NodeFetchAbortSignal | undefined, }); + await throwIfNotOk(res); + return res.json(); } - static async introspect() { + static async introspect({ signal }: { signal?: AbortSignal } = {}) { + if (signal?.aborted) { + const err = signal.reason ?? new Error('This operation was aborted.'); + throw err; + } + const url = new URL(`${this.issuer}/v1/introspect`); url.searchParams.set('client_id', this.clientId); + const res = await this.fetch(url.toString(), { method: 'POST', body: new URLSearchParams([ @@ -178,8 +206,11 @@ export class AtlasService { headers: { Accept: 'application/json', }, + signal: signal as NodeFetchAbortSignal | undefined, }); + await throwIfNotOk(res); + return res.json() as Promise; } diff --git a/packages/atlas-service/src/util.spec.ts b/packages/atlas-service/src/util.spec.ts index e69de29bb2d..3c91c8c6914 100644 --- a/packages/atlas-service/src/util.spec.ts +++ b/packages/atlas-service/src/util.spec.ts @@ -0,0 +1,135 @@ +import Sinon from 'sinon'; +import { expect } from 'chai'; +import { promisify } from 'util'; +import { ipcExpose, ipcInvoke, ControllerMap } from './util'; + +const wait = promisify(setTimeout); + +describe('ipc', function () { + const sandbox = Sinon.createSandbox(); + + const MockIpc = class { + handlers = new Map any>(); + handle = sandbox + .stub() + .callsFake((channel: string, fn: (_evt: any, ...args: any[]) => any) => { + this.handlers.set(channel, fn); + }); + async invoke(channel: string, ...args: any[]) { + return await this.handlers.get(channel)?.({}, ...args); + } + }; + + const mockIpc = new MockIpc(); + + const mockHandler = { + foo: sandbox.stub().resolves(42), + bar: sandbox.stub().rejects(new Error('Whoops!')), + buz: sandbox.stub().callsFake(({ signal }: { signal: AbortSignal }) => { + return new Promise((_resolve, reject) => { + if (signal.aborted) { + throw signal.reason; + } + signal.addEventListener('abort', () => { + reject(signal.reason); + }); + }); + }), + }; + + afterEach(function () { + mockIpc.handlers.clear(); + sandbox.resetHistory(); + }); + + it('should pass arguments from invoker to handler', async function () { + ipcExpose('Test', mockHandler, ['foo'], mockIpc, true); + const { foo } = ipcInvoke( + 'Test', + ['foo'], + mockIpc + ); + + await foo({ test: 1 }); + + expect(mockHandler.foo).to.have.been.calledOnceWith({ + signal: new AbortController().signal, + test: 1, + }); + }); + + it('should return handler result when invoked', async function () { + ipcExpose('Test', mockHandler, ['foo'], mockIpc, true); + const { foo } = ipcInvoke( + 'Test', + ['foo'], + mockIpc + ); + + const res = await foo({ test: 1 }); + + expect(res).to.eq(42); + }); + + it('should serialize and de-serialize errors when thrown in handler', async function () { + ipcExpose('Test', mockHandler, ['bar'], mockIpc, true); + const { bar } = ipcInvoke( + 'Test', + ['bar'], + mockIpc + ); + + try { + await bar(); + expect.fail('Expected bar() to throw'); + } catch (err) { + expect(err).to.have.property('message', 'Whoops!'); + } + }); + + it('should handle signals being passed from invoker to handler', async function () { + ipcExpose('Test', mockHandler, ['buz'], mockIpc, true); + const { buz } = ipcInvoke( + 'Test', + ['buz'], + mockIpc + ); + + const invokeController = new AbortController(); + + const promise = buz({ signal: invokeController.signal }); + + // Wait a bit before aborting so that we don't throw right in the invoker + await wait(100); + + expect(ControllerMap).to.have.property('size', 1); + + const [handlerController] = Array.from(ControllerMap.values()); + + invokeController.abort(); + + try { + await promise; + expect.fail('Expected promise to throw abort error'); + } catch (err) { + expect(handlerController).to.have.nested.property('signal.aborted', true); + expect(err).to.have.property('message', 'This operation was aborted'); + } + }); + + it('should clean up abort controllers when handlers are executed', async function () { + ipcExpose('Test', mockHandler, ['foo'], mockIpc, true); + const { foo } = ipcInvoke( + 'Test', + ['foo'], + mockIpc + ); + + await foo(); + await foo(); + await Promise.all([foo(), foo(), foo()]); + await foo(); + + expect(ControllerMap).to.have.property('size', 0); + }); +}); diff --git a/packages/atlas-service/src/util.ts b/packages/atlas-service/src/util.ts index 0447508b601..3c41caf5c63 100644 --- a/packages/atlas-service/src/util.ts +++ b/packages/atlas-service/src/util.ts @@ -15,15 +15,23 @@ export type AIQuery = { type SerializedError = { $$error: Error & { statusCode?: number } }; +// We are serializing errors to get a better error shape on the other end, ipc +// will only preserve message from the original error. See https://github.com/electron/electron/issues/24427 function serializeErrorForIpc(err: any): SerializedError { return { - $$error: { name: err.name, message: err.message, statusCode: err.status }, + $$error: { + name: err.name, + message: err.message, + statusCode: err.status, + stack: err.stack, + }, }; } function deserializeErrorFromIpc({ $$error: err }: SerializedError) { const e = new Error(err.message); e.name = err.name; + e.stack = err.stack; (e as any).stausCode = err.statusCode; return e; } @@ -32,6 +40,32 @@ function isSerializedError(err: any): err is { $$error: Error } { return err !== null && typeof err === 'object' && !!err.$$error; } +// Exported for testing purposes +export const ControllerMap = new Map(); + +let cId = 0; + +let setup = false; + +export function setupSignalHandler( + _ipcMain: Pick = ipcMain, + forceSetup = false +) { + if (!forceSetup && setup) { + return; + } + + setup = true; + + _ipcMain.handle('ipcHandlerInvoke', (_evt, id: string) => { + ControllerMap.set(id, new AbortController()); + }); + + _ipcMain.handle('ipcHandlerAborted', (_evt, id: string) => { + ControllerMap.get(id)?.abort(); + }); +} + type PickByValue = Pick< T, { [k in keyof T]: T[k] extends K ? k : never }[keyof T] @@ -41,37 +75,83 @@ export function ipcExpose( serviceName: string, obj: T, methodNames: Extract< - keyof PickByValue Promise>, + keyof PickByValue Promise>, string - >[] + >[], + _ipcMain: Pick = ipcMain, + _forceSetup = false ) { + setupSignalHandler(_ipcMain, _forceSetup); + for (const name of methodNames) { - ipcMain.handle(`${serviceName}.${name}`, async (_evt, ...args) => { - try { - return await (obj[name] as (...args: any[]) => any).call(obj, ...args); - } catch (err) { - return serializeErrorForIpc(err); + const channel = `${serviceName}.${name}`; + _ipcMain.handle( + channel, + async ( + _evt, + { signal, ...rest }: { signal: string } & Record + ) => { + try { + const controller = ControllerMap.get(signal); + return await (obj[name] as (...args: any[]) => any).call(obj, { + signal: controller?.signal, + ...rest, + }); + } catch (err) { + return serializeErrorForIpc(err); + } finally { + ControllerMap.delete(signal); + } } - }); + ); } } export function ipcInvoke< T, K extends Extract< - keyof PickByValue Promise>, + keyof PickByValue Promise>, string > ->(serviceName: string, methodNames: K[]) { +>( + serviceName: string, + methodNames: K[], + _ipcRenderer: Pick = ipcRenderer +) { return Object.fromEntries( methodNames.map((name) => { + const channel = `${serviceName}.${name}`; + const signalId = `${channel}:${++cId}`; return [ name, - async (...args: any[]) => { - const res = await ipcRenderer.invoke( - `${serviceName}.${name}`, - ...args - ); + async ({ + signal, + ...rest + }: { signal?: AbortSignal } & Record = {}) => { + await _ipcRenderer.invoke('ipcHandlerInvoke', signalId); + const onAbort = () => { + return _ipcRenderer.invoke('ipcHandlerAborted', signalId); + }; + // If signal is already aborted, make sure that handler will see it + // when it runs, otherwise just set up abort listener to communicate + // this to main process + if (signal?.aborted) { + await onAbort(); + } else { + signal?.addEventListener( + 'abort', + () => { + void onAbort(); + }, + { once: true } + ); + } + const res = await _ipcRenderer.invoke(`${serviceName}.${name}`, { + // We replace this with a matched signal on the other side, this + // is mostly for testing / debugging purposes + signal: signalId, + ...rest, + }); if (isSerializedError(res)) { throw deserializeErrorFromIpc(res); }