-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
chain_extract.ts
120 lines (108 loc) Β· 3.71 KB
/
chain_extract.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import type { BaseLanguageModelInterface } from "@langchain/core/language_models/base";
import { type DocumentInterface, Document } from "@langchain/core/documents";
import { PromptTemplate } from "@langchain/core/prompts";
import { BaseOutputParser } from "@langchain/core/output_parsers";
import { LLMChain } from "../../chains/llm_chain.js";
import { BaseDocumentCompressor } from "./index.js";
import { PROMPT_TEMPLATE } from "./chain_extract_prompt.js";
function defaultGetInput(
query: string,
doc: DocumentInterface
): Record<string, unknown> {
return { question: query, context: doc.pageContent };
}
class NoOutputParser extends BaseOutputParser<string> {
lc_namespace = [
"langchain",
"retrievers",
"document_compressors",
"chain_extract",
];
noOutputStr = "NO_OUTPUT";
parse(text: string): Promise<string> {
const cleanedText = text.trim();
if (cleanedText === this.noOutputStr) {
return Promise.resolve("");
}
return Promise.resolve(cleanedText);
}
getFormatInstructions(): string {
throw new Error("Method not implemented.");
}
}
function getDefaultChainPrompt(): PromptTemplate {
const outputParser = new NoOutputParser();
const template = PROMPT_TEMPLATE(outputParser.noOutputStr);
return new PromptTemplate({
template,
inputVariables: ["question", "context"],
outputParser,
});
}
/**
* Interface for the arguments required to create an instance of
* LLMChainExtractor.
*/
export interface LLMChainExtractorArgs {
llmChain: LLMChain;
getInput: (query: string, doc: DocumentInterface) => Record<string, unknown>;
}
/**
* A class that uses an LLM chain to extract relevant parts of documents.
* It extends the BaseDocumentCompressor class.
*/
export class LLMChainExtractor extends BaseDocumentCompressor {
llmChain: LLMChain;
getInput: (query: string, doc: DocumentInterface) => Record<string, unknown> =
defaultGetInput;
constructor({ llmChain, getInput }: LLMChainExtractorArgs) {
super();
this.llmChain = llmChain;
this.getInput = getInput;
}
/**
* Compresses a list of documents based on the output of an LLM chain.
* @param documents The list of documents to be compressed.
* @param query The query to be used for document compression.
* @returns A list of compressed documents.
*/
async compressDocuments(
documents: DocumentInterface[],
query: string
): Promise<DocumentInterface[]> {
const compressedDocs = await Promise.all(
documents.map(async (doc) => {
const input = this.getInput(query, doc);
const output = await this.llmChain.predict(input);
return output.length > 0
? new Document({
pageContent: output,
metadata: doc.metadata,
})
: undefined;
})
);
return compressedDocs.filter((doc): doc is Document => doc !== undefined);
}
/**
* Creates a new instance of LLMChainExtractor from a given LLM, prompt
* template, and getInput function.
* @param llm The BaseLanguageModel instance used for document extraction.
* @param prompt The PromptTemplate instance used for document extraction.
* @param getInput A function used for constructing the chain input from the query and a Document.
* @returns A new instance of LLMChainExtractor.
*/
static fromLLM(
llm: BaseLanguageModelInterface,
prompt?: PromptTemplate,
getInput?: (
query: string,
doc: DocumentInterface
) => Record<string, unknown>
): LLMChainExtractor {
const _prompt = prompt || getDefaultChainPrompt();
const _getInput = getInput || defaultGetInput;
const llmChain = new LLMChain({ llm, prompt: _prompt });
return new LLMChainExtractor({ llmChain, getInput: _getInput });
}
}