-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
contextual_compression.ts
68 lines (60 loc) · 2.13 KB
/
contextual_compression.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import {
BaseRetriever,
type BaseRetrieverInput,
type BaseRetrieverInterface,
} from "@langchain/core/retrievers";
import type { DocumentInterface } from "@langchain/core/documents";
import { CallbackManagerForRetrieverRun } from "@langchain/core/callbacks/manager";
import { BaseDocumentCompressor } from "./document_compressors/index.js";
/**
* Interface for the arguments required to construct a
* ContextualCompressionRetriever. It extends the BaseRetrieverInput
* interface with two additional fields: baseCompressor and baseRetriever.
*/
export interface ContextualCompressionRetrieverArgs extends BaseRetrieverInput {
baseCompressor: BaseDocumentCompressor;
baseRetriever: BaseRetrieverInterface;
}
/**
* A retriever that wraps a base retriever and compresses the results. It
* retrieves relevant documents based on a given query and then compresses
* these documents using a specified document compressor.
* @example
* ```typescript
* const retriever = new ContextualCompressionRetriever({
* baseCompressor: new LLMChainExtractor(),
* baseRetriever: new HNSWLib().asRetriever(),
* });
* const retrievedDocs = await retriever.getRelevantDocuments(
* "What did the speaker say about Justice Breyer?",
* );
* ```
*/
export class ContextualCompressionRetriever extends BaseRetriever {
static lc_name() {
return "ContextualCompressionRetriever";
}
lc_namespace = ["langchain", "retrievers", "contextual_compression"];
baseCompressor: BaseDocumentCompressor;
baseRetriever: BaseRetrieverInterface;
constructor(fields: ContextualCompressionRetrieverArgs) {
super(fields);
this.baseCompressor = fields.baseCompressor;
this.baseRetriever = fields.baseRetriever;
}
async _getRelevantDocuments(
query: string,
runManager?: CallbackManagerForRetrieverRun
): Promise<DocumentInterface[]> {
const docs = await this.baseRetriever.getRelevantDocuments(
query,
runManager?.getChild("base_retriever")
);
const compressedDocs = await this.baseCompressor.compressDocuments(
docs,
query,
runManager?.getChild("base_compressor")
);
return compressedDocs;
}
}