-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
cloudflare_workersai.ts
94 lines (76 loc) · 2.41 KB
/
cloudflare_workersai.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import { Ai } from "@cloudflare/ai";
import { Fetcher } from "@cloudflare/workers-types";
import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings";
import { chunkArray } from "@langchain/core/utils/chunk_array";
type AiTextEmbeddingsInput = {
text: string | string[];
};
type AiTextEmbeddingsOutput = {
shape: number[];
data: number[][];
};
export interface CloudflareWorkersAIEmbeddingsParams extends EmbeddingsParams {
/** Binding */
binding: Fetcher;
/** Model name to use */
modelName?: string;
/**
* The maximum number of documents to embed in a single request.
*/
batchSize?: number;
/**
* Whether to strip new lines from the input text. This is recommended by
* OpenAI, but may not be suitable for all use cases.
*/
stripNewLines?: boolean;
}
export class CloudflareWorkersAIEmbeddings extends Embeddings {
modelName = "@cf/baai/bge-base-en-v1.5";
batchSize = 50;
stripNewLines = true;
ai: Ai;
constructor(fields: CloudflareWorkersAIEmbeddingsParams) {
super(fields);
if (!fields.binding) {
throw new Error(
"Must supply a Workers AI binding, eg { binding: env.AI }"
);
}
this.ai = new Ai(fields.binding);
this.modelName = fields.modelName ?? this.modelName;
this.stripNewLines = fields.stripNewLines ?? this.stripNewLines;
}
async embedDocuments(texts: string[]): Promise<number[][]> {
const batches = chunkArray(
this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts,
this.batchSize
);
const batchRequests = batches.map((batch) => this.runEmbedding(batch));
const batchResponses = await Promise.all(batchRequests);
const embeddings: number[][] = [];
for (let i = 0; i < batchResponses.length; i += 1) {
const batchResponse = batchResponses[i];
for (let j = 0; j < batchResponse.length; j += 1) {
embeddings.push(batchResponse[j]);
}
}
return embeddings;
}
async embedQuery(text: string): Promise<number[]> {
const data = await this.runEmbedding([
this.stripNewLines ? text.replace(/\n/g, " ") : text,
]);
return data[0];
}
private async runEmbedding(texts: string[]) {
return this.caller.call(async () => {
const response: AiTextEmbeddingsOutput = await this.ai.run(
this.modelName,
{
text: texts,
} as AiTextEmbeddingsInput
);
return response.data;
});
}
}