Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions pymongo_search_utils/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,38 +102,42 @@ def combine_pipelines(


def reciprocal_rank_stage(
score_field: str, penalty: float = 0, **kwargs: Any
score_field: str, penalty: float = 0, weight: float = 1, **kwargs: Any
) -> list[dict[str, Any]]:
"""Stage adds Reciprocal Rank Fusion weighting.
"""
Stage adds Weighted Reciprocal Rank Fusion (WRRF) scoring.

First, it pushes documents retrieved from previous stage
into a temporary sub-document. It then unwinds to establish
the rank to each and applies the penalty.
First, it groups documents into an array, assigns rank by array index,
and then computes a weighted RRF score.

Args:
score_field: A unique string to identify the search being ranked
penalty: A non-negative float.
extra_fields: Any fields other than text_field that one wishes to keep.
score_field: A unique string to identify the search being ranked.
penalty: A non-negative float (e.g., 60 for RRF-60). Controls the denominator.
weight: A float multiplier for this source's importance.
**kwargs: Ignored; allows future extensions or passthrough args.

Returns:
RRF score
Aggregation pipeline stage for weighted RRF scoring.
"""

rrf_pipeline = [
return [
{"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}},
{"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}},
{
"$addFields": {
f"docs.{score_field}": {"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]},
f"docs.{score_field}": {
"$multiply": [
weight,
{"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]},
]
},
"docs.rank": "$rank",
"_id": "$docs._id",
}
},
{"$replaceRoot": {"newRoot": "$docs"}},
]

return rrf_pipeline # type: ignore[return-value]


def final_hybrid_stage(scores_fields: list[str], limit: int, **kwargs: Any) -> list[dict[str, Any]]:
"""Sum weighted scores, sort, and apply limit.
Expand Down
12 changes: 9 additions & 3 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,9 @@ def test_basic_reciprocal_rank(self):
{"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}},
{
"$addFields": {
"docs.text_score": {"$divide": [1.0, {"$add": ["$rank", 0, 1]}]},
"docs.text_score": {
"$multiply": [1, {"$divide": [1.0, {"$add": ["$rank", 0, 1]}]}]
},
"docs.rank": "$rank",
"_id": "$docs._id",
}
Expand All @@ -210,7 +212,7 @@ def test_reciprocal_rank_with_penalty(self):
result = reciprocal_rank_stage(score_field="vector_score", penalty=60)

add_fields_stage = result[2]["$addFields"]
divide_expr = add_fields_stage["docs.vector_score"]["$divide"]
divide_expr = add_fields_stage["docs.vector_score"]["$multiply"][1]["$divide"]
add_expr = divide_expr[1]["$add"]

assert add_expr == ["$rank", 60, 1]
Expand All @@ -225,7 +227,11 @@ def test_reciprocal_rank_with_kwargs(self):
result = reciprocal_rank_stage(score_field="test_score", penalty=10, extra_param="ignored")

assert len(result) == 4
assert result[2]["$addFields"]["docs.test_score"]["$divide"][1]["$add"] == ["$rank", 10, 1]
assert result[2]["$addFields"]["docs.test_score"]["$multiply"][1]["$divide"][1]["$add"] == [
"$rank",
10,
1,
]


class TestFinalHybridStage:
Expand Down