diff --git a/pymongo_search_utils/pipeline.py b/pymongo_search_utils/pipeline.py index 5129045..1f63bc5 100644 --- a/pymongo_search_utils/pipeline.py +++ b/pymongo_search_utils/pipeline.py @@ -102,29 +102,35 @@ def combine_pipelines( def reciprocal_rank_stage( - score_field: str, penalty: float = 0, **kwargs: Any + score_field: str, penalty: float = 0, weight: float = 1, **kwargs: Any ) -> list[dict[str, Any]]: - """Stage adds Reciprocal Rank Fusion weighting. + """ + Stage adds Weighted Reciprocal Rank Fusion (WRRF) scoring. - First, it pushes documents retrieved from previous stage - into a temporary sub-document. It then unwinds to establish - the rank to each and applies the penalty. + First, it groups documents into an array, assigns rank by array index, + and then computes a weighted RRF score. Args: - score_field: A unique string to identify the search being ranked - penalty: A non-negative float. - extra_fields: Any fields other than text_field that one wishes to keep. + score_field: A unique string to identify the search being ranked. + penalty: A non-negative float (e.g., 60 for RRF-60). Controls the denominator. + weight: A float multiplier for this source's importance. + **kwargs: Ignored; allows future extensions or passthrough args. Returns: - RRF score + Aggregation pipeline stage for weighted RRF scoring. """ - rrf_pipeline = [ + return [ {"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}}, {"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}}, { "$addFields": { - f"docs.{score_field}": {"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]}, + f"docs.{score_field}": { + "$multiply": [ + weight, + {"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]}, + ] + }, "docs.rank": "$rank", "_id": "$docs._id", } @@ -132,8 +138,6 @@ def reciprocal_rank_stage( {"$replaceRoot": {"newRoot": "$docs"}}, ] - return rrf_pipeline # type: ignore[return-value] - def final_hybrid_stage(scores_fields: list[str], limit: int, **kwargs: Any) -> list[dict[str, Any]]: """Sum weighted scores, sort, and apply limit. diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 7cdffd7..8da9a26 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -196,7 +196,9 @@ def test_basic_reciprocal_rank(self): {"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}}, { "$addFields": { - "docs.text_score": {"$divide": [1.0, {"$add": ["$rank", 0, 1]}]}, + "docs.text_score": { + "$multiply": [1, {"$divide": [1.0, {"$add": ["$rank", 0, 1]}]}] + }, "docs.rank": "$rank", "_id": "$docs._id", } @@ -210,7 +212,7 @@ def test_reciprocal_rank_with_penalty(self): result = reciprocal_rank_stage(score_field="vector_score", penalty=60) add_fields_stage = result[2]["$addFields"] - divide_expr = add_fields_stage["docs.vector_score"]["$divide"] + divide_expr = add_fields_stage["docs.vector_score"]["$multiply"][1]["$divide"] add_expr = divide_expr[1]["$add"] assert add_expr == ["$rank", 60, 1] @@ -225,7 +227,11 @@ def test_reciprocal_rank_with_kwargs(self): result = reciprocal_rank_stage(score_field="test_score", penalty=10, extra_param="ignored") assert len(result) == 4 - assert result[2]["$addFields"]["docs.test_score"]["$divide"][1]["$add"] == ["$rank", 10, 1] + assert result[2]["$addFields"]["docs.test_score"]["$multiply"][1]["$divide"][1]["$add"] == [ + "$rank", + 10, + 1, + ] class TestFinalHybridStage: