-
Notifications
You must be signed in to change notification settings - Fork 13.5k
/
stuff.py
266 lines (224 loc) Β· 10.7 KB
/
stuff.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
"""Chain that combines documents by stuffing into context."""
from typing import Any, Dict, List, Optional, Tuple
from langchain_core.callbacks import Callbacks
from langchain_core.documents import Document
from langchain_core.language_models import LanguageModelLike
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
from langchain_core.prompts import BasePromptTemplate, format_document
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain.chains.combine_documents.base import (
DEFAULT_DOCUMENT_PROMPT,
DEFAULT_DOCUMENT_SEPARATOR,
DOCUMENTS_KEY,
BaseCombineDocumentsChain,
_validate_prompt,
)
from langchain.chains.llm import LLMChain
def create_stuff_documents_chain(
llm: LanguageModelLike,
prompt: BasePromptTemplate,
*,
output_parser: Optional[BaseOutputParser] = None,
document_prompt: Optional[BasePromptTemplate] = None,
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
) -> Runnable[Dict[str, Any], Any]:
"""Create a chain for passing a list of Documents to a model.
Args:
llm: Language model.
prompt: Prompt template. Must contain input variable "context", which will be
used for passing in the formatted documents.
output_parser: Output parser. Defaults to StrOutputParser.
document_prompt: Prompt used for formatting each document into a string. Input
variables can be "page_content" or any metadata keys that are in all
documents. "page_content" will automatically retrieve the
`Document.page_content`, and all other inputs variables will be
automatically retrieved from the `Document.metadata` dictionary. Default to
a prompt that only contains `Document.page_content`.
document_separator: String separator to use between formatted document strings.
Returns:
An LCEL Runnable. The input is a dictionary that must have a "context" key that
maps to a List[Document], and any other input variables expected in the prompt.
The Runnable return type depends on output_parser used.
Example:
.. code-block:: python
# pip install -U langchain langchain-community
from langchain_community.chat_models import ChatOpenAI
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
prompt = ChatPromptTemplate.from_messages(
[("system", "What are everyone's favorite colors:\\n\\n{context}")]
)
llm = ChatOpenAI(model="gpt-3.5-turbo")
chain = create_stuff_documents_chain(llm, prompt)
docs = [
Document(page_content="Jesse loves red but not yellow"),
Document(page_content = "Jamal loves green but not as much as he loves orange")
]
chain.invoke({"context": docs})
""" # noqa: E501
_validate_prompt(prompt)
_document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT
_output_parser = output_parser or StrOutputParser()
def format_docs(inputs: dict) -> str:
return document_separator.join(
format_document(doc, _document_prompt) for doc in inputs[DOCUMENTS_KEY]
)
return (
RunnablePassthrough.assign(**{DOCUMENTS_KEY: format_docs}).with_config(
run_name="format_inputs"
)
| prompt
| llm
| _output_parser
).with_config(run_name="stuff_documents_chain")
class StuffDocumentsChain(BaseCombineDocumentsChain):
"""Chain that combines documents by stuffing into context.
This chain takes a list of documents and first combines them into a single string.
It does this by formatting each document into a string with the `document_prompt`
and then joining them together with `document_separator`. It then adds that new
string to the inputs with the variable name set by `document_variable_name`.
Those inputs are then passed to the `llm_chain`.
Example:
.. code-block:: python
from langchain.chains import StuffDocumentsChain, LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_community.llms import OpenAI
# This controls how each document will be formatted. Specifically,
# it will be passed to `format_document` - see that function for more
# details.
document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
document_variable_name = "context"
llm = OpenAI()
# The prompt here should take as an input variable the
# `document_variable_name`
prompt = PromptTemplate.from_template(
"Summarize this content: {context}"
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
"""
llm_chain: LLMChain
"""LLM chain which is called with the formatted document string,
along with any other inputs."""
document_prompt: BasePromptTemplate = Field(
default_factory=lambda: DEFAULT_DOCUMENT_PROMPT
)
"""Prompt to use to format each document, gets passed to `format_document`."""
document_variable_name: str
"""The variable name in the llm_chain to put the documents in.
If only one variable in the llm_chain, this need not be provided."""
document_separator: str = "\n\n"
"""The string with which to join the formatted documents"""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def get_default_document_variable_name(cls, values: Dict) -> Dict:
"""Get default document variable name, if not provided.
If only one variable is present in the llm_chain.prompt,
we can infer that the formatted documents should be passed in
with this variable name.
"""
llm_chain_variables = values["llm_chain"].prompt.input_variables
if "document_variable_name" not in values:
if len(llm_chain_variables) == 1:
values["document_variable_name"] = llm_chain_variables[0]
else:
raise ValueError(
"document_variable_name must be provided if there are "
"multiple llm_chain_variables"
)
else:
if values["document_variable_name"] not in llm_chain_variables:
raise ValueError(
f"document_variable_name {values['document_variable_name']} was "
f"not found in llm_chain input_variables: {llm_chain_variables}"
)
return values
@property
def input_keys(self) -> List[str]:
extra_keys = [
k for k in self.llm_chain.input_keys if k != self.document_variable_name
]
return super().input_keys + extra_keys
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
"""Construct inputs from kwargs and docs.
Format and then join all the documents together into one input with name
`self.document_variable_name`. Also pluck any additional variables
from **kwargs.
Args:
docs: List of documents to format and then join into single input
**kwargs: additional inputs to chain, will pluck any other required
arguments from here.
Returns:
dictionary of inputs to LLMChain
"""
# Format each document according to the prompt
doc_strings = [format_document(doc, self.document_prompt) for doc in docs]
# Join the documents together to put them in the prompt.
inputs = {
k: v
for k, v in kwargs.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(doc_strings)
return inputs
def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]:
"""Return the prompt length given the documents passed in.
This can be used by a caller to determine whether passing in a list
of documents would exceed a certain prompt length. This useful when
trying to ensure that the size of a prompt remains below a certain
context limit.
Args:
docs: List[Document], a list of documents to use to calculate the
total prompt length.
Returns:
Returns None if the method does not depend on the prompt length,
otherwise the length of the prompt in tokens.
"""
inputs = self._get_inputs(docs, **kwargs)
prompt = self.llm_chain.prompt.format(**inputs)
return self.llm_chain._get_num_tokens(prompt)
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM.
Args:
docs: List of documents to join together into one variable
callbacks: Optional callbacks to pass along
**kwargs: additional parameters to use to get inputs to LLMChain.
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Async stuff all documents into one prompt and pass to LLM.
Args:
docs: List of documents to join together into one variable
callbacks: Optional callbacks to pass along
**kwargs: additional parameters to use to get inputs to LLMChain.
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
return await self.llm_chain.apredict(callbacks=callbacks, **inputs), {}
@property
def _chain_type(self) -> str:
return "stuff_documents_chain"