-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
retrieval.ts
109 lines (106 loc) · 3.83 KB
/
retrieval.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import type { BaseRetrieverInterface } from "@langchain/core/retrievers";
import {
type Runnable,
RunnableSequence,
type RunnableInterface,
RunnablePassthrough,
} from "@langchain/core/runnables";
import type { BaseMessage } from "@langchain/core/messages";
import type { DocumentInterface } from "@langchain/core/documents";
/**
* Parameters for the createRetrievalChain method.
*/
export type CreateRetrievalChainParams<RunOutput> = {
/**
* Retriever-like object that returns list of documents. Should
* either be a subclass of BaseRetriever or a Runnable that returns
* a list of documents. If a subclass of BaseRetriever, then it
* is expected that an `input` key be passed in - this is what
* is will be used to pass into the retriever. If this is NOT a
* subclass of BaseRetriever, then all the inputs will be passed
* into this runnable, meaning that runnable should take a object
* as input.
*/
retriever:
| BaseRetrieverInterface
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| RunnableInterface<Record<string, any>, DocumentInterface[]>;
/**
* Runnable that takes inputs and produces a string output.
* The inputs to this will be any original inputs to this chain, a new
* context key with the retrieved documents, and chat_history (if not present
* in the inputs) with a value of `[]` (to easily enable conversational
* retrieval).
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
combineDocsChain: RunnableInterface<Record<string, any>, RunOutput>;
};
function isBaseRetriever(x: unknown): x is BaseRetrieverInterface {
return (
!!x &&
typeof (x as BaseRetrieverInterface).getRelevantDocuments === "function"
);
}
/**
* Create a retrieval chain that retrieves documents and then passes them on.
* @param {CreateRetrievalChainParams} params A params object
* containing a retriever and a combineDocsChain.
* @returns An LCEL Runnable which returns a an object
* containing at least `context` and `answer` keys.
* @example
* ```typescript
* // yarn add langchain @langchain/openai
*
* import { ChatOpenAI } from "@langchain/openai";
* import { pull } from "langchain/hub";
* import { createRetrievalChain } from "langchain/chains/retrieval";
* import { createStuffDocumentsChain } from "langchain/chains/combine_documents";
*
* const retrievalQAChatPrompt = await pull("langchain-ai/retrieval-qa-chat");
* const llm = new ChatOpenAI({});
* const retriever = ...
* const combineDocsChain = await createStuffDocumentsChain(...);
* const retrievalChain = await createRetrievalChain({
* retriever,
* combineDocsChain,
* });
* const response = await chain.invoke({ input: "..." });
* ```
*/
export async function createRetrievalChain<RunOutput>({
retriever,
combineDocsChain,
}: CreateRetrievalChainParams<RunOutput>): Promise<
Runnable<
{ input: string; chat_history?: BaseMessage[] | string } & {
[key: string]: unknown;
},
{ context: string; answer: RunOutput } & { [key: string]: unknown }
>
> {
let retrieveDocumentsChain: Runnable<{ input: string }, DocumentInterface[]>;
if (isBaseRetriever(retriever)) {
retrieveDocumentsChain = RunnableSequence.from([
(input) => input.input,
retriever,
]);
} else {
// TODO: Fix typing by adding withConfig to core RunnableInterface
retrieveDocumentsChain = retriever as Runnable;
}
const retrievalChain = RunnableSequence.from<{
input: string;
chat_history?: BaseMessage[] | string;
}>([
RunnablePassthrough.assign({
context: retrieveDocumentsChain.withConfig({
runName: "retrieve_documents",
}),
chat_history: (input) => input.chat_history ?? [],
}),
RunnablePassthrough.assign({
answer: combineDocsChain,
}),
]).withConfig({ runName: "retrieval_chain" });
return retrievalChain;
}