-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
embeddings.ts
191 lines (162 loc) Β· 5.54 KB
/
embeddings.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import { GoogleGenerativeAI, GenerativeModel } from "@google/generative-ai";
import type { TaskType, EmbedContentRequest } from "@google/generative-ai";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings";
import { chunkArray } from "@langchain/core/utils/chunk_array";
/**
* Interface that extends EmbeddingsParams and defines additional
* parameters specific to the GoogleGenerativeAIEmbeddings class.
*/
export interface GoogleGenerativeAIEmbeddingsParams extends EmbeddingsParams {
/**
* Model Name to use
*
* Alias for `model`
*
* Note: The format must follow the pattern - `{model}`
*/
modelName?: string;
/**
* Model Name to use
*
* Note: The format must follow the pattern - `{model}`
*/
model?: string;
/**
* Type of task for which the embedding will be used
*
* Note: currently only supported by `embedding-001` model
*/
taskType?: TaskType;
/**
* An optional title for the text. Only applicable when TaskType is
* `RETRIEVAL_DOCUMENT`
*
* Note: currently only supported by `embedding-001` model
*/
title?: string;
/**
* Whether to strip new lines from the input text. Default to true
*/
stripNewLines?: boolean;
/**
* Google API key to use
*/
apiKey?: string;
}
/**
* Class that extends the Embeddings class and provides methods for
* generating embeddings using the Google Palm API.
* @example
* ```typescript
* const model = new GoogleGenerativeAIEmbeddings({
* apiKey: "<YOUR API KEY>",
* modelName: "embedding-001",
* });
*
* // Embed a single query
* const res = await model.embedQuery(
* "What would be a good company name for a company that makes colorful socks?"
* );
* console.log({ res });
*
* // Embed multiple documents
* const documentRes = await model.embedDocuments(["Hello world", "Bye bye"]);
* console.log({ documentRes });
* ```
*/
export class GoogleGenerativeAIEmbeddings
extends Embeddings
implements GoogleGenerativeAIEmbeddingsParams
{
apiKey?: string;
modelName = "embedding-001";
model = "embedding-001";
taskType?: TaskType;
title?: string;
stripNewLines = true;
maxBatchSize = 100; // Max batch size for embedDocuments set by GenerativeModel client's batchEmbedContents call
private client: GenerativeModel;
constructor(fields?: GoogleGenerativeAIEmbeddingsParams) {
super(fields ?? {});
this.modelName =
fields?.model?.replace(/^models\//, "") ??
fields?.modelName?.replace(/^models\//, "") ??
this.modelName;
this.model = this.modelName;
this.taskType = fields?.taskType ?? this.taskType;
this.title = fields?.title ?? this.title;
if (this.title && this.taskType !== "RETRIEVAL_DOCUMENT") {
throw new Error(
"title can only be sepcified with TaskType.RETRIEVAL_DOCUMENT"
);
}
this.apiKey = fields?.apiKey ?? getEnvironmentVariable("GOOGLE_API_KEY");
if (!this.apiKey) {
throw new Error(
"Please set an API key for Google GenerativeAI " +
"in the environmentb variable GOOGLE_API_KEY " +
"or in the `apiKey` field of the " +
"GoogleGenerativeAIEmbeddings constructor"
);
}
this.client = new GoogleGenerativeAI(this.apiKey).getGenerativeModel({
model: this.model,
});
}
private _convertToContent(text: string): EmbedContentRequest {
const cleanedText = this.stripNewLines ? text.replace(/\n/g, " ") : text;
return {
content: { role: "user", parts: [{ text: cleanedText }] },
taskType: this.taskType,
title: this.title,
};
}
protected async _embedQueryContent(text: string): Promise<number[]> {
const req = this._convertToContent(text);
const res = await this.client.embedContent(req);
return res.embedding.values ?? [];
}
protected async _embedDocumentsContent(
documents: string[]
): Promise<number[][]> {
const batchEmbedChunks: string[][] = chunkArray<string>(
documents,
this.maxBatchSize
);
const batchEmbedRequests = batchEmbedChunks.map((chunk) => ({
requests: chunk.map((doc) => this._convertToContent(doc)),
}));
const responses = await Promise.allSettled(
batchEmbedRequests.map((req) => this.client.batchEmbedContents(req))
);
const embeddings = responses.flatMap((res, idx) => {
if (res.status === "fulfilled") {
return res.value.embeddings.map((e) => e.values || []);
} else {
return Array(batchEmbedChunks[idx].length).fill([]);
}
});
return embeddings;
}
/**
* Method that takes a document as input and returns a promise that
* resolves to an embedding for the document. It calls the _embedText
* method with the document as the input.
* @param document Document for which to generate an embedding.
* @returns Promise that resolves to an embedding for the input document.
*/
embedQuery(document: string): Promise<number[]> {
return this.caller.call(this._embedQueryContent.bind(this), document);
}
/**
* Method that takes an array of documents as input and returns a promise
* that resolves to a 2D array of embeddings for each document. It calls
* the _embedText method for each document in the array.
* @param documents Array of documents for which to generate embeddings.
* @returns Promise that resolves to a 2D array of embeddings for each input document.
*/
embedDocuments(documents: string[]): Promise<number[][]> {
return this.caller.call(this._embedDocumentsContent.bind(this), documents);
}
}