-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
llm_router.ts
73 lines (65 loc) · 2.43 KB
/
llm_router.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import type { BaseLanguageModelInterface } from "@langchain/core/language_models/base";
import { BasePromptTemplate } from "../../prompts/base.js";
import { LLMChain } from "../../chains/llm_chain.js";
import { RouterChain } from "./multi_route.js";
import { CallbackManagerForChainRun } from "../../callbacks/manager.js";
import { ChainValues } from "../../schema/index.js";
import { ChainInputs } from "../../chains/base.js";
/**
* A type that represents the output schema of a router chain. It defines
* the structure of the output data returned by the router chain.
*/
export type RouterOutputSchema = {
destination: string;
next_inputs: { [key: string]: string };
};
/**
* An interface that extends the default ChainInputs interface and adds an
* additional "llmChain" property.
*/
export interface LLMRouterChainInput extends ChainInputs {
llmChain: LLMChain<RouterOutputSchema>;
}
/**
* A class that represents an LLM router chain in the LangChain framework.
* It extends the RouterChain class and implements the LLMRouterChainInput
* interface. It provides additional functionality specific to LLMs and
* routing based on LLM predictions.
*/
export class LLMRouterChain extends RouterChain implements LLMRouterChainInput {
llmChain: LLMChain<RouterOutputSchema>;
constructor(fields: LLMRouterChainInput) {
super(fields);
this.llmChain = fields.llmChain;
}
get inputKeys(): string[] {
return this.llmChain.inputKeys;
}
async _call(
values: ChainValues,
runManager?: CallbackManagerForChainRun | undefined
): Promise<RouterOutputSchema> {
return this.llmChain.predict(values, runManager?.getChild());
}
_chainType(): string {
return "llm_router_chain";
}
/**
* A static method that creates an instance of LLMRouterChain from a
* BaseLanguageModel and a BasePromptTemplate. It takes in an optional
* options object and returns an instance of LLMRouterChain with the
* specified LLMChain.
* @param llm A BaseLanguageModel instance.
* @param prompt A BasePromptTemplate instance.
* @param options Optional LLMRouterChainInput object, excluding "llmChain".
* @returns An instance of LLMRouterChain.
*/
static fromLLM(
llm: BaseLanguageModelInterface,
prompt: BasePromptTemplate,
options?: Omit<LLMRouterChainInput, "llmChain">
) {
const llmChain = new LLMChain<RouterOutputSchema>({ llm, prompt });
return new LLMRouterChain({ ...options, llmChain });
}
}