-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
score_threshold.ts
57 lines (50 loc) · 1.54 KB
/
score_threshold.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import { Document } from "@langchain/core/documents";
import {
VectorStore,
VectorStoreRetriever,
VectorStoreRetrieverInput,
} from "@langchain/core/vectorstores";
export type ScoreThresholdRetrieverInput<V extends VectorStore> = Omit<
VectorStoreRetrieverInput<V>,
"k"
> & {
maxK?: number;
kIncrement?: number;
minSimilarityScore: number;
};
export class ScoreThresholdRetriever<
V extends VectorStore
> extends VectorStoreRetriever<V> {
minSimilarityScore: number;
kIncrement = 10;
maxK = 100;
constructor(input: ScoreThresholdRetrieverInput<V>) {
super(input);
this.maxK = input.maxK ?? this.maxK;
this.minSimilarityScore =
input.minSimilarityScore ?? this.minSimilarityScore;
this.kIncrement = input.kIncrement ?? this.kIncrement;
}
async getRelevantDocuments(query: string): Promise<Document[]> {
let currentK = 0;
let filteredResults: [Document, number][] = [];
do {
currentK += this.kIncrement;
const results = await this.vectorStore.similaritySearchWithScore(
query,
currentK,
this.filter
);
filteredResults = results.filter(
([, score]) => score >= this.minSimilarityScore
);
} while (filteredResults.length >= currentK && currentK < this.maxK);
return filteredResults.map((documents) => documents[0]).slice(0, this.maxK);
}
static fromVectorStore<V extends VectorStore>(
vectorStore: V,
options: Omit<ScoreThresholdRetrieverInput<V>, "vectorStore">
) {
return new this<V>({ ...options, vectorStore });
}
}