diff --git a/libs/langchain/langchain/chains/combine_documents/stuff.py b/libs/langchain/langchain/chains/combine_documents/stuff.py index 2b113c7ab466ba..996a5ef7138781 100644 --- a/libs/langchain/langchain/chains/combine_documents/stuff.py +++ b/libs/langchain/langchain/chains/combine_documents/stuff.py @@ -100,6 +100,13 @@ def get_default_document_variable_name(cls, values: Dict) -> Dict: ) return values + @property + def input_keys(self) -> List[str]: + extra_keys = [ + k for k in self.llm_chain.input_keys if k != self.document_variable_name + ] + return super().input_keys + extra_keys + def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict: """Construct inputs from kwargs and docs. diff --git a/libs/langchain/tests/unit_tests/chains/test_combine_documents.py b/libs/langchain/tests/unit_tests/chains/test_combine_documents.py index df2212588e640d..a970c33cd4d292 100644 --- a/libs/langchain/tests/unit_tests/chains/test_combine_documents.py +++ b/libs/langchain/tests/unit_tests/chains/test_combine_documents.py @@ -9,8 +9,10 @@ _collapse_docs, _split_list_of_docs, ) +from langchain.chains.qa_with_sources import load_qa_with_sources_chain from langchain.docstore.document import Document from langchain.schema import format_document +from tests.unit_tests.llms.fake_llm import FakeLLM def _fake_docs_len_func(docs: List[Document]) -> int: @@ -21,6 +23,11 @@ def _fake_combine_docs_func(docs: List[Document], **kwargs: Any) -> str: return "".join([d.page_content for d in docs]) +def test_multiple_input_keys() -> None: + chain = load_qa_with_sources_chain(FakeLLM(), chain_type="stuff") + assert chain.input_keys == ["input_documents", "question"] + + def test__split_list_long_single_doc() -> None: """Test splitting of a long single doc.""" docs = [Document(page_content="foo" * 100)]