diff --git a/.github/workflows/test-e2e.yml b/.github/workflows/test-e2e.yml index 2581fd81e..baebdbbda 100644 --- a/.github/workflows/test-e2e.yml +++ b/.github/workflows/test-e2e.yml @@ -2,14 +2,14 @@ name: e2e Tests on: push: - branches: [ alpha, beta, rc, main ] + branches: [alpha, beta, rc, main] pull_request: types: [opened, synchronize] - branches: [ alpha, beta, rc, main ] + branches: [alpha, beta, rc, main] jobs: test-e2e: - runs-on: ubuntu-latest + runs-on: windows-latest strategy: matrix: node-version: [18.x] @@ -20,7 +20,12 @@ jobs: uses: actions/setup-node@v4 with: node-version: ${{ matrix.node-version }} + - name: Install dependencies run: npm install + + - name: Fix auto-fixable eslint issues + run: npm run eslint -- --fix + - name: Run e2e tests run: npm run test:e2e diff --git a/jest.config.electron.e2e.ts b/jest.config.electron.e2e.ts index 6888b4edc..5431441d4 100644 --- a/jest.config.electron.e2e.ts +++ b/jest.config.electron.e2e.ts @@ -5,7 +5,15 @@ const jestConfig = { roots: ["/src/electron"], testMatch: ["**/*.test.e2e.ts"], transform: { - "^.+\\.(ts|js)$": ["@swc/jest"], + "^.+\\.(ts)$": ["@swc/jest"], + "^.+\\.(js)$": [ + "@swc/jest", + { + jsc: { + target: "es5", + }, + }, + ], }, moduleNameMapper: { "@/(.*)": "/src/electron/future/$1", diff --git a/package-lock.json b/package-lock.json index 8ad6059f2..98b2e8014 100644 --- a/package-lock.json +++ b/package-lock.json @@ -14,6 +14,7 @@ "electron-dl": "3.5.2", "electron-serve": "1.3.0", "electron-store": "^8.2.0", + "onnxruntime-node": "^1.17.0", "sharp": "0.33.2" }, "devDependencies": { @@ -61,6 +62,7 @@ "@types/uuid": "^9.0.8", "@typescript-eslint/eslint-plugin": "^7.1.1", "@typescript-eslint/parser": "^7.1.1", + "@xenova/transformers": "github:xenova/transformers.js#v3", "axios": "^1.6.7", "codemirror": "^5.65.16", "conventional-changelog-cli": "^4.1.0", @@ -3319,6 +3321,15 @@ "integrity": "sha512-9TANp6GPoMtYzQdt54kfAyMmz1+osLlXdg2ENroU7zzrtflTLrrC/lgrIfaSe+Wu0b89GKccT7vxXA0MoAIO+Q==", "dev": true }, + "node_modules/@huggingface/jinja": { + "version": "0.2.1", + "resolved": "https://registry.npmjs.org/@huggingface/jinja/-/jinja-0.2.1.tgz", + "integrity": "sha512-HxjVCll8oGfgUQmN91NYWCjfuaQ5mYZkc/BB1gjfp28q3s48yiB5jUEV7BvaRdIAb/+14cNdX8TIdalFykwywA==", + "dev": true, + "engines": { + "node": ">=18" + } + }, "node_modules/@humanwhocodes/config-array": { "version": "0.11.14", "resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.14.tgz", @@ -5293,6 +5304,70 @@ "url": "https://opencollective.com/popperjs" } }, + "node_modules/@protobufjs/aspromise": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@protobufjs/aspromise/-/aspromise-1.1.2.tgz", + "integrity": "sha512-j+gKExEuLmKwvz3OgROXtrJ2UG2x8Ch2YZUxahh+s1F2HZ+wAceUNLkvy6zKCPVRkU++ZWQrdxsUeQXmcg4uoQ==", + "dev": true + }, + "node_modules/@protobufjs/base64": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@protobufjs/base64/-/base64-1.1.2.tgz", + "integrity": "sha512-AZkcAA5vnN/v4PDqKyMR5lx7hZttPDgClv83E//FMNhR2TMcLUhfRUBHCmSl0oi9zMgDDqRUJkSxO3wm85+XLg==", + "dev": true + }, + "node_modules/@protobufjs/codegen": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/@protobufjs/codegen/-/codegen-2.0.4.tgz", + "integrity": "sha512-YyFaikqM5sH0ziFZCN3xDC7zeGaB/d0IUb9CATugHWbd1FRFwWwt4ld4OYMPWu5a3Xe01mGAULCdqhMlPl29Jg==", + "dev": true + }, + "node_modules/@protobufjs/eventemitter": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@protobufjs/eventemitter/-/eventemitter-1.1.0.tgz", + "integrity": "sha512-j9ednRT81vYJ9OfVuXG6ERSTdEL1xVsNgqpkxMsbIabzSo3goCjDIveeGv5d03om39ML71RdmrGNjG5SReBP/Q==", + "dev": true + }, + "node_modules/@protobufjs/fetch": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@protobufjs/fetch/-/fetch-1.1.0.tgz", + "integrity": "sha512-lljVXpqXebpsijW71PZaCYeIcE5on1w5DlQy5WH6GLbFryLUrBD4932W/E2BSpfRJWseIL4v/KPgBFxDOIdKpQ==", + "dev": true, + "dependencies": { + "@protobufjs/aspromise": "^1.1.1", + "@protobufjs/inquire": "^1.1.0" + } + }, + "node_modules/@protobufjs/float": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/@protobufjs/float/-/float-1.0.2.tgz", + "integrity": "sha512-Ddb+kVXlXst9d+R9PfTIxh1EdNkgoRe5tOX6t01f1lYWOvJnSPDBlG241QLzcyPdoNTsblLUdujGSE4RzrTZGQ==", + "dev": true + }, + "node_modules/@protobufjs/inquire": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@protobufjs/inquire/-/inquire-1.1.0.tgz", + "integrity": "sha512-kdSefcPdruJiFMVSbn801t4vFK7KB/5gd2fYvrxhuJYg8ILrmn9SKSX2tZdV6V+ksulWqS7aXjBcRXl3wHoD9Q==", + "dev": true + }, + "node_modules/@protobufjs/path": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@protobufjs/path/-/path-1.1.2.tgz", + "integrity": "sha512-6JOcJ5Tm08dOHAbdR3GrvP+yUUfkjG5ePsHYczMFLq3ZmMkAD98cDgcT2iA1lJ9NVwFd4tH/iSSoe44YWkltEA==", + "dev": true + }, + "node_modules/@protobufjs/pool": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@protobufjs/pool/-/pool-1.1.0.tgz", + "integrity": "sha512-0kELaGSIDBKvcgS4zkjz1PeddatrjYcmMWOlAuAPwAeccUrPHdUqo/J6LiymHHEiJT5NrF1UVwxY14f+fy4WQw==", + "dev": true + }, + "node_modules/@protobufjs/utf8": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@protobufjs/utf8/-/utf8-1.1.0.tgz", + "integrity": "sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw==", + "dev": true + }, "node_modules/@qdrant/js-client-rest": { "version": "1.8.0", "resolved": "https://registry.npmjs.org/@qdrant/js-client-rest/-/js-client-rest-1.8.0.tgz", @@ -7365,6 +7440,20 @@ "@xtuc/long": "4.2.2" } }, + "node_modules/@xenova/transformers": { + "version": "3.0.0-alpha.0", + "resolved": "git+ssh://git@github.com/xenova/transformers.js.git#b0b5e412beba47bd95f45e88d961a4cdaf5fcc01", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@huggingface/jinja": "^0.2.1", + "onnxruntime-web": "1.17.1", + "sharp": "^0.33.2" + }, + "optionalDependencies": { + "onnxruntime-node": "1.17.0" + } + }, "node_modules/@xmldom/xmldom": { "version": "0.8.10", "resolved": "https://registry.npmjs.org/@xmldom/xmldom/-/xmldom-0.8.10.tgz", @@ -12583,6 +12672,12 @@ "node": "^10.12.0 || >=12.0.0" } }, + "node_modules/flatbuffers": { + "version": "1.12.0", + "resolved": "https://registry.npmjs.org/flatbuffers/-/flatbuffers-1.12.0.tgz", + "integrity": "sha512-c7CZADjRcl6j0PlvFy0ZqXQ67qSEZfrVPynmnL+2zPc+NtMvrF8Y0QceMo7QqnSPc7+uWjUIAbvCQ5WIKlMVdQ==", + "dev": true + }, "node_modules/flatted": { "version": "3.2.9", "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.2.9.tgz", @@ -13323,6 +13418,12 @@ "integrity": "sha512-D9cPgkvLlV3t3IzL0D0YLvGA9Ahk4PcvVwUbN0dSGr1aP0Nrt4AEnTUbuGvquEC0mA64Gqt1fzirlRs5ibXx8g==", "dev": true }, + "node_modules/guid-typescript": { + "version": "1.0.9", + "resolved": "https://registry.npmjs.org/guid-typescript/-/guid-typescript-1.0.9.tgz", + "integrity": "sha512-Y8T4vYhEfwJOTbouREvG+3XDsjr8E3kIr7uf+JZ0BYloFsttiHU0WfvANVsR7TxNUJa/WpCnw/Ino/p+DeBhBQ==", + "dev": true + }, "node_modules/handlebars": { "version": "4.7.8", "resolved": "https://registry.npmjs.org/handlebars/-/handlebars-4.7.8.tgz", @@ -16300,6 +16401,12 @@ "integrity": "sha512-sReKOYJIJf74dhJONhU4e0/shzi1trVbSWDOhKYE5XV2O+H7Sb2Dihwuc7xWxVl+DgFPyTqIN3zMfT9cq5iWDg==", "dev": true }, + "node_modules/long": { + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz", + "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==", + "dev": true + }, "node_modules/longest-streak": { "version": "3.1.0", "resolved": "https://registry.npmjs.org/longest-streak/-/longest-streak-3.1.0.tgz", @@ -21046,6 +21153,44 @@ "node": ">=6" } }, + "node_modules/onnxruntime-common": { + "version": "1.17.1", + "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.17.1.tgz", + "integrity": "sha512-6wLNhpn+1hnsKN+jq6ulqUEJ61TdRmyFkGCvtRNnZkAupH8Yfr805UeNxjl9jtiX9B1q48pq6Q/67fEFpxT7Dw==", + "dev": true + }, + "node_modules/onnxruntime-node": { + "version": "1.17.0", + "resolved": "https://registry.npmjs.org/onnxruntime-node/-/onnxruntime-node-1.17.0.tgz", + "integrity": "sha512-pRxdqSP3a6wtiFVkVX1V3/gsEMwBRUA9D2oYmcN3cjF+j+ILS+SIY2L7KxdWapsG6z64i5rUn8ijFZdIvbojBg==", + "os": [ + "win32", + "darwin", + "linux" + ], + "dependencies": { + "onnxruntime-common": "1.17.0" + } + }, + "node_modules/onnxruntime-node/node_modules/onnxruntime-common": { + "version": "1.17.0", + "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.17.0.tgz", + "integrity": "sha512-Vq1remJbCPITjDMJ04DA7AklUTnbYUp4vbnm6iL7ukSt+7VErH0NGYfekRSTjxxurEtX7w41PFfnQlE6msjPJw==" + }, + "node_modules/onnxruntime-web": { + "version": "1.17.1", + "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.17.1.tgz", + "integrity": "sha512-EotY9uJU4xFY/ZVZ2Zrl2OZmBcbTVTWn/2OOh4cCWODPwtsYN2xeJYgoz8LfCgZSrhenGg0q4ceYUWATXqEsYQ==", + "dev": true, + "dependencies": { + "flatbuffers": "^1.12.0", + "guid-typescript": "^1.0.9", + "long": "^5.2.3", + "onnxruntime-common": "1.17.1", + "platform": "^1.3.6", + "protobufjs": "^7.2.4" + } + }, "node_modules/openai": { "version": "4.28.4", "resolved": "https://registry.npmjs.org/openai/-/openai-4.28.4.tgz", @@ -21670,6 +21815,12 @@ "node": ">=4" } }, + "node_modules/platform": { + "version": "1.3.6", + "resolved": "https://registry.npmjs.org/platform/-/platform-1.3.6.tgz", + "integrity": "sha512-fnWVljUchTro6RiCFvCXBbNhJc2NijN7oIQxbwsyL0buWJPG85v81ehlHI9fXrJsMNgTofEoWIQeClKpgxFLrg==", + "dev": true + }, "node_modules/playwright": { "version": "1.42.1", "resolved": "https://registry.npmjs.org/playwright/-/playwright-1.42.1.tgz", @@ -21899,6 +22050,30 @@ "integrity": "sha512-vtK/94akxsTMhe0/cbfpR+syPuszcuwhqVjJq26CuNDgFGj682oRBXOP5MJpv2r7JtE8MsiepGIqvvOTBwn2vA==", "dev": true }, + "node_modules/protobufjs": { + "version": "7.2.6", + "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.6.tgz", + "integrity": "sha512-dgJaEDDL6x8ASUZ1YqWciTRrdOuYNzoOf27oHNfdyvKqHr5i0FV7FSLU+aIeFjyFgVxrpTOtQUi0BLLBymZaBw==", + "dev": true, + "hasInstallScript": true, + "dependencies": { + "@protobufjs/aspromise": "^1.1.2", + "@protobufjs/base64": "^1.1.2", + "@protobufjs/codegen": "^2.0.4", + "@protobufjs/eventemitter": "^1.1.0", + "@protobufjs/fetch": "^1.1.0", + "@protobufjs/float": "^1.0.2", + "@protobufjs/inquire": "^1.1.0", + "@protobufjs/path": "^1.1.2", + "@protobufjs/pool": "^1.1.0", + "@protobufjs/utf8": "^1.1.0", + "@types/node": ">=13.7.0", + "long": "^5.0.0" + }, + "engines": { + "node": ">=12.0.0" + } + }, "node_modules/proxy-from-env": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/proxy-from-env/-/proxy-from-env-1.1.0.tgz", diff --git a/package.json b/package.json index bf81fadd1..1952d7ec0 100644 --- a/package.json +++ b/package.json @@ -21,7 +21,7 @@ "spj": "npx sort-package-json@latest", "test:client": "jest --runInBand --config jest.config.client.ts --verbose", "pretest:e2e": "nextron build --no-pack", - "test:e2e": "xvfb-run --auto-servernum --server-args=\"-screen 0 1280x960x24\" -- cross-env TEST_ENV=test npx playwright test", + "test:e2e": "cross-env TEST_ENV=test npx playwright test", "test:e2e:local": "cross-env TEST_ENV=local npx playwright test", "test:electron": "jest --runInBand --config jest.config.electron.ts --verbose", "test:electron:e2e": "cross-env TEST_ENV=local jest --runInBand --config jest.config.electron.e2e.ts --verbose", @@ -34,6 +34,7 @@ "electron-dl": "3.5.2", "electron-serve": "1.3.0", "electron-store": "^8.2.0", + "onnxruntime-node": "^1.17.0", "sharp": "0.33.2" }, "devDependencies": { @@ -81,6 +82,7 @@ "@types/uuid": "^9.0.8", "@typescript-eslint/eslint-plugin": "^7.1.1", "@typescript-eslint/parser": "^7.1.1", + "@xenova/transformers": "github:xenova/transformers.js#v3", "axios": "^1.6.7", "codemirror": "^5.65.16", "conventional-changelog-cli": "^4.1.0", @@ -151,8 +153,8 @@ "tslib": "2.6.2", "type-fest": "4.12.0", "typescript": "^5.4.2", - "uuid": "^9.0.1", "use-debounce": "^10.0.0", + "uuid": "^9.0.1", "webpack": "^5.90.3" }, "optionalDependencies": { diff --git a/src/electron/future/langchain/__tests__/custom-hugging-face-transformers-embeddings.test.e2e.ts b/src/electron/future/langchain/__tests__/custom-hugging-face-transformers-embeddings.test.e2e.ts new file mode 100644 index 000000000..52585e9a3 --- /dev/null +++ b/src/electron/future/langchain/__tests__/custom-hugging-face-transformers-embeddings.test.e2e.ts @@ -0,0 +1,88 @@ +import path from "node:path"; + +import { env } from "@xenova/transformers"; + +const originalImplementation = Array.isArray; +// @ts-expect-error we just want to mock this +Array.isArray = jest.fn(type => { + if ( + type && + type.constructor && + (type.constructor.name === "Float32Array" || type.constructor.name === "BigInt64Array") + ) { + return true; + } + + return originalImplementation(type); +}); + +import { CustomHuggingFaceTransformersEmbeddings } from "../custom-hugging-face-transformers-embeddings"; + +describe("CustomHuggingFaceEmbeddings", () => { + let embeddings: any; + + beforeAll(() => { + env.localModelPath = path.join(process.cwd(), "models"); + + embeddings = new CustomHuggingFaceTransformersEmbeddings({ + modelName: "Xenova/all-MiniLM-L6-v2", + maxTokens: 128, + stripNewLines: true, + }); + }); + + it("should create embeddings for a given text", async () => { + const text = "Hello, world!"; + const result = await embeddings.embedQuery(text); + + expect(Array.isArray(result)).toBeTruthy(); + expect(result.length).toBeGreaterThan(0); + }); + + it("should create embeddings for multiple documents", async () => { + const texts = ["Hello, world!", "Goodbye, world!"]; + const results = await embeddings.embedDocuments(texts); + + // Check that results is an array of arrays + expect(Array.isArray(results)).toBeTruthy(); + expect(results.length).toEqual(texts.length); + + // Check each result to ensure it's an array and not empty + for (const [index, embedding] of results.entries()) { + expect(Array.isArray(embedding)).toBeTruthy(); + expect(embedding.length).toBeGreaterThan(0); + } + }); + + describe("without maxTokens defined", () => { + let defaultEmbeddings: any; + + beforeAll(() => { + defaultEmbeddings = new CustomHuggingFaceTransformersEmbeddings({ + modelName: "Xenova/all-MiniLM-L6-v2", + stripNewLines: true, + }); + }); + + it("should create embeddings for a given text using default settings", async () => { + const text = "Hello, world!"; + const result = await defaultEmbeddings.embedQuery(text); + + expect(Array.isArray(result)).toBeTruthy(); + expect(result.length).toBeGreaterThan(0); + }); + + it("should create embeddings for multiple documents using default settings", async () => { + const texts = ["Hello, world!", "Goodbye, world!"]; + const results = await defaultEmbeddings.embedDocuments(texts); + + expect(Array.isArray(results)).toBeTruthy(); + expect(results.length).toEqual(texts.length); + + for (const [index, embedding] of results.entries()) { + expect(Array.isArray(embedding)).toBeTruthy(); + expect(embedding.length).toBeGreaterThan(0); + } + }); + }); +}); diff --git a/src/electron/future/langchain/custom-hugging-face-transformers-embeddings.ts b/src/electron/future/langchain/custom-hugging-face-transformers-embeddings.ts new file mode 100644 index 000000000..54f3b9131 --- /dev/null +++ b/src/electron/future/langchain/custom-hugging-face-transformers-embeddings.ts @@ -0,0 +1,86 @@ +import type { HuggingFaceTransformersEmbeddingsParams } from "@langchain/community/embeddings/hf_transformers"; +import { HuggingFaceTransformersEmbeddings } from "@langchain/community/embeddings/hf_transformers"; +import { AutoTokenizer, env } from "@xenova/transformers"; + +// Configuration for Transformers.js to only use local models +env.allowRemoteModels = false; +env.allowLocalModels = true; + +/** + * Truncate texts to a specified maximum token length. + * + * @param texts - The array of text strings to truncate. + * @param modelName - The name of the model used for tokenization. + * @param maxTokens - The maximum number of tokens allowed for each text. + * @returns A Promise that resolves to an array of truncated text strings. + */ +async function truncateTexts( + texts: string[], + modelName: string, + maxTokens: number +): Promise { + const tokenizer = await AutoTokenizer.from_pretrained(modelName); + return Promise.all( + texts.map(async text => { + const { input_ids } = await tokenizer(text, { + truncation: true, + max_length: maxTokens, + }); + return tokenizer.decode(input_ids, { skip_special_tokens: true }); + }) + ); +} + +/** + * Custom wrapper around HuggingFaceTransformersEmbeddings to support text truncation + * to a specified maximum number of tokens before embedding. This can be useful + * when working with models that have a fixed maximum input size, but produce better results + * when you use a lower input size as the maximum. + * + * @extends HuggingFaceTransformersEmbeddings + */ +export class CustomHuggingFaceTransformersEmbeddings extends HuggingFaceTransformersEmbeddings { + private maxTokens?: number; + + constructor( + fields?: Partial & { + maxTokens?: number; + } + ) { + super(fields); + this.maxTokens = fields?.maxTokens; + } + + /** + * Embeds multiple documents, optionally truncating each to a maximum token length. + * + * @param texts - The array of text strings to embed. + * @returns A Promise that resolves to a two-dimensional array of embeddings, with each + * sub-array representing the embedding of one input text. + */ + async embedDocuments(texts: string[]): Promise { + // Truncate texts if maxTokens is specified + if (this.maxTokens) { + const truncatedTexts = await truncateTexts(texts, this.modelName, this.maxTokens); + return super.embedDocuments(truncatedTexts); + } + + return super.embedDocuments(texts); + } + + /** + * Embeds a single query, optionally truncating it to a maximum token length. + * + * @param text - The text string to embed. + * @returns A Promise that resolves to an array representing the embedding of the input text. + */ + async embedQuery(text: string): Promise { + // Truncate text if maxTokens is specified + if (this.maxTokens) { + const [truncatedText] = await truncateTexts([text], this.modelName, this.maxTokens); + return super.embedQuery(truncatedText); + } + + return super.embedQuery(text); + } +} diff --git a/src/electron/future/main.ts b/src/electron/future/main.ts index 64995db59..97b1a4ea4 100644 --- a/src/electron/future/main.ts +++ b/src/electron/future/main.ts @@ -3,7 +3,8 @@ import fsp from "node:fs/promises"; import path from "path"; import url from "url"; -import { OpenAIEmbeddings } from "@langchain/openai"; +import { HuggingFaceTransformersEmbeddings } from "@langchain/community/embeddings/hf_transformers"; +import { env } from "@xenova/transformers"; import type { BrowserWindowConstructorOptions } from "electron"; import { app, ipcMain, BrowserWindow, Menu, protocol, screen, globalShortcut } from "electron"; import { globby } from "globby"; @@ -11,7 +12,7 @@ import matter from "gray-matter"; import { version } from "../../../package.json"; -import { appSettingsStore, keyStore, userStore } from "./stores"; +import { appSettingsStore, userStore } from "./stores"; import { buildKey } from "#/build-key"; import { LOCAL_PROTOCOL, VECTOR_STORE_COLLECTION } from "#/constants"; @@ -21,7 +22,8 @@ import { VectorStore } from "@/services/vector-store"; import { isCoreApp, isCoreView } from "@/utils/core"; import { createWindow } from "@/utils/create-window"; import { loadURL } from "@/utils/load-window"; -import { getCaptainData, getDirectory } from "@/utils/path-helpers"; +import { getCaptainData, getCaptainDownloads, getDirectory } from "@/utils/path-helpers"; +import { CustomHuggingFaceTransformersEmbeddings } from "@/langchain/custom-hugging-face-transformers-embeddings"; /** * Creates and displays the installer window with predefined dimensions. @@ -324,17 +326,23 @@ async function populateVectorStoreFromDocuments() { const apps: Record = {}; async function runStartup(withDashboard?: boolean) { - const apiKey = keyStore.get("openAiApiKey", ""); + env.localModelPath = getCaptainDownloads("llm/embeddings"); + env.allowRemoteModels = false; + env.allowLocalModels = true; + await VectorStore.init( - new OpenAIEmbeddings({ - openAIApiKey: apiKey, - modelName: "text-embedding-3-large", + new CustomHuggingFaceTransformersEmbeddings({ + modelName: "Xenova/all-MiniLM-L6-v2", + maxTokens: 128, + stripNewLines: true, }) ); - // - // await VectorStore.getInstance.deleteCollection(VECTOR_STORE_COLLECTION); - // await populateVectorStoreFromDocuments(); + try { + await VectorStore.getInstance.deleteCollection(VECTOR_STORE_COLLECTION); + } catch {} + + await populateVectorStoreFromDocuments(); apps.prompt = await createPromptWindow(); if (withDashboard) { diff --git a/src/electron/future/services/__tests__/vector-store.test.e2e.ts b/src/electron/future/services/__tests__/vector-store.test.e2e.ts index 39630d0f0..663a6f772 100644 --- a/src/electron/future/services/__tests__/vector-store.test.e2e.ts +++ b/src/electron/future/services/__tests__/vector-store.test.e2e.ts @@ -1,4 +1,6 @@ -import { OpenAIEmbeddings } from "@langchain/openai"; +import path from "node:path"; + +import { env } from "@xenova/transformers"; import axios from "axios"; import dotenv from "dotenv"; @@ -20,15 +22,32 @@ jest.mock("electron", () => ({ }, })); +const originalImplementation = Array.isArray; +// @ts-expect-error we just want to mock this +Array.isArray = jest.fn(type => { + if ( + type && + type.constructor && + (type.constructor.name === "Float32Array" || type.constructor.name === "BigInt64Array") + ) { + return true; + } + + return originalImplementation(type); +}); + +import { CustomHuggingFaceTransformersEmbeddings } from "@/langchain/custom-hugging-face-transformers-embeddings"; import { VectorStore } from "@/services/vector-store"; describe("VectorStore Integration Tests", () => { let vectorStore: VectorStore; const collectionName = "test_collection"; + const collectionName2 = "test_collection_2"; const document1 = { content: "Live Painting is very nice", payload: { id: "live-painting:schema", + label: "Live Painting", language: "en", }, }; @@ -38,21 +57,27 @@ describe("VectorStore Integration Tests", () => { content: "Story Creator writes any story", payload: { id: "story-creator:schema", + label: "Story Creator", language: "en", }, }; beforeAll(async () => { - const embedding = new OpenAIEmbeddings({ - openAIApiKey: process.env.OPENAI_API_KEY, - modelName: "text-embedding-3-large", + env.localModelPath = path.join(process.cwd(), "models"); + + const embedding = new CustomHuggingFaceTransformersEmbeddings({ + modelName: "Xenova/all-MiniLM-L6-v2", + maxTokens: 128, + stripNewLines: true, }); vectorStore = await VectorStore.init(embedding); + + await vectorStore.deleteCollection(collectionName); + await vectorStore.deleteCollection(collectionName2); }); afterAll(async () => { - await vectorStore.deleteCollection(collectionName); await vectorStore.stop(); }); @@ -108,10 +133,13 @@ describe("VectorStore Integration Tests", () => { it("should throw an error as searching in a non-existing collection doesn't work", async () => { try { - await vectorStore.search("doesnt-exist", document1.content); + await vectorStore.search(collectionName2, document1.content); } catch (error) { expect(error).toBeDefined(); - expect((error as Error).message).toContain("Collection doesnt-exist doesn't exist"); + console.log(error); + expect((error as Error).message).toContain( + `Collection ${collectionName2} doesn't exist` + ); } }, 10_000); diff --git a/src/electron/future/services/vector-store.ts b/src/electron/future/services/vector-store.ts index e6d3c51c3..a6d6db7b4 100644 --- a/src/electron/future/services/vector-store.ts +++ b/src/electron/future/services/vector-store.ts @@ -287,7 +287,7 @@ class VectorStore { * @returns {Promise} A promise that resolves when the collection has been deleted. */ public async deleteCollection(collectionName: string) { - this.ensureCollection(collectionName, false); + await this.ensureCollection(collectionName, false); return this.client?.deleteCollection(collectionName); }