In [114]:
from longeval.spark import get_spark
from longeval.collection import ParquetCollection
from pyspark.sql import functions as F, Window
import pandas as pd

spark = get_spark()
root = "../../tests/integration"
collection = ParquetCollection(spark, f"{root}/parquet/train/2023_01/English")

collection.queries.printSchema()
collection.queries.show(3)

collection.qrels.printSchema()
collection.qrels.show(3)

relevant_queries = collection.queries.join(
    collection.qrels.where("rel > 0")
    .groupBy("qid")
    .agg(F.collect_set("docid").alias("rel_docids")),
    on="qid",
).select("qid", "query", "rel_docids")

relevant_queries.printSchema()
relevant_queries.show(3)

25/03/15 22:20:14 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


root
 |-- qid: string (nullable = true)
 |-- query: string (nullable = true)

+--------+--------------------+
|     qid|               query|
+--------+--------------------+
| q012318|case over the border|
| q012396|      water atlantic|
|q0123180|blanquette de vea...|
+--------+--------------------+
only showing top 3 rows

root
 |-- qid: string (nullable = true)
 |-- rank: integer (nullable = true)
 |-- docid: string (nullable = true)
 |-- rel: integer (nullable = true)

+-------+----+---------------+---+
|    qid|rank|          docid|rel|
+-------+----+---------------+---+
|q012318|   0|doc012303114898|  0|
|q012318|   0|doc012307806130|  1|
|q012318|   0|doc012311314092|  0|
+-------+----+---------------+---+
only showing top 3 rows

root
 |-- qid: string (nullable = true)
 |-- query: string (nullable = true)
 |-- rel_docids: array (nullable = false)
 |    |-- element: string (containsNull = false)

+-----------------+-------------------+--------------------+
|              qid|   

In [102]:
from opensearchpy import OpenSearch


def generate_bulk_query(df, index_name: str) -> list[dict]:
    data = []
    for row in df.itertuples():
        data += [
            {
                "index": index_name,
            },
            {
                "query": {
                    "match": {
                        "contents": {
                            "query": row.query,
                        }
                    }
                },
                "_source": False,
            },
        ]
    return data


client = OpenSearch("http://localhost:9200")
index_name = "train-english-2023_01"

pdf = relevant_queries.toPandas()
display(pdf.head())

results = client.msearch(generate_bulk_query(pdf, index_name))
# now iterate over the results and add in the original query
for row, obj in zip(pdf.itertuples(), results["responses"]):
    obj["qid"] = row.qid

Unnamed: 0,qid,query,rel_docids
0,q0123103079215124,bill of sale vessel,"[doc012308214224, doc012311713704, doc01230130..."
1,q0123103079215188,areches beaufort,"[doc012312405753, doc012304816793, doc01230430..."
2,q0123103079215846,schengen space,"[doc012300218460, doc012303114703, doc01231200..."
3,q0123103079215871,office chair,"[doc012300401944, doc012300814609, doc01230430..."
4,q0123103079215891,fontainebleau cinema,"[doc012301119210, doc012303612581, doc01230890..."


In [103]:
print(results.keys())
print(results["took"])

resp = pd.DataFrame(results["responses"])
display(resp.head())
shards = pd.DataFrame(resp["_shards"].tolist())
display(shards.head())
hits = pd.DataFrame(resp["hits"].tolist())
hits["qid"] = pdf["qid"]
display(hits.head())
display(hits.iloc[0].hits[:3])

dict_keys(['took', 'responses'])
998


Unnamed: 0,took,timed_out,_shards,hits,status,qid
0,22,False,"{'total': 1, 'successful': 1, 'skipped': 0, 'f...","{'total': {'value': 1191, 'relation': 'eq'}, '...",200,q0123103079215124
1,13,False,"{'total': 1, 'successful': 1, 'skipped': 0, 'f...","{'total': {'value': 2, 'relation': 'eq'}, 'max...",200,q0123103079215188
2,12,False,"{'total': 1, 'successful': 1, 'skipped': 0, 'f...","{'total': {'value': 138, 'relation': 'eq'}, 'm...",200,q0123103079215846
3,15,False,"{'total': 1, 'successful': 1, 'skipped': 0, 'f...","{'total': {'value': 136, 'relation': 'eq'}, 'm...",200,q0123103079215871
4,14,False,"{'total': 1, 'successful': 1, 'skipped': 0, 'f...","{'total': {'value': 35, 'relation': 'eq'}, 'ma...",200,q0123103079215891


