In [2]:
from longeval.spark import get_spark
from pathlib import Path
from longeval.collection import ParquetCollection

spark = get_spark(cores=8, memory="20g")

root = Path("~/shared/longeval/2025/parquet").expanduser()
train = ParquetCollection(spark, (root / "train").as_posix())
test = ParquetCollection(spark, (root / "test").as_posix())
queries = train.queries.union(test.queries).cache()
queries.printSchema()

25/05/26 01:09:36 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


root
 |-- qid: string (nullable = true)
 |-- query: string (nullable = true)
 |-- split: string (nullable = true)
 |-- language: string (nullable = true)
 |-- date: string (nullable = true)



In [3]:
queries.select("qid").distinct().count()

                                                                                

67822

In [7]:
queries.select("date", "qid", "query").show(truncate=False, n=5)

+-------+---+------------------------+
|date   |qid|query                   |
+-------+---+------------------------+
|2022-08|2  |18 videoz               |
|2022-08|3  |1ere guerre mondiale    |
|2022-08|7  |3949 pole emploi        |
|2022-08|8  |4 mariages 1 enterrement|
|2022-08|10 |a vendre chateau        |
+-------+---+------------------------+
only showing top 5 rows


In [10]:
# how many duplicate queries?
queries.groupBy("qid").count().groupBy("count").count().orderBy("count").show()

+-----+-----+
|count|count|
+-----+-----+
|    1|19327|
|    2| 8804|
|    3|25632|
|    4| 3357|
|    5| 2340|
|    6| 1704|
|    7| 1399|
|    8| 1042|
|    9|  910|
|   10|  768|
|   11|  611|
|   12|  448|
|   13|  414|
|   14|  459|
|   15|  607|
+-----+-----+



In [17]:
queries.where("qid=100").show()

+---+-----------------+-----+--------+-------+
|qid|            query|split|language|   date|
+---+-----------------+-----+--------+-------+
|100|appli pole emploi|train|  French|2022-08|
|100|appli pole emploi|train|  French|2022-07|
|100|appli pole emploi|train|  French|2022-06|
|100|appli pole emploi|train|  French|2023-01|
|100|appli pole emploi|train|  French|2022-12|
|100|appli pole emploi|train|  French|2022-11|
|100|appli pole emploi|train|  French|2022-10|
|100|appli pole emploi|train|  French|2023-02|
|100|appli pole emploi|train|  French|2022-09|
|100|appli pole emploi| test|  French|2023-08|
|100|appli pole emploi| test|  French|2023-05|
|100|appli pole emploi| test|  French|2023-07|
|100|appli pole emploi| test|  French|2023-06|
+---+-----------------+-----+--------+-------+



In [None]:
from pyspark.sql import functions as F

deduped = (
    queries.groupBy("qid")
    .agg(F.first("query").alias("query"))
    .orderBy(F.col("qid").cast("integer"))
    .cache()
)
deduped.take(5)

[Row(qid='2', query='18 videoz'),
 Row(qid='3', query='1ere guerre mondiale'),
 Row(qid='4', query='1ere guerre mondiale ce2'),
 Row(qid='5', query='2 guerre mondial'),
 Row(qid='6', query='2 eme guerre mondiale')]

In [None]:
import os
import requests
import dotenv
import json

dotenv.load_dotenv("../../.env")
# let's choose the longest query and query expand with openrouter


def get_schema():
    return {
        "type": "array",
        "items": {
            "type": "object",
            "properties": {
                "qid": {
                    "type": "string",
                    "description": "The unique identifier for the query.",
                },
                "query": {
                    "type": "string",
                    "description": "The text of the query.",
                },
            },
            "required": ["qid", "query"],
            "additionalProperties": False,
        },
    }


def prompt(queries):
    query_text = "\n".join([f"{row['qid']}: {row['query']}" for row in queries])
    return f"""{query_text}

    For each query above, generate a query expansion in French that includes additional relevant terms or phrases.
    The query expansion should be no more than 100 words long.
    The query engine relies on BM25 and vector search techniques in French.
    The output should be a JSON array of objects, each containing the original 'qid' and the expanded 'query'.
    """


def chat_complete(
    queries,
    api_key=os.environ.get("OPENROUTER_API_KEY"),
    model="google/gemini-2.5-flash-preview-05-20",
    # model="google/gemma-3n-e4b-it:free",
):
    completion = requests.post(
        "https://openrouter.ai/api/v1/chat/completions",
        headers={
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json",
        },
        json=dict(
            model=model,
            messages=[
                {"role": "user", "content": [{"type": "text", "text": prompt(queries)}]}
            ],
            # object with predict list of strings and a reason sting
            response_format={
                "type": "json_schema",
                "json_schema": {
                    "name": "longeval",
                    "strict": True,
                    "schema": get_schema(),
                },
            },
        ),
    )
    return completion.json()


def query_expansion(queries, output):
    # add the response to a logfile
    output = Path(output).expanduser()
    logfile = Path(output) / "completion/log.txt"
    # name is start-end range
    start = queries[0]["qid"]
    end = queries[-1]["qid"]
    output = Path(output) / f"expansion/{start}-{end}.json"
    if output.exists():
        print(f"Output file {output} already exists, skipping.")
        return
    logfile.parent.mkdir(parents=True, exist_ok=True)
    output.parent.mkdir(parents=True, exist_ok=True)

    resp = chat_complete(queries)
    with logfile.open("a") as f:
        f.write(json.dumps(resp) + "\n")

    data = json.loads(resp["choices"][0]["message"]["content"])
    # check that all of the quids are present
    input_qids = {q["qid"] for q in queries}
    output_qids = {d["qid"] for d in data}
    if input_qids != output_qids:
        raise ValueError("Input and output qids do not match.")

    # write the output to a file
    with output.open("w") as f:
        json.dump(data, f, indent=2)

    # return the data
    return data


resp = query_expansion(deduped.take(5), "~/scratch/longeval/query_expansion_tmp")
resp

[{'qid': '2',
  'query': '18 vidéos clips courts tutoriels explicatifs didactiques documentaires cours chaîne formation module série collection épisodes HD 4K'},
 {'qid': '3',
  'query': 'Première Guerre Mondiale Grande Guerre 1914-1918 conflit mondial combats tranchées poilus batailles histoire causes conséquences traité de Versailles armistice'},
 {'qid': '4',
  'query': 'Première Guerre Mondiale CE2 cours leçon histoire pour enfants primaire cycle 2 explication facile résumé quiz exercices ressources pédagogiques école'},
 {'qid': '5',
  'query': 'Deuxième Guerre Mondiale Seconde Guerre Mondiale 1939-1945 conflit mondial nazisme Alliés Axe camps de concentration résistance débarquement bataille de Normandie'},
 {'qid': '6',
  'query': 'Deuxième Guerre Mondiale leçons histoire causes événements conséquences principaux acteurs Hitler Staline Churchill alliances batailles clés Hiroshima Nagasaki capitulation'}]

In [None]:
# okay let's do this in batches of 1000, until we reach the end of the dataset
import tqdm

batch_size = 100
deduped_rows = deduped.collect()
for i in tqdm.tqdm(range(0, len(deduped_rows), batch_size)):
    batch = deduped_rows[i : i + batch_size]
    try:
        query_expansion(batch, "~/scratch/longeval/query_expansion")
    except Exception as e:
        print(f"Error processing batch {i}-{i + batch_size}: {e}")
        continue