-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
embeddings.ts
174 lines (151 loc) Β· 5.14 KB
/
embeddings.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import { CohereClient } from "cohere-ai";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings";
import { chunkArray } from "@langchain/core/utils/chunk_array";
/**
* Interface that extends EmbeddingsParams and defines additional
* parameters specific to the CohereEmbeddings class.
*/
export interface CohereEmbeddingsParams extends EmbeddingsParams {
model: string;
/**
* The maximum number of documents to embed in a single request. This is
* limited by the Cohere API to a maximum of 96.
*/
batchSize?: number;
/**
* Specifies the type of input you're giving to the model.
* Not required for older versions of the embedding models (i.e. anything lower than v3),
* but is required for more recent versions (i.e. anything bigger than v2).
*
* * `search_document` - Use this when you encode documents for embeddings that you store in a vector database for search use-cases.
* * `search_query` - Use this when you query your vector DB to find relevant documents.
* * `classification` - Use this when you use the embeddings as an input to a text classifier.
* * `clustering` - Use this when you want to cluster the embeddings.
*/
inputType?: string;
}
/**
* A class for generating embeddings using the Cohere API.
*/
export class CohereEmbeddings
extends Embeddings
implements CohereEmbeddingsParams
{
model = "small";
batchSize = 48;
inputType: string | undefined;
private client: CohereClient;
/**
* Constructor for the CohereEmbeddings class.
* @param fields - An optional object with properties to configure the instance.
*/
constructor(
fields?: Partial<CohereEmbeddingsParams> & {
verbose?: boolean;
apiKey?: string;
}
) {
const fieldsWithDefaults = { maxConcurrency: 2, ...fields };
super(fieldsWithDefaults);
const apiKey =
fieldsWithDefaults?.apiKey || getEnvironmentVariable("COHERE_API_KEY");
if (!apiKey) {
throw new Error("Cohere API key not found");
}
this.client = new CohereClient({
token: apiKey,
});
this.model = fieldsWithDefaults?.model ?? this.model;
this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize;
this.inputType = fieldsWithDefaults?.inputType;
}
/**
* Generates embeddings for an array of texts.
* @param texts - An array of strings to generate embeddings for.
* @returns A Promise that resolves to an array of embeddings.
*/
async embedDocuments(texts: string[]): Promise<number[][]> {
const batches = chunkArray(texts, this.batchSize);
const batchRequests = batches.map((batch) =>
this.embeddingWithRetry({
model: this.model,
texts: batch,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
inputType: this.inputType as any,
})
);
const batchResponses = await Promise.all(batchRequests);
const embeddings: number[][] = [];
for (let i = 0; i < batchResponses.length; i += 1) {
const batch = batches[i];
const { embeddings: batchResponse } = batchResponses[i];
for (let j = 0; j < batch.length; j += 1) {
if ("float" in batchResponse && batchResponse.float) {
embeddings.push(batchResponse.float[j]);
} else if (Array.isArray(batchResponse)) {
embeddings.push(batchResponse[j as number]);
}
}
}
return embeddings;
}
/**
* Generates an embedding for a single text.
* @param text - A string to generate an embedding for.
* @returns A Promise that resolves to an array of numbers representing the embedding.
*/
async embedQuery(text: string): Promise<number[]> {
const { embeddings } = await this.embeddingWithRetry({
model: this.model,
texts: [text],
// eslint-disable-next-line @typescript-eslint/no-explicit-any
inputType: this.inputType as any,
});
if ("float" in embeddings && embeddings.float) {
return embeddings.float[0];
} else if (Array.isArray(embeddings)) {
return embeddings[0];
} else {
throw new Error(
`Invalid response from Cohere API. Received: ${JSON.stringify(
embeddings,
null,
2
)}`
);
}
}
/**
* Generates embeddings with retry capabilities.
* @param request - An object containing the request parameters for generating embeddings.
* @returns A Promise that resolves to the API response.
*/
private async embeddingWithRetry(
request: Parameters<typeof this.client.embed>[0]
) {
return this.caller.call(async () => {
let response;
try {
response = await this.client.embed(request);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
e.status = e.status ?? e.statusCode;
throw e;
}
return response;
});
}
get lc_secrets(): { [key: string]: string } | undefined {
return {
apiKey: "COHERE_API_KEY",
api_key: "COHERE_API_KEY",
};
}
get lc_aliases(): { [key: string]: string } | undefined {
return {
apiKey: "cohere_api_key",
api_key: "cohere_api_key",
};
}
}