In [4]:
from longeval.spark import get_spark
from longeval.collection import ParquetCollection
from pyspark.sql import functions as F, Window
from pathlib import Path
from opensearchpy import OpenSearch
from sklearn.metrics import ndcg_score

spark = get_spark()
root = Path("~/").expanduser() / "scratch/longeval"
collection = ParquetCollection(spark, f"{root}/parquet/train/2023_01/English")

relevant_queries = collection.queries.join(
    collection.qrels.where("rel > 0")
    .groupBy("qid")
    .agg(F.collect_set("docid").alias("rel_docids")),
    on="qid",
).select("qid", "query", "rel_docids")

25/03/16 03:18:52 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [13]:
relevant_queries.select(F.size("rel_docids")).describe().show()

+-------+------------------+
|summary|  size(rel_docids)|
+-------+------------------+
|  count|               599|
|   mean| 7.282136894824708|
| stddev|2.0460087848745068|
|    min|                 6|
|    max|                19|
+-------+------------------+



In [55]:
def generate_bulk_query(df, index_name: str, k: int = 50) -> list[dict]:
    data = []
    for row in df.itertuples():
        data += [
            {
                "index": index_name,
            },
            {
                "query": {
                    "match": {
                        "contents": {
                            "query": row.query,
                        }
                    }
                },
                "_source": False,
                "size": k,
            },
        ]
    return data


@F.udf("struct<ndcg: double, rel_count: int>")
def compute_scores(
    docids: list[str], scores: list[float], rel_docids: list[str]
) -> dict:
    if not docids:
        return {"ndcg": 0.0, "rel_count": 0}
    true_rel = [1 if docid in rel_docids else 0 for docid in docids]
    pred_rel = scores

    # handle the case where there is only a single document, which the sklearn
    # metric does not handle well
    if len(true_rel) == 1:
        score = true_rel[0]
        return {"ndcg": float(score), "rel_count": int(score)}
    return {
        "ndcg": float(ndcg_score([true_rel], [pred_rel])),
        "rel_count": sum(true_rel),
    }


def run_search(relevant_queries, index_name: str, host="localhost:9200") -> list[dict]:
    client = OpenSearch(host)

    pdf = relevant_queries.toPandas()
    results = client.msearch(generate_bulk_query(pdf, index_name))
    # now iterate over the results and add in the original query
    for row, obj in zip(pdf.itertuples(), results["responses"]):
        obj["qid"] = row.qid

    schema = """
        qid: string,
        hits: struct<
            total: struct<value: long, relation: string>,
            max_score: double,
            hits: array<struct<_index: string, _id: string, _score: double>>
        >
    """
    resp = (
        spark.createDataFrame(results["responses"], schema=schema)
        .select("qid", "hits.*")
        .withColumn("total", F.col("total.value"))
    )

    window = Window.partitionBy("qid").orderBy("pos")
    return (
        resp.select(
            "qid", "total", "max_score", F.posexplode("hits").alias("pos", "hits")
        )
        .withColumn("docids", F.collect_list(F.col("hits._id")).over(window))
        .withColumn("scores", F.collect_list(F.col("hits._score")).over(window))
        .groupBy("qid")
        .agg(
            F.any_value("total").alias("total"),
            F.any_value("max_score").alias("max_score"),
            F.max("docids").alias("docids"),
            F.max("scores").alias("scores"),
        )
        .join(
            relevant_queries.select("qid", "rel_docids"),
            on="qid",
        )
        .withColumn(
            "computed_scores",
            compute_scores(F.col("docids"), F.col("scores"), F.col("rel_docids")),
        )
    )

In [58]:
res = run_search(relevant_queries, "longeval-train-english-2023_01").cache()
res.show(n=5)
# average ndcg score, how accurate is this?
res.select("computed_scores.ndcg").describe().show()

INFO:opensearch:POST http://localhost:9200/_msearch [status:200 request:2.077s]
                                                                                

