In [4]:
from longeval.spark import get_spark
from longeval.collection import ParquetCollection
from pyspark.sql import functions as F, Window
from pathlib import Path
from opensearchpy import OpenSearch
from sklearn.metrics import ndcg_score

spark = get_spark()
root = Path("~/").expanduser() / "scratch/longeval"
collection = ParquetCollection(spark, f"{root}/parquet/train/2023_01/English")

relevant_queries = collection.queries.join(
    collection.qrels.where("rel > 0")
    .groupBy("qid")
    .agg(F.collect_set("docid").alias("rel_docids")),
    on="qid",
).select("qid", "query", "rel_docids")

25/03/16 03:18:52 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [13]:
relevant_queries.select(F.size("rel_docids")).describe().show()

+-------+------------------+
|summary|  size(rel_docids)|
+-------+------------------+
|  count|               599|
|   mean| 7.282136894824708|
| stddev|2.0460087848745068|
|    min|                 6|
|    max|                19|
+-------+------------------+



In [9]:
def generate_bulk_query(df, index_name: str, k: int = 50) -> list[dict]:
    data = []
    for row in df.itertuples():
        data += [
            {
                "index": index_name,
            },
            {
                "query": {
                    "match": {
                        "contents": {
                            "query": row.query,
                        }
                    }
                },
                "_source": False,
                "size": k,
            },
        ]
    return data


@F.udf("struct<ndcg: double, rel_count: int>")
def compute_scores(docids: list[str], rel_docids: list[str]) -> dict:
    if not docids:
        return {"ndcg": 0.0, "rel_count": 0}
    true_rel = [1 if docid in rel_docids else 0 for docid in docids]
    pred_rel = list(range(len(docids), 0, -1))

    # handle the case where there is only a single document, which the sklearn
    # metric does not handle well
    if len(true_rel) == 1:
        score = true_rel[0]
        return {"ndcg": float(score), "rel_count": int(score)}
    return {
        "ndcg": float(ndcg_score([true_rel], [pred_rel])),
        "rel_count": sum(true_rel),
    }


def run_search(relevant_queries, index_name: str, host="localhost:9200") -> list[dict]:
    client = OpenSearch(host)

    pdf = relevant_queries.toPandas()
    results = client.msearch(generate_bulk_query(pdf, index_name))
    # now iterate over the results and add in the original query
    for row, obj in zip(pdf.itertuples(), results["responses"]):
        obj["qid"] = row.qid

    schema = """
        qid: string,
        hits: struct<
            total: struct<value: long, relation: string>,
            max_score: double,
            hits: array<struct<_index: string, _id: string, _score: double>>
        >
    """
    resp = (
        spark.createDataFrame(results["responses"], schema=schema)
        .select("qid", "hits.*")
        .withColumn("total", F.col("total.value"))
    )

    window = Window.partitionBy("qid").orderBy(F.desc("hits._score"), F.asc("hits._id"))
    return (
        resp.select(
            "qid", "total", "max_score", F.posexplode("hits").alias("pos", "hits")
        )
        .withColumn("docids", F.collect_list(F.col("hits._id")).over(window))
        .withColumn("scores", F.collect_list(F.col("hits._score")).over(window))
        .groupBy("qid")
        .agg(
            F.any_value("total").alias("total"),
            F.any_value("max_score").alias("max_score"),
            F.max("docids").alias("docids"),
            F.max("scores").alias("scores"),
        )
        .join(
            relevant_queries.select("qid", "rel_docids"),
            on="qid",
        )
        .withColumn("scores", compute_scores(F.col("docids"), F.col("rel_docids")))
    )

In [10]:
res = run_search(relevant_queries, "longeval-train-english-2023_01").cache()
res.show()

INFO:opensearch:POST http://localhost:9200/_msearch [status:200 request:4.422s]

+-----------------+-----+---------+--------------------+--------------------+--------------------+
|              qid|total|max_score|              docids|              scores|          rel_docids|
+-----------------+-----+---------+--------------------+--------------------+--------------------+
|  q01238589934750|  325| 18.65921|[doc012303114671,...|{0.34561852186877...|[doc012311714403,...|
| q012334359739164|10000|26.772526|[doc012302714978,...|{0.63216048917480...|[doc012311108501,...|
| q012360129542513|10000|14.505418|[doc012305205192,...|{0.44020499676612...|[doc012311313198,...|
| q012360129543646|10000|  12.9868|[doc012312210422,...|            {1.0, 1}|[doc012312210422,...|
|q0123103079215891|10000|20.341152|[doc012301119210,...|{0.94563527046735...|[doc012301119210,...|
|q0123111669150911|10000| 14.23442|[doc012312006786,...|            {0.0, 0}|[doc012307814240,...|
|q0123120259085172| 6752| 24.29629|[doc012304008081,...|{0.47574215113581...|[doc012311313428,...|
|q01231202

                                                                                

In [None]:
# average ndcg score, how accurate is this?
res.select("scores.ndcg").describe().show()



+-------+-------------------+
|summary|               ndcg|
+-------+-------------------+
|  count|                597|
|   mean|0.44121392504198237|
| stddev| 0.2897673578153641|
|    min|                0.0|
|    max|                1.0|
+-------+-------------------+



                                                                                