-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
embeddings.ts
143 lines (129 loc) Β· 4.46 KB
/
embeddings.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings";
import { type EmbeddingsResult as MistralAIEmbeddingsResult } from "@mistralai/mistralai";
import { chunkArray } from "@langchain/core/utils/chunk_array";
/**
* Interface for MistralAIEmbeddings parameters. Extends EmbeddingsParams and
* defines additional parameters specific to the MistralAIEmbeddings class.
*/
export interface MistralAIEmbeddingsParams extends EmbeddingsParams {
/**
* The API key to use.
* @default {process.env.MISTRAL_API_KEY}
*/
apiKey?: string;
/**
* The name of the model to use.
* @default {"mistral-embed"}
*/
modelName?: string;
/**
* The format of the output data.
* @default {"float"}
*/
encodingFormat?: string;
/**
* Override the default endpoint.
*/
endpoint?: string;
/**
* The maximum number of documents to embed in a single request.
* @default {512}
*/
batchSize?: number;
/**
* Whether to strip new lines from the input text. This is recommended,
* but may not be suitable for all use cases.
* @default {true}
*/
stripNewLines?: boolean;
}
/**
* Class for generating embeddings using the MistralAI API.
*/
export class MistralAIEmbeddings
extends Embeddings
implements MistralAIEmbeddingsParams
{
modelName = "mistral-embed";
encodingFormat = "float";
batchSize = 512;
stripNewLines = true;
apiKey: string;
endpoint?: string;
constructor(fields?: Partial<MistralAIEmbeddingsParams>) {
super(fields ?? {});
const apiKey = fields?.apiKey ?? getEnvironmentVariable("MISTRAL_API_KEY");
if (!apiKey) {
throw new Error("API key missing for MistralAI, but it is required.");
}
this.apiKey = apiKey;
this.endpoint = fields?.endpoint;
this.modelName = fields?.modelName ?? this.modelName;
this.encodingFormat = fields?.encodingFormat ?? this.encodingFormat;
this.batchSize = fields?.batchSize ?? this.batchSize;
this.stripNewLines = fields?.stripNewLines ?? this.stripNewLines;
}
/**
* Method to generate embeddings for an array of documents. Splits the
* documents into batches and makes requests to the MistralAI API to generate
* embeddings.
* @param {Array<string>} texts Array of documents to generate embeddings for.
* @returns {Promise<number[][]>} Promise that resolves to a 2D array of embeddings for each document.
*/
async embedDocuments(texts: string[]): Promise<number[][]> {
const batches = chunkArray(
this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts,
this.batchSize
);
const batchRequests = batches.map((batch) =>
this.embeddingWithRetry(batch)
);
const batchResponses = await Promise.all(batchRequests);
const embeddings: number[][] = [];
for (let i = 0; i < batchResponses.length; i += 1) {
const batch = batches[i];
const { data: batchResponse } = batchResponses[i];
for (let j = 0; j < batch.length; j += 1) {
embeddings.push(batchResponse[j].embedding);
}
}
return embeddings;
}
/**
* Method to generate an embedding for a single document. Calls the
* embeddingWithRetry method with the document as the input.
* @param {string} text Document to generate an embedding for.
* @returns {Promise<number[]>} Promise that resolves to an embedding for the document.
*/
async embedQuery(text: string): Promise<number[]> {
const { data } = await this.embeddingWithRetry(
this.stripNewLines ? text.replace(/\n/g, " ") : text
);
return data[0].embedding;
}
/**
* Private method to make a request to the MistralAI API to generate
* embeddings. Handles the retry logic and returns the response from the
* API.
* @param {string | Array<string>} input Text to send to the MistralAI API.
* @returns {Promise<MistralAIEmbeddingsResult>} Promise that resolves to the response from the API.
*/
private async embeddingWithRetry(
input: string | Array<string>
): Promise<MistralAIEmbeddingsResult> {
return this.caller.call(async () => {
const { MistralClient } = await this.imports();
const client = new MistralClient(this.apiKey, this.endpoint);
const res = await client.embeddings({
model: this.modelName,
input,
});
return res;
});
}
async imports() {
const { default: MistralClient } = await import("@mistralai/mistralai");
return { MistralClient };
}
}