+-----------------+-----+---------+--------------------+--------------------+--------------------+--------------------+
|              qid|total|max_score|              docids|              scores|          rel_docids|     computed_scores|
+-----------------+-----+---------+--------------------+--------------------+--------------------+--------------------+
|  q01238589934750|  325| 18.65921|[doc012303114671,...|[18.65921, 18.654...|[doc012311714403,...|{0.34561852186877...|
| q012334359739164|10000|26.772526|[doc012302714978,...|[26.772526, 26.71...|[doc012311108501,...|{0.63216048917480...|
| q012360129542513|10000|14.505418|[doc012305205192,...|[14.505418, 14.49...|[doc012311313198,...|{0.44020499676612...|
| q012360129543646|10000|  12.9868|[doc012312210422,...|[12.9868, 12.9073...|[doc012312210422,...|            {1.0, 1}|
|q0123103079215891|10000|20.341152|[doc012301119210,...|[20.341152, 19.73...|[doc012301119210,...|{0.94563527046735...|
+-----------------+-----+---------+-----

                                                                                

In [54]:
import pytrec_eval


@F.udf("map<string, float>")
def run_udf(docids: list[str], scores: list[float]) -> dict[str, int]:
    if not docids:
        return {}
    return {k: v for k, v in zip(docids, scores)}


@F.udf("map<string, int>")
def qrel_udf(rel_docids: list[str]) -> dict[str, int]:
    if not rel_docids:
        return {}
    # return the relevant documents
    return {docid: 1 for docid in rel_docids}


# for each qid, we need a mapping of docid to relevance
run_df = res.select(
    "qid",
    run_udf("docids", "scores").alias("run"),
    qrel_udf("rel_docids").alias("qrels"),
)
run_df.printSchema()
run_df.show(n=3, truncate=100)

# now convert to the format required by trec_eval
qrel = {}
run = {}
for row in run_df.collect():
    qrel[row.qid] = row.qrels
    run[row.qid] = row.run

evaluator = pytrec_eval.RelevanceEvaluator(qrel, {"ndcg", "map"})
evals = evaluator.evaluate(run)

evals_df = spark.createDataFrame([{"qid": k, **v} for k, v in evals.items()])
evals_df.describe().show()

root
 |-- qid: string (nullable = true)
 |-- run: map (nullable = true)
 |    |-- key: string
 |    |-- value: float (valueContainsNull = true)
 |-- qrels: map (nullable = true)
 |    |-- key: string
 |    |-- value: integer (valueContainsNull = true)

+----------------+----------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------+
|             qid|                                                                                                 run|                                                                                               qrels|
+----------------+----------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------+
| q01238589934750|{doc012301806966 -> 14.238972, doc012311312899 -> 14.43931, doc012

                                                                                

+-------+-------------------+-------------------+-----------------+
|summary|                map|               ndcg|              qid|
+-------+-------------------+-------------------+-----------------+
|  count|                597|                597|              597|
|   mean|0.16477464955560864|0.31424901046884035|             NULL|
| stddev|0.19069715746288896|0.25924436825908315|             NULL|
|    min|                0.0|                0.0|q0123103079215124|
|    max| 0.8491452991452991| 0.9491546172962253|          q012396|
+-------+-------------------+-------------------+-----------------+



In [52]:
# okay let's see where we went wrong

debug = res.select("qid", F.col("computed_scores.ndcg").alias("my_ndcg")).join(
    evals_df.select("qid", "ndcg").withColumnRenamed("ndcg", "trec_ndcg"),
    on="qid",
)
debug.show()

                                                                                

+-----------------+------------------+-------------------+
|              qid|           my_ndcg|          trec_ndcg|
+-----------------+------------------+-------------------+
|  q01238589934750|0.3456185218687763|0.17057078643005402|
| q012334359739164|0.6321604891748087|  0.564020444522855|
| q012360129542513|0.4402049967661204| 0.2578466240460916|
| q012360129543646|               1.0|0.27487633291429087|
|q0123103079215891|0.9456352704673574|  0.843706044175124|
|q0123111669150911|               0.0|                0.0|
|q0123120259085172|0.4753190249891079|0.28593204868013994|
|q0123120259086252|0.1790522317510419|  0.054181637470214|
| q012317179871244|0.4306765580733929|0.13032376591037062|
|        q01231876|0.7996199827438175| 0.7996199827438173|
| q012325769805427|               0.0|                0.0|
| q012334359739354|0.5221860067660785| 0.3676841955471875|
|  q01238589936278|               0.0|                0.0|
|q0123103079216681|               0.0|                0.

Okay, so I have no idea why my ndcg stuff doesn't work and overestimates the score. I'm just going to use the tried and true library -- might be worth figuring out how to reproduce the scores between the two systems at some point though.