-
Notifications
You must be signed in to change notification settings - Fork 2k
/
analyze_documents_chain.ts
132 lines (115 loc) Β· 3.51 KB
/
analyze_documents_chain.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import { ChainValues } from "@langchain/core/utils/types";
import { CallbackManagerForChainRun } from "@langchain/core/callbacks/manager";
import { BaseChain, ChainInputs } from "./base.js";
import {
TextSplitter,
RecursiveCharacterTextSplitter,
} from "../text_splitter.js";
import { SerializedAnalyzeDocumentChain } from "./serde.js";
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type LoadValues = Record<string, any>;
/**
* Interface for the input parameters required by the AnalyzeDocumentChain
* class.
*/
export interface AnalyzeDocumentChainInput extends Omit<ChainInputs, "memory"> {
combineDocumentsChain: BaseChain;
textSplitter?: TextSplitter;
inputKey?: string;
}
/**
* Chain that combines documents by stuffing into context.
* @augments BaseChain
* @augments StuffDocumentsChainInput
* @example
* ```typescript
* const model = new ChatOpenAI({ temperature: 0 });
* const combineDocsChain = loadSummarizationChain(model);
* const chain = new AnalyzeDocumentChain({
* combineDocumentsChain: combineDocsChain,
* });
*
* // Read the text from a file (this is a placeholder for actual file reading)
* const text = readTextFromFile("state_of_the_union.txt");
*
* // Invoke the chain to analyze the document
* const res = await chain.call({
* input_document: text,
* });
*
* console.log({ res });
* ```
*/
export class AnalyzeDocumentChain
extends BaseChain
implements AnalyzeDocumentChainInput
{
static lc_name() {
return "AnalyzeDocumentChain";
}
inputKey = "input_document";
combineDocumentsChain: BaseChain;
textSplitter: TextSplitter;
constructor(fields: AnalyzeDocumentChainInput) {
super(fields);
this.combineDocumentsChain = fields.combineDocumentsChain;
this.inputKey = fields.inputKey ?? this.inputKey;
this.textSplitter =
fields.textSplitter ?? new RecursiveCharacterTextSplitter();
}
get inputKeys(): string[] {
return [this.inputKey];
}
get outputKeys(): string[] {
return this.combineDocumentsChain.outputKeys;
}
/** @ignore */
async _call(
values: ChainValues,
runManager?: CallbackManagerForChainRun
): Promise<ChainValues> {
if (!(this.inputKey in values)) {
throw new Error(`Document key ${this.inputKey} not found.`);
}
const { [this.inputKey]: doc, ...rest } = values;
const currentDoc = doc as string;
const currentDocs = await this.textSplitter.createDocuments([currentDoc]);
const newInputs = { input_documents: currentDocs, ...rest };
const result = await this.combineDocumentsChain.call(
newInputs,
runManager?.getChild("combine_documents")
);
return result;
}
_chainType() {
return "analyze_document_chain" as const;
}
static async deserialize(
data: SerializedAnalyzeDocumentChain,
values: LoadValues
) {
if (!("text_splitter" in values)) {
throw new Error(
`Need to pass in a text_splitter to deserialize AnalyzeDocumentChain.`
);
}
const { text_splitter } = values;
if (!data.combine_document_chain) {
throw new Error(
`Need to pass in a combine_document_chain to deserialize AnalyzeDocumentChain.`
);
}
return new AnalyzeDocumentChain({
combineDocumentsChain: await BaseChain.deserialize(
data.combine_document_chain
),
textSplitter: text_splitter,
});
}
serialize(): SerializedAnalyzeDocumentChain {
return {
_type: this._chainType(),
combine_document_chain: this.combineDocumentsChain.serialize(),
};
}
}