From ed37fbaeff57fadd84b67b130d3a1dc8977748f5 Mon Sep 17 00:00:00 2001 From: Ji Date: Sun, 19 Feb 2023 20:48:23 -0800 Subject: [PATCH] for ChatVectorDBChain, add top_k_docs_for_context to allow control how many chunks of context will be retrieved (#1155) given that we allow user define chunk size, think it would be useful for user to define how many chunks of context will be retrieved. --- langchain/chains/chat_vector_db/base.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/langchain/chains/chat_vector_db/base.py b/langchain/chains/chat_vector_db/base.py index f9431baf5b85de..c2bbd4df4986dc 100644 --- a/langchain/chains/chat_vector_db/base.py +++ b/langchain/chains/chat_vector_db/base.py @@ -32,6 +32,7 @@ class ChatVectorDBChain(Chain, BaseModel): question_generator: LLMChain output_key: str = "answer" return_source_documents: bool = False + top_k_docs_for_context: int = 4 """Return the source documents.""" @property @@ -88,7 +89,9 @@ def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: ) else: new_question = question - docs = self.vectorstore.similarity_search(new_question, k=4, **vectordbkwargs) + docs = self.vectorstore.similarity_search( + new_question, k=self.top_k_docs_for_context, **vectordbkwargs + ) new_inputs = inputs.copy() new_inputs["question"] = new_question new_inputs["chat_history"] = chat_history_str @@ -109,7 +112,9 @@ async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]: else: new_question = question # TODO: This blocks the event loop, but it's not clear how to avoid it. - docs = self.vectorstore.similarity_search(new_question, k=4, **vectordbkwargs) + docs = self.vectorstore.similarity_search( + new_question, k=self.top_k_docs_for_context, **vectordbkwargs + ) new_inputs = inputs.copy() new_inputs["question"] = new_question new_inputs["chat_history"] = chat_history_str