-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
history_aware_retriever.ts
91 lines (89 loc) · 2.94 KB
/
history_aware_retriever.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import type { LanguageModelLike } from "@langchain/core/language_models/base";
import {
type Runnable,
type RunnableInterface,
RunnableSequence,
RunnableBranch,
} from "@langchain/core/runnables";
import { type BasePromptTemplate } from "@langchain/core/prompts";
import { StringOutputParser } from "@langchain/core/output_parsers";
import type { DocumentInterface } from "@langchain/core/documents";
import type { BaseMessage } from "@langchain/core/messages";
/**
* Params for the createHistoryAwareRetriever method.
*/
export type CreateHistoryAwareRetrieverParams = {
/**
* Language model to use for generating a search term given chat history.
*/
llm: LanguageModelLike;
/**
* RetrieverLike object that takes a string as input and outputs a list of Documents.
*/
retriever: RunnableInterface<string, DocumentInterface[]>;
/**
* The prompt used to generate the search query for the retriever.
*/
rephrasePrompt: BasePromptTemplate;
};
/**
* Create a chain that takes conversation history and returns documents.
* If there is no `chat_history`, then the `input` is just passed directly to the
* retriever. If there is `chat_history`, then the prompt and LLM will be used
* to generate a search query. That search query is then passed to the retriever.
* @param {CreateHistoryAwareRetriever} params
* @returns An LCEL Runnable. The runnable input must take in `input`, and if there
* is chat history should take it in the form of `chat_history`.
* The Runnable output is a list of Documents
* @example
* ```typescript
* // yarn add langchain @langchain/openai
*
* import { ChatOpenAI } from "@langchain/openai";
* import { pull } from "langchain/hub";
* import { createHistoryAwareRetriever } from "langchain/chains/history_aware_retriever";
*
* const rephrasePrompt = await pull("langchain-ai/chat-langchain-rephrase");
* const llm = new ChatOpenAI({});
* const retriever = ...
* const chain = await createHistoryAwareRetriever({
* llm,
* retriever,
* rephrasePrompt,
* });
* const result = await chain.invoke({"input": "...", "chat_history": [] })
* ```
*/
export async function createHistoryAwareRetriever({
llm,
retriever,
rephrasePrompt,
}: CreateHistoryAwareRetrieverParams): Promise<
Runnable<
{ input: string; chat_history: string | BaseMessage[] },
DocumentInterface[]
>
> {
if (!rephrasePrompt.inputVariables.includes("input")) {
throw new Error(
`Expected "input" to be a prompt variable, but got ${JSON.stringify(
rephrasePrompt.inputVariables
)}`
);
}
const retrieveDocuments = RunnableBranch.from([
[
(input) => !input.chat_history || input.chat_history.length === 0,
RunnableSequence.from([(input) => input.input, retriever]),
],
RunnableSequence.from([
rephrasePrompt,
llm,
new StringOutputParser(),
retriever,
]),
]).withConfig({
runName: "history_aware_retriever",
});
return retrieveDocuments;
}