Unnamed: 0,total,successful,skipped,failed
0,1,1,0,0
1,1,1,0,0
2,1,1,0,0
3,1,1,0,0
4,1,1,0,0


Unnamed: 0,total,max_score,hits,qid
0,"{'value': 1191, 'relation': 'eq'}",10.457349,"[{'_index': 'train-english-2023_01', '_id': 'd...",q0123103079215124
1,"{'value': 2, 'relation': 'eq'}",7.458131,"[{'_index': 'train-english-2023_01', '_id': 'd...",q0123103079215188
2,"{'value': 138, 'relation': 'eq'}",5.751015,"[{'_index': 'train-english-2023_01', '_id': 'd...",q0123103079215846
3,"{'value': 136, 'relation': 'eq'}",9.848803,"[{'_index': 'train-english-2023_01', '_id': 'd...",q0123103079215871
4,"{'value': 35, 'relation': 'eq'}",8.403374,"[{'_index': 'train-english-2023_01', '_id': 'd...",q0123103079215891


[{'_index': 'train-english-2023_01',
  '_id': 'doc012301800007',
  '_score': 10.457349},
 {'_index': 'train-english-2023_01',
  '_id': 'doc012302900007',
  '_score': 9.1460285},
 {'_index': 'train-english-2023_01',
  '_id': 'doc012303600003',
  '_score': 8.8720665}]

In [105]:
tmp = spark.createDataFrame(results["responses"])
tmp.printSchema()
tmp.select("hits").show(5, truncate=100)

root
 |-- _shards: map (nullable = true)
 |    |-- key: string
 |    |-- value: long (valueContainsNull = true)
 |-- hits: map (nullable = true)
 |    |-- key: string
 |    |-- value: map (valueContainsNull = true)
 |    |    |-- key: string
 |    |    |-- value: long (valueContainsNull = true)
 |-- qid: string (nullable = true)
 |-- status: long (nullable = true)
 |-- timed_out: boolean (nullable = true)
 |-- took: long (nullable = true)

