-
Notifications
You must be signed in to change notification settings - Fork 13.6k
/
merger_retriever.py
128 lines (102 loc) Β· 3.52 KB
/
merger_retriever.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import asyncio
from typing import List
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
class MergerRetriever(BaseRetriever):
"""Retriever that merges the results of multiple retrievers."""
retrievers: List[BaseRetriever]
"""A list of retrievers to merge."""
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
"""
Get the relevant documents for a given query.
Args:
query: The query to search for.
Returns:
A list of relevant documents.
"""
# Merge the results of the retrievers.
merged_documents = self.merge_documents(query, run_manager)
return merged_documents
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> List[Document]:
"""
Asynchronously get the relevant documents for a given query.
Args:
query: The query to search for.
Returns:
A list of relevant documents.
"""
# Merge the results of the retrievers.
merged_documents = await self.amerge_documents(query, run_manager)
return merged_documents
def merge_documents(
self, query: str, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""
Merge the results of the retrievers.
Args:
query: The query to search for.
Returns:
A list of merged documents.
"""
# Get the results of all retrievers.
retriever_docs = [
retriever.invoke(
query,
config={
"callbacks": run_manager.get_child("retriever_{}".format(i + 1))
},
)
for i, retriever in enumerate(self.retrievers)
]
# Merge the results of the retrievers.
merged_documents = []
max_docs = max(map(len, retriever_docs), default=0)
for i in range(max_docs):
for retriever, doc in zip(self.retrievers, retriever_docs):
if i < len(doc):
merged_documents.append(doc[i])
return merged_documents
async def amerge_documents(
self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
"""
Asynchronously merge the results of the retrievers.
Args:
query: The query to search for.
Returns:
A list of merged documents.
"""
# Get the results of all retrievers.
retriever_docs = await asyncio.gather(
*(
retriever.ainvoke(
query,
config={
"callbacks": run_manager.get_child("retriever_{}".format(i + 1))
},
)
for i, retriever in enumerate(self.retrievers)
)
)
# Merge the results of the retrievers.
merged_documents = []
max_docs = max(map(len, retriever_docs), default=0)
for i in range(max_docs):
for retriever, doc in zip(self.retrievers, retriever_docs):
if i < len(doc):
merged_documents.append(doc[i])
return merged_documents