Skip to content

Commit

Permalink
Keep also original query - multi_query.py (#12696)
Browse files Browse the repository at this point in the history
When you use a MultiQuery it might be useful to use the original query
as well as the newly generated ones to maximise the changes to retriever
the correct document. I haven't created an issue, it seems a very small
and easy thing.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
  • Loading branch information
manuelrech and baskaryan committed Nov 3, 2023
1 parent 4fe9bf7 commit 2e2b9c7
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions libs/langchain/langchain/retrievers/multi_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class MultiQueryRetriever(BaseRetriever):
llm_chain: LLMChain
verbose: bool = True
parser_key: str = "lines"
include_original: bool = False
"""Whether to include the original query in the list of generated queries."""

@classmethod
def from_llm(
Expand All @@ -69,12 +71,15 @@ def from_llm(
llm: BaseLLM,
prompt: PromptTemplate = DEFAULT_QUERY_PROMPT,
parser_key: str = "lines",
include_original: bool = False,
) -> "MultiQueryRetriever":
"""Initialize from llm using default template.
Args:
retriever: retriever to query documents from
llm: llm for query generation using DEFAULT_QUERY_PROMPT
include_original: Whether to include the original query in the list of
generated queries.
Returns:
MultiQueryRetriever
Expand All @@ -85,6 +90,7 @@ def from_llm(
retriever=retriever,
llm_chain=llm_chain,
parser_key=parser_key,
include_original=include_original,
)

async def _aget_relevant_documents(
Expand All @@ -102,6 +108,8 @@ async def _aget_relevant_documents(
Unique union of relevant documents from all generated queries
"""
queries = await self.agenerate_queries(query, run_manager)
if self.include_original:
queries.append(query)
documents = await self.aretrieve_documents(queries, run_manager)
return self.unique_union(documents)

Expand Down Expand Up @@ -160,6 +168,8 @@ def _get_relevant_documents(
Unique union of relevant documents from all generated queries
"""
queries = self.generate_queries(query, run_manager)
if self.include_original:
queries.append(query)
documents = self.retrieve_documents(queries, run_manager)
return self.unique_union(documents)

Expand Down

0 comments on commit 2e2b9c7

Please sign in to comment.