/
index_embeddings.js
129 lines (110 loc) · 3.5 KB
/
index_embeddings.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
// Run this server with cmd "node index_embeddings.js" from root directory.
// Ensure you have values for AIRTABLE_BASE_ID and AIRTABLE_API_KEY environment variables in .env.
// Airtable base name used in this example is "Frontend Fresh" and Airtable view is "Grid view".
require("dotenv").config();
const express = require("express");
const { Configuration, OpenAIApi } = require("openai");
const Airtable = require("airtable");
const app = express();
app.use(express.json());
// airtable configuration
const airtableBase = new Airtable({
apiKey: process.env.AIRTABLE_API_KEY,
}).base(process.env.AIRTABLE_BASE_ID);
const airtableTable = airtableBase("Frontend Fresh");
const airtableView = airtableTable.select({ view: "Grid view" });
// open ai configuration
const configuration = new Configuration({
apiKey: process.env.OPENAI_API_KEY,
});
const openai = new OpenAIApi(configuration);
const port = process.env.PORT || 5000;
// constants
const COMPLETIONS_MODEL = "text-davinci-003";
const EMBEDDING_MODEL = "text-embedding-ada-002";
// functions
// ---
function cosineSimilarity(A, B) {
let dotProduct = 0;
let normA = 0;
let normB = 0;
for (let i = 0; i < A.length; i++) {
dotProduct += A[i] * B[i];
normA += A[i] * A[i];
normB += B[i] * B[i];
}
normA = Math.sqrt(normA);
normB = Math.sqrt(normB);
return dotProduct / (normA * normB);
}
function getSimilarityScore(embeddingsHash, promptEmbedding) {
const similarityScoreHash = {};
Object.keys(embeddingsHash).forEach((text) => {
similarityScoreHash[text] = cosineSimilarity(
promptEmbedding,
JSON.parse(embeddingsHash[text])
);
});
return similarityScoreHash;
}
function getAirtableData() {
return new Promise((resolve, reject) => {
airtableView.firstPage((error, records) => {
if (error) {
console.log(error);
return reject({});
}
const recordsHash = {};
records.forEach(
(record) => (recordsHash[record.get("Text")] = record.get("Embedding"))
);
resolve(recordsHash);
});
});
}
// ---
app.post("/ask", async (req, res) => {
const prompt = req.body.prompt;
try {
if (prompt == null) {
throw new Error("Uh oh, no prompt was provided");
}
// getting text and embeddings data from airtable
const embeddingsHash = await getAirtableData();
// get embeddings value for prompt question
const promptEmbeddingsResponse = await openai.createEmbedding({
model: EMBEDDING_MODEL,
input: prompt,
max_tokens: 64,
});
const promptEmbedding = promptEmbeddingsResponse.data.data[0].embedding;
// create map of text against similarity score
const similarityScoreHash = getSimilarityScore(
embeddingsHash,
promptEmbedding
);
// get text (i.e. key) from score map that has highest similarity score
const textWithHighestScore = Object.keys(similarityScoreHash).reduce(
(a, b) => (similarityScoreHash[a] > similarityScoreHash[b] ? a : b)
);
// build final prompt
const finalPrompt = `
Info: ${textWithHighestScore}
Question: ${prompt}
Answer:
`;
const response = await openai.createCompletion({
model: COMPLETIONS_MODEL,
prompt: finalPrompt,
max_tokens: 64,
});
const completion = response.data.choices[0].text;
return res.status(200).json({
success: true,
message: completion,
});
} catch (error) {
console.log(error.message);
}
});
app.listen(port, () => console.log(`Server is running on port ${port}!!`));