-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
tensorflow.ts
91 lines (82 loc) · 2.88 KB
/
tensorflow.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import { load } from "@tensorflow-models/universal-sentence-encoder";
import * as tf from "@tensorflow/tfjs-core";
import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings";
/**
* Interface that extends EmbeddingsParams and defines additional
* parameters specific to the TensorFlowEmbeddings class.
*/
export interface TensorFlowEmbeddingsParams extends EmbeddingsParams {}
/**
* Class that extends the Embeddings class and provides methods for
* generating embeddings using the Universal Sentence Encoder model from
* TensorFlow.js.
* @example
* ```typescript
* const embeddings = new TensorFlowEmbeddings();
* const store = new MemoryVectorStore(embeddings);
*
* const documents = [
* "A document",
* "Some other piece of text",
* "One more",
* "And another",
* ];
*
* await store.addDocuments(
* documents.map((pageContent) => new Document({ pageContent }))
* );
* ```
*/
export class TensorFlowEmbeddings extends Embeddings {
constructor(fields?: TensorFlowEmbeddingsParams) {
super(fields ?? {});
try {
tf.backend();
} catch (e) {
throw new Error("No TensorFlow backend found, see instructions at ...");
}
}
_cached: ReturnType<typeof load>;
/**
* Private method that loads the Universal Sentence Encoder model if it
* hasn't been loaded already. It returns a promise that resolves to the
* loaded model.
* @returns Promise that resolves to the loaded Universal Sentence Encoder model.
*/
private async load() {
if (this._cached === undefined) {
this._cached = load();
}
return this._cached;
}
private _embed(texts: string[]) {
return this.caller.call(async () => {
const model = await this.load();
return model.embed(texts);
});
}
/**
* Method that takes a document as input and returns a promise that
* resolves to an embedding for the document. It calls the _embed method
* with the document as the input and processes the result to return a
* single embedding.
* @param document Document to generate an embedding for.
* @returns Promise that resolves to an embedding for the input document.
*/
embedQuery(document: string): Promise<number[]> {
return this._embed([document])
.then((embeddings) => embeddings.array())
.then((embeddings) => embeddings[0]);
}
/**
* Method that takes an array of documents as input and returns a promise
* that resolves to a 2D array of embeddings for each document. It calls
* the _embed method with the documents as the input and processes the
* result to return the embeddings.
* @param documents Array of documents to generate embeddings for.
* @returns Promise that resolves to a 2D array of embeddings for each input document.
*/
embedDocuments(documents: string[]): Promise<number[][]> {
return this._embed(documents).then((embeddings) => embeddings.array());
}
}