+-----------------------------------------------------------------------------+
|                                                                         hits|
+-----------------------------------------------------------------------------+
|{hits -> NULL, total -> {value -> 1191, relation -> NULL}, max_score -> NULL}|
|   {hits -> NULL, total -> {value -> 2, relation -> NULL}, max_score -> NULL}|
| {hits -> NULL, total -> {value -> 138, relation -> NULL}, max_score -> NULL}|
| {hits -> NULL, total -> {value -> 136, relation -> NULL}, max_score -> NUL

In [110]:
schema = """
qid: string,
hits: struct<
    total: struct<value: long, relation: string>,
    max_score: double,
    hits: array<struct<_index: string, _id: string, _score: double>>
>
"""
resp = (
    spark.createDataFrame(results["responses"], schema=schema)
    .select("qid", "hits.*")
    .withColumn("total", F.col("total.value"))
)
resp.printSchema()
resp.show()

root
 |-- qid: string (nullable = true)
 |-- total: long (nullable = true)
 |-- max_score: double (nullable = true)
 |-- hits: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- _index: string (nullable = true)
 |    |    |-- _id: string (nullable = true)
 |    |    |-- _score: double (nullable = true)

+-----------------+-----+----------+--------------------+
|              qid|total| max_score|                hits|
+-----------------+-----+----------+--------------------+
|q0123103079215124| 1191| 10.457349|[{train-english-2...|
|q0123103079215188|    2|  7.458131|[{train-english-2...|
|q0123103079215846|  138| 5.7510147|[{train-english-2...|
|q0123103079215871|  136|  9.848803|[{train-english-2...|
|q0123103079215891|   35|  8.403374|[{train-english-2...|
|q0123103079215932|  360| 17.849213|[{train-english-2...|
|q0123103079215943|   12|  6.336317|[{train-english-2...|
|q0123103079215951|   80| 5.1309576|[{train-english-2...|
|q0123103079215952|  

In [126]:
resp.select(F.posexplode("hits")).printSchema()

root
 |-- pos: integer (nullable = false)
 |-- col: struct (nullable = true)
 |    |-- _index: string (nullable = true)
 |    |-- _id: string (nullable = true)
 |    |-- _score: double (nullable = true)



In [141]:
window = Window.partitionBy("qid").orderBy(F.desc("hits._score"), F.asc("hits._id"))
scored = (
    resp.select("qid", "total", "max_score", F.posexplode("hits").alias("pos", "hits"))
    .withColumn("docids", F.collect_list(F.col("hits._id")).over(window))
    .withColumn("scores", F.collect_list(F.col("hits._score")).over(window))
    .groupBy("qid")
    .agg(
        F.any_value("total").alias("total"),
        F.any_value("max_score").alias("max_score"),
        F.max("docids").alias("docids"),
        F.max("scores").alias("scores"),
    )
    .join(
        relevant_queries.select("qid", "rel_docids"),
        on="qid",
    )
)
scored.show()

+-----------------+-----+----------+--------------------+--------------------+--------------------+
|              qid|total| max_score|              docids|              scores|          rel_docids|
+-----------------+-----+----------+--------------------+--------------------+--------------------+
|q0123103079215124| 1191| 10.457349|[doc012301800007,...|[10.457349, 9.146...|[doc012308214224,...|
|q0123103079215188|    2|  7.458131|[doc012305500004,...|[7.458131, 4.7410...|[doc012312405753,...|
|q0123103079215846|  138| 5.7510147|[doc012308200002,...|[5.7510147, 4.486...|[doc012300218460,...|
|q0123103079215871|  136|  9.848803|[doc012302600006,...|[9.848803, 9.3083...|[doc012300401944,...|
|q0123103079215891|   35|  8.403374|[doc012308500005,...|[8.403374, 6.6840...|[doc012301119210,...|
|q0123103079215932|  360| 17.849213|[doc012312300001,...|[17.849213, 11.66...|[doc012303908297,...|
|q0123103079215943|   12|  6.336317|[doc012312500008,...|[6.336317, 5.9054...|[doc012310507883,...|


In [148]:
from sklearn.metrics import ndcg_score


# now compute the ndcg
@F.udf("struct<ndcg: double, rel_count: int>")
def compute_scores(docids: list[str], rel_docids: list[str]) -> dict:
    if not docids:
        return {"ndcg": 0.0, "rel_count": 0}
    true_rel = [1 if docid in rel_docids else 0 for docid in docids]
    predicted_rel = [1] * len(docids)
    return {
        "ndcg": float(ndcg_score([true_rel], [predicted_rel])),
        "rel_count": sum(true_rel),
    }


scored = scored.withColumn(
    "scores", compute_scores(F.col("docids"), F.col("rel_docids"))
)
scored.show()

[Stage 285:>                                                        (0 + 1) / 1]

+-----------------+-----+----------+--------------------+--------+--------------------+
|              qid|total| max_score|              docids|  scores|          rel_docids|
+-----------------+-----+----------+--------------------+--------+--------------------+
|q0123103079215124| 1191| 10.457349|[doc012301800007,...|{0.0, 0}|[doc012308214224,...|
|q0123103079215188|    2|  7.458131|[doc012305500004,...|{0.0, 0}|[doc012312405753,...|
|q0123103079215846|  138| 5.7510147|[doc012308200002,...|{0.0, 0}|[doc012300218460,...|
|q0123103079215871|  136|  9.848803|[doc012302600006,...|{0.0, 0}|[doc012300401944,...|
|q0123103079215891|   35|  8.403374|[doc012308500005,...|{0.0, 0}|[doc012301119210,...|
|q0123103079215932|  360| 17.849213|[doc012312300001,...|{0.0, 0}|[doc012303908297,...|
|q0123103079215943|   12|  6.336317|[doc012312500008,...|{0.0, 0}|[doc012310507883,...|
|q0123103079215951|   80| 5.1309576|[doc012300200003,...|{0.0, 0}|[doc012301309936,...|
|q0123103079215989|   12| 13.040

                                                                                