-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
hf.ts
77 lines (67 loc) · 2.61 KB
/
hf.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import { HfInference, HfInferenceEndpoint } from "@huggingface/inference";
import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
/**
* Interface that extends EmbeddingsParams and defines additional
* parameters specific to the HuggingFaceInferenceEmbeddings class.
*/
export interface HuggingFaceInferenceEmbeddingsParams extends EmbeddingsParams {
apiKey?: string;
model?: string;
endpointUrl?: string;
}
/**
* Class that extends the Embeddings class and provides methods for
* generating embeddings using Hugging Face models through the
* HuggingFaceInference API.
*/
export class HuggingFaceInferenceEmbeddings
extends Embeddings
implements HuggingFaceInferenceEmbeddingsParams
{
apiKey?: string;
model: string;
endpointUrl?: string;
client: HfInference | HfInferenceEndpoint;
constructor(fields?: HuggingFaceInferenceEmbeddingsParams) {
super(fields ?? {});
this.model = fields?.model ?? "BAAI/bge-base-en-v1.5";
this.apiKey =
fields?.apiKey ?? getEnvironmentVariable("HUGGINGFACEHUB_API_KEY");
this.endpointUrl = fields?.endpointUrl;
this.client = this.endpointUrl
? new HfInference(this.apiKey).endpoint(this.endpointUrl)
: new HfInference(this.apiKey);
}
async _embed(texts: string[]): Promise<number[][]> {
// replace newlines, which can negatively affect performance.
const clean = texts.map((text) => text.replace(/\n/g, " "));
return this.caller.call(() =>
this.client.featureExtraction({
model: this.model,
inputs: clean,
})
) as Promise<number[][]>;
}
/**
* Method that takes a document as input and returns a promise that
* resolves to an embedding for the document. It calls the _embed method
* with the document as the input and returns the first embedding in the
* resulting array.
* @param document Document to generate an embedding for.
* @returns Promise that resolves to an embedding for the document.
*/
embedQuery(document: string): Promise<number[]> {
return this._embed([document]).then((embeddings) => embeddings[0]);
}
/**
* Method that takes an array of documents as input and returns a promise
* that resolves to a 2D array of embeddings for each document. It calls
* the _embed method with the documents as the input.
* @param documents Array of documents to generate embeddings for.
* @returns Promise that resolves to a 2D array of embeddings for each document.
*/
embedDocuments(documents: string[]): Promise<number[][]> {
return this._embed(documents);
}
}