|
| 1 | +import { ChromaClient } from 'chromadb'; |
| 2 | +import { GoogleGenerativeAI } from '@google/generative-ai'; |
| 3 | +import aiConfig from '../../../../buildScripts/ai/aiConfig.mjs'; |
| 4 | +import dotenv from 'dotenv'; |
| 5 | +import path from 'path'; |
| 6 | + |
| 7 | +// TODO: This dotenv config needs a more robust solution. |
| 8 | +const cwd = process.cwd(); |
| 9 | +const insideNeo = process.env.npm_package_name?.includes('neo.mjs') ?? false; |
| 10 | +dotenv.config({ |
| 11 | + path: insideNeo ? path.resolve(cwd, '.env') : path.resolve(cwd, '../../.env'), |
| 12 | + quiet: true |
| 13 | +}); |
| 14 | + |
| 15 | +/** |
| 16 | + * Performs a semantic search on the knowledge base using a natural language query. |
| 17 | + * Returns a scored and ranked list of the most relevant source files. |
| 18 | + * @param {string} query - The natural language search query. |
| 19 | + * @param {string} [type='all'] - The content type to filter by. |
| 20 | + * @returns {Promise<object>} A promise that resolves to the query results object. |
| 21 | + */ |
| 22 | +async function queryDocuments({ query, type = 'all' }) { |
| 23 | + if (!query) { |
| 24 | + throw new Error('A query string must be provided.'); |
| 25 | + } |
| 26 | + |
| 27 | + const dbClient = new ChromaClient(); |
| 28 | + const GEMINI_API_KEY = process.env.GEMINI_API_KEY; |
| 29 | + if (!GEMINI_API_KEY) { |
| 30 | + throw new Error('The GEMINI_API_KEY environment variable is not set.'); |
| 31 | + } |
| 32 | + |
| 33 | + const genAI = new GoogleGenerativeAI(GEMINI_API_KEY); |
| 34 | + const model = genAI.getGenerativeModel({ model: aiConfig.knowledgeBase.embeddingModel }); |
| 35 | + |
| 36 | + let collection; |
| 37 | + try { |
| 38 | + const originalWarn = console.warn; |
| 39 | + console.warn = () => {}; // Suppress unwanted warnings from ChromaDB client |
| 40 | + collection = await dbClient.getCollection({ name: aiConfig.knowledgeBase.collectionName }); |
| 41 | + console.warn = originalWarn; |
| 42 | + } catch (err) { |
| 43 | + throw new Error('Could not connect to collection. Please run the sync process first.'); |
| 44 | + } |
| 45 | + |
| 46 | + const queryEmbedding = await model.embedContent(query); |
| 47 | + const queryLower = query.toLowerCase(); |
| 48 | + |
| 49 | + const whereClause = (type && type !== 'all') ? { type } : {}; |
| 50 | + |
| 51 | + const queryOptions = { |
| 52 | + queryEmbeddings: [queryEmbedding.embedding.values], |
| 53 | + nResults: aiConfig.knowledgeBase.nResults, |
| 54 | + where: whereClause |
| 55 | + }; |
| 56 | + |
| 57 | + if (Object.keys(whereClause).length === 0) { |
| 58 | + delete queryOptions.where; |
| 59 | + } |
| 60 | + |
| 61 | + const results = await collection.query(queryOptions); |
| 62 | + |
| 63 | + if (!results.metadatas || results.metadatas.length === 0 || results.metadatas[0].length === 0) { |
| 64 | + return { message: 'No results found for your query and type.' }; |
| 65 | + } |
| 66 | + |
| 67 | + const sourceScores = {}; |
| 68 | + const queryWords = queryLower.replace(/[^a-zA-Z ]/g, '').split(' ').filter(w => w.length > 2); |
| 69 | + |
| 70 | + results.metadatas[0].forEach((metadata, index) => { |
| 71 | + if (!metadata.source || metadata.source === 'unknown') return; |
| 72 | + |
| 73 | + let score = (results.metadatas[0].length - index) * 1; |
| 74 | + const sourcePath = metadata.source; |
| 75 | + const sourcePathLower = sourcePath.toLowerCase(); |
| 76 | + const fileName = sourcePath.split('/').pop().toLowerCase(); |
| 77 | + const nameLower = (metadata.name || '').toLowerCase(); |
| 78 | + |
| 79 | + queryWords.forEach(queryWord => { |
| 80 | + const keyword = queryWord; |
| 81 | + const keywordSingular = keyword.endsWith('s') ? keyword.slice(0, -1) : keyword; |
| 82 | + |
| 83 | + if (keywordSingular.length > 2) { |
| 84 | + if (sourcePathLower.includes(`/${keywordSingular}/`)) score += 40; |
| 85 | + if (fileName.includes(keywordSingular)) score += 30; |
| 86 | + if (metadata.type === 'class' && nameLower.includes(keywordSingular)) score += 20; |
| 87 | + if (metadata.className && metadata.className.toLowerCase().includes(keywordSingular)) score += 20; |
| 88 | + if (metadata.type === 'guide' || metadata.type === 'blog') { |
| 89 | + score += metadata.type === 'blog' ? 5 : 50; |
| 90 | + if (nameLower.includes(keywordSingular)) score += 50; |
| 91 | + } |
| 92 | + const nameParts = nameLower.split('.'); |
| 93 | + if (nameParts.includes(keywordSingular)) score += 30; |
| 94 | + } |
| 95 | + }); |
| 96 | + |
| 97 | + if (metadata.type === 'ticket' && type === 'all') score -= 70; |
| 98 | + if (metadata.type === 'release') score -= 50; |
| 99 | + if (fileName.endsWith('base.mjs')) score += 20; |
| 100 | + if (metadata.type === 'release' && queryLower.startsWith('v') && nameLower === queryLower) score += 1000; |
| 101 | + |
| 102 | + sourceScores[sourcePath] = (sourceScores[sourcePath] || 0) + score; |
| 103 | + |
| 104 | + const inheritanceChain = JSON.parse(metadata.inheritanceChain || '[]'); |
| 105 | + let boost = 80; |
| 106 | + inheritanceChain.forEach(parent => { |
| 107 | + if (parent.source) { |
| 108 | + sourceScores[parent.source] = (sourceScores[parent.source] || 0) + boost; |
| 109 | + } |
| 110 | + boost = Math.floor(boost * 0.6); |
| 111 | + }); |
| 112 | + }); |
| 113 | + |
| 114 | + if (Object.keys(sourceScores).length === 0) { |
| 115 | + return { message: 'No relevant source files found for the specified type.' }; |
| 116 | + } |
| 117 | + |
| 118 | + const sortedSources = Object.entries(sourceScores).sort(([, a], [, b]) => b - a); |
| 119 | + const finalScores = {}; |
| 120 | + const topSourceDirs = sortedSources.slice(0, 5).map(([source]) => path.dirname(source)); |
| 121 | + |
| 122 | + sortedSources.forEach(([source, score]) => { |
| 123 | + let finalScore = score; |
| 124 | + const sourceDir = path.dirname(source); |
| 125 | + if (topSourceDirs.includes(sourceDir)) { |
| 126 | + finalScore *= 1.1; |
| 127 | + } |
| 128 | + finalScores[source] = finalScore; |
| 129 | + }); |
| 130 | + |
| 131 | + const finalSorted = Object.entries(finalScores) |
| 132 | + .sort(([, a], [, b]) => b - a) |
| 133 | + .slice(0, 25) |
| 134 | + .map(([source, score]) => ({ source, score: score.toFixed(0) })); |
| 135 | + |
| 136 | + if (finalSorted.length > 0) { |
| 137 | + return { |
| 138 | + topResult: finalSorted[0].source, |
| 139 | + results: finalSorted |
| 140 | + }; |
| 141 | + } |
| 142 | + |
| 143 | + return { message: 'No relevant source files found after scoring.' }; |
| 144 | +} |
| 145 | + |
| 146 | +export { queryDocuments }; |
0 commit comments