-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
api_chain.ts
155 lines (132 loc) Β· 4.28 KB
/
api_chain.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import { BaseChain, ChainInputs } from "../base.js";
import { SerializedAPIChain } from "../serde.js";
import { LLMChain } from "../llm_chain.js";
import { BaseLanguageModel } from "../../base_language/index.js";
import { CallbackManagerForChainRun } from "../../callbacks/manager.js";
import { ChainValues } from "../../schema/index.js";
import {
API_URL_PROMPT_TEMPLATE,
API_RESPONSE_PROMPT_TEMPLATE,
} from "./prompts.js";
import { BasePromptTemplate } from "../../prompts/base.js";
/**
* Interface that extends ChainInputs and defines additional input
* parameters specific to an APIChain.
*/
export interface APIChainInput extends Omit<ChainInputs, "memory"> {
apiAnswerChain: LLMChain;
apiRequestChain: LLMChain;
apiDocs: string;
inputKey?: string;
headers?: Record<string, string>;
/** Key to use for output, defaults to `output` */
outputKey?: string;
}
/**
* Type that defines optional configuration options for an APIChain.
*/
export type APIChainOptions = {
headers?: Record<string, string>;
apiUrlPrompt?: BasePromptTemplate;
apiResponsePrompt?: BasePromptTemplate;
};
/**
* Class that extends BaseChain and represents a chain specifically
* designed for making API requests and processing API responses.
*/
export class APIChain extends BaseChain implements APIChainInput {
apiAnswerChain: LLMChain;
apiRequestChain: LLMChain;
apiDocs: string;
headers = {};
inputKey = "question";
outputKey = "output";
get inputKeys() {
return [this.inputKey];
}
get outputKeys() {
return [this.outputKey];
}
constructor(fields: APIChainInput) {
super(fields);
this.apiRequestChain = fields.apiRequestChain;
this.apiAnswerChain = fields.apiAnswerChain;
this.apiDocs = fields.apiDocs;
this.inputKey = fields.inputKey ?? this.inputKey;
this.outputKey = fields.outputKey ?? this.outputKey;
this.headers = fields.headers ?? this.headers;
}
/** @ignore */
async _call(
values: ChainValues,
runManager?: CallbackManagerForChainRun
): Promise<ChainValues> {
const question: string = values[this.inputKey];
const api_url = await this.apiRequestChain.predict(
{ question, api_docs: this.apiDocs },
runManager?.getChild("request")
);
const res = await fetch(api_url, { headers: this.headers });
const api_response = await res.text();
const answer = await this.apiAnswerChain.predict(
{ question, api_docs: this.apiDocs, api_url, api_response },
runManager?.getChild("response")
);
return { [this.outputKey]: answer };
}
_chainType() {
return "api_chain" as const;
}
static async deserialize(data: SerializedAPIChain) {
const { api_request_chain, api_answer_chain, api_docs } = data;
if (!api_request_chain) {
throw new Error("LLMChain must have api_request_chain");
}
if (!api_answer_chain) {
throw new Error("LLMChain must have api_answer_chain");
}
if (!api_docs) {
throw new Error("LLMChain must have api_docs");
}
return new APIChain({
apiAnswerChain: await LLMChain.deserialize(api_answer_chain),
apiRequestChain: await LLMChain.deserialize(api_request_chain),
apiDocs: api_docs,
});
}
serialize(): SerializedAPIChain {
return {
_type: this._chainType(),
api_answer_chain: this.apiAnswerChain.serialize(),
api_request_chain: this.apiRequestChain.serialize(),
api_docs: this.apiDocs,
};
}
/**
* Static method to create a new APIChain from a BaseLanguageModel and API
* documentation.
* @param llm BaseLanguageModel instance.
* @param apiDocs API documentation.
* @param options Optional configuration options for the APIChain.
* @returns New APIChain instance.
*/
static fromLLMAndAPIDocs(
llm: BaseLanguageModel,
apiDocs: string,
options: APIChainOptions &
Omit<APIChainInput, "apiAnswerChain" | "apiRequestChain" | "apiDocs"> = {}
): APIChain {
const {
apiUrlPrompt = API_URL_PROMPT_TEMPLATE,
apiResponsePrompt = API_RESPONSE_PROMPT_TEMPLATE,
} = options;
const apiRequestChain = new LLMChain({ prompt: apiUrlPrompt, llm });
const apiAnswerChain = new LLMChain({ prompt: apiResponsePrompt, llm });
return new this({
apiAnswerChain,
apiRequestChain,
apiDocs,
...options,
});
}
}