Skip to content

Commit

Permalink
feat: allow DocumentJoiner to accept top_k parameter in run method (#…
Browse files Browse the repository at this point in the history
…7709)

* feat: allow DocumentJoiner to accept top_k parameter in run method

* Added release note for DocumentJoiner top_k fix
  • Loading branch information
Varun-Krishnan1 committed May 23, 2024
1 parent 482f60e commit badb05b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
9 changes: 7 additions & 2 deletions haystack/components/joiners/document_joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,14 @@ def __init__(
self.sort_by_score = sort_by_score

@component.output_types(documents=List[Document])
def run(self, documents: Variadic[List[Document]]):
def run(self, documents: Variadic[List[Document]], top_k: Optional[int] = None):
"""
Joins multiple lists of Documents into a single list depending on the `join_mode` parameter.
:param documents:
List of list of Documents to be merged.
:param top_k:
The maximum number of Documents to return. Overrides the instance's `top_k` if provided.
:returns:
A dictionary with the following keys:
Expand All @@ -103,8 +105,11 @@ def run(self, documents: Variadic[List[Document]]):
"score, so those with score=None were sorted as if they had a score of -infinity."
)

if self.top_k:
if top_k:
output_documents = output_documents[:top_k]
elif self.top_k:
output_documents = output_documents[: self.top_k]

return {"documents": output_documents}

def _concatenate(self, document_lists):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---

enhancements:
- |
The `DocumentJoiner` component's `run` method now accepts a `top_k` parameter, allowing users to specify the maximum number of documents to return at query time. This fixes issue #7702.
8 changes: 8 additions & 0 deletions test/components/joiners/test_document_joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@ def test_run_with_reciprocal_rank_fusion_join_mode(self):
]
assert all(doc.id in expected_document_ids for doc in output["documents"])

def test_run_with_top_k_in_run_method(self):
joiner = DocumentJoiner()
documents_1 = [Document(content="a"), Document(content="b"), Document(content="c")]
documents_2 = [Document(content="d"), Document(content="e"), Document(content="f")]
top_k = 4
output = joiner.run([documents_1, documents_2], top_k=top_k)
assert len(output["documents"]) == top_k

def test_sort_by_score_without_scores(self, caplog):
joiner = DocumentJoiner()
with caplog.at_level(logging.INFO):
Expand Down

0 comments on commit badb05b

Please sign in to comment.