-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
hf_transformers.ts
113 lines (94 loc) · 2.95 KB
/
hf_transformers.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import { Pipeline, pipeline } from "@xenova/transformers";
import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings";
import { chunkArray } from "@langchain/core/utils/chunk_array";
export interface HuggingFaceTransformersEmbeddingsParams
extends EmbeddingsParams {
/**
* Model name to use
* Alias for `model`
*/
modelName: string;
/** Model name to use */
model: string;
/**
* Timeout to use when making requests to OpenAI.
*/
timeout?: number;
/**
* The maximum number of documents to embed in a single request.
*/
batchSize?: number;
/**
* Whether to strip new lines from the input text. This is recommended by
* OpenAI, but may not be suitable for all use cases.
*/
stripNewLines?: boolean;
}
/**
* @example
* ```typescript
* const model = new HuggingFaceTransformersEmbeddings({
* model: "Xenova/all-MiniLM-L6-v2",
* });
*
* // Embed a single query
* const res = await model.embedQuery(
* "What would be a good company name for a company that makes colorful socks?"
* );
* console.log({ res });
*
* // Embed multiple documents
* const documentRes = await model.embedDocuments(["Hello world", "Bye bye"]);
* console.log({ documentRes });
* ```
*/
export class HuggingFaceTransformersEmbeddings
extends Embeddings
implements HuggingFaceTransformersEmbeddingsParams
{
modelName = "Xenova/all-MiniLM-L6-v2";
model = "Xenova/all-MiniLM-L6-v2";
batchSize = 512;
stripNewLines = true;
timeout?: number;
private pipelinePromise: Promise<Pipeline>;
constructor(fields?: Partial<HuggingFaceTransformersEmbeddingsParams>) {
super(fields ?? {});
this.modelName = fields?.model ?? fields?.modelName ?? this.model;
this.model = this.modelName;
this.stripNewLines = fields?.stripNewLines ?? this.stripNewLines;
this.timeout = fields?.timeout;
}
async embedDocuments(texts: string[]): Promise<number[][]> {
const batches = chunkArray(
this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts,
this.batchSize
);
const batchRequests = batches.map((batch) => this.runEmbedding(batch));
const batchResponses = await Promise.all(batchRequests);
const embeddings: number[][] = [];
for (let i = 0; i < batchResponses.length; i += 1) {
const batchResponse = batchResponses[i];
for (let j = 0; j < batchResponse.length; j += 1) {
embeddings.push(batchResponse[j]);
}
}
return embeddings;
}
async embedQuery(text: string): Promise<number[]> {
const data = await this.runEmbedding([
this.stripNewLines ? text.replace(/\n/g, " ") : text,
]);
return data[0];
}
private async runEmbedding(texts: string[]) {
const pipe = await (this.pipelinePromise ??= pipeline(
"feature-extraction",
this.model
));
return this.caller.call(async () => {
const output = await pipe(texts, { pooling: "mean", normalize: true });
return output.tolist();
});
}
}