-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
embeddings_filter.ts
96 lines (86 loc) · 2.97 KB
/
embeddings_filter.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import type { EmbeddingsInterface } from "@langchain/core/embeddings";
import type { DocumentInterface } from "@langchain/core/documents";
import { cosineSimilarity } from "@langchain/core/utils/math";
import { BaseDocumentCompressor } from "./index.js";
/**
* Interface for the parameters of the `EmbeddingsFilter` class.
*/
export interface EmbeddingsFilterParams {
embeddings: EmbeddingsInterface;
similarityFn?: (x: number[][], y: number[][]) => number[][];
similarityThreshold?: number;
k?: number;
}
/**
* Class that represents a document compressor that uses embeddings to
* drop documents unrelated to the query.
* @example
* ```typescript
* const embeddingsFilter = new EmbeddingsFilter({
* embeddings: new OpenAIEmbeddings(),
* similarityThreshold: 0.8,
* k: 5,
* });
* const retrievedDocs = await embeddingsFilter.filterDocuments(
* getDocuments(),
* "What did the speaker say about Justice Breyer in the 2022 State of the Union?",
* );
* console.log({ retrievedDocs });
* ```
*/
export class EmbeddingsFilter extends BaseDocumentCompressor {
/**
* Embeddings to use for embedding document contents and queries.
*/
embeddings: EmbeddingsInterface;
/**
* Similarity function for comparing documents.
*/
similarityFn = cosineSimilarity;
/**
* Threshold for determining when two documents are similar enough
* to be considered redundant. Must be specified if `k` is not set.
*/
similarityThreshold?: number;
/**
* The number of relevant documents to return. Can be explicitly set to undefined, in which case
* similarity_threshold` must be specified. Defaults to 20
*/
k? = 20;
constructor(params: EmbeddingsFilterParams) {
super();
this.embeddings = params.embeddings;
this.similarityFn = params.similarityFn ?? this.similarityFn;
this.similarityThreshold = params.similarityThreshold;
this.k = "k" in params ? params.k : this.k;
if (this.k === undefined && this.similarityThreshold === undefined) {
throw new Error(`Must specify one of "k" or "similarity_threshold".`);
}
}
async compressDocuments(
documents: DocumentInterface[],
query: string
): Promise<DocumentInterface[]> {
const embeddedDocuments = await this.embeddings.embedDocuments(
documents.map((doc) => doc.pageContent)
);
const embeddedQuery = await this.embeddings.embedQuery(query);
const similarity = this.similarityFn([embeddedQuery], embeddedDocuments)[0];
let includedIdxs = Array.from(
{ length: embeddedDocuments.length },
(_, i) => i
);
if (this.k !== undefined) {
includedIdxs = includedIdxs
.map((v, i) => [similarity[i], v])
.sort(([a], [b]) => b - a)
.slice(0, this.k)
.map(([, i]) => i);
}
if (this.similarityThreshold !== undefined) {
const threshold = this.similarityThreshold;
includedIdxs = includedIdxs.filter((i) => similarity[i] > threshold);
}
return includedIdxs.map((i) => documents[i]);
}
}