In [47]:
import pandas as pd
import random
import numpy as np

# For reproducibility
random.seed(42)


In [48]:
# Load combined dataset
df = pd.read_parquet("dataprep/ms_marco_combined/combined.parquet")
print(f"Loaded {len(df)} queries.")


Loaded 102023 queries.


In [49]:
# Build positive map and global negatives (skip empty texts)
pos_map = {}    # qid -> {"query": qtxt, "positives": [texts]}
neg_list = []   # list of (qid, neg_text)

for row in df.itertuples():
    qid, qtxt = row.query_id, row.query
    sel_arr = row.passages['is_selected']
    txt_arr = row.passages['passage_text']

    for sel, txt in zip(sel_arr, txt_arr):
        if not isinstance(txt, str) or not txt.strip():
            continue
        if sel == 1:
            pos_map.setdefault(qid, {"query": qtxt, "positives": []})["positives"].append(txt)
        else:
            neg_list.append((qid, txt))


In [50]:
# Stats after building pools
total_queries = len(df)
queries_with_pos = len(pos_map)
total_negatives = len(neg_list)
print(f"Queries total: {total_queries}")
print(f"Queries with ≥1 positive: {queries_with_pos}")
print(f"Negatives in pool: {total_negatives}")


Queries total: 102023
Queries with ≥1 positive: 98755
Negatives in pool: 727975


In [53]:
def extract_triplets(raw_dataframe: pd.DataFrame, offset: int = 4) -> pd.DataFrame:
    triplets = []
    num_rows = len(raw_dataframe)

    for i in range(num_rows):
        row = raw_dataframe.iloc[i]
        query = row['query']
        passages = row['passages']

        if not isinstance(passages, dict):
            continue

        is_selected = passages.get('is_selected', [])
        passage_texts = passages.get('passage_text', [])

        if len(is_selected) == 0 or len(passage_texts) == 0:
            continue

        selected_indices = np.where(np.array(is_selected) == 1)[0]
        if selected_indices.size == 0:
            continue

        positive_passage = passage_texts[selected_indices[0]]

        negative_index = (i + offset) % num_rows
        if negative_index == i:
            continue

        negative_row = raw_dataframe.iloc[negative_index]
        negative_passages_dict = negative_row['passages']

        if not isinstance(negative_passages_dict, dict):
            continue

        neg_passages = negative_passages_dict.get('passage_text', [])
        if len(neg_passages) == 0:
            continue

        # Try 3 times to get a different negative passage
        negative_passage = None
        for _ in range(3):
            candidate = random.choice(neg_passages)
            if candidate != positive_passage:
                negative_passage = candidate
                break

        if negative_passage is None:
            continue

        triplets.append({
            'query': query,
            'positive_passage': positive_passage,
            'negative_passage': negative_passage
        })

    return pd.DataFrame(triplets)

In [54]:
df = pd.read_parquet("dataprep/ms_marco_combined/combined.parquet")
triples_df = extract_triplets(df, offset=4)
print("Total triplets:", len(triples_df))

Total triplets: 98755


In [55]:
sampled = triples_df[['query', 'positive_passage', 'negative_passage']] \
            .sample(10, random_state=42) \
            .reset_index(drop=True)
display(sampled)

Unnamed: 0,query,positive_passage,negative_passage
0,what is a landmark in colombia,Bogota. Colombia's capital of Bogota is in the...,The second level of the food chains is called ...
1,how to prepare a field for planting pumpkins,1. Prepare your soil. If you are anticipating ...,A. The Locarno Conference was called partly be...
2,should you lighten your hair as you get older,I don't want to be a blond. I was thinking of ...,HIP ABDUCTOR MUSCLES The hip abductor muscles ...
3,why is carbon dioxide inorganic,an organic compound is a compound which must c...,The truth: Age is the most important factor af...
4,is licorice fattening,"No, the red part is artificial. But licorice i...","Amount-With the Child Tax Credit, you may be a..."
5,how is the term flashback defined,Flashback or flashbacks may refer to: 1 Flash...,over 55 years. Building Blocks of Nutrition: F...
6,what age does universal credit include,Young people and Universal Credit. From April ...,These are the richest athletes in the world! H...
7,what is white zircon,Zircon is a type of gemstone. It's desirable b...,Those teaching elementary school earned annual...
8,how has the first amendment been incorporated ...,The First Amendment has been fully incorporate...,Rating Newest Oldest. Best Answer: The five-po...
9,is a company liable for employees driving own ...,When you decide to permit an employee to drive...,A barebones system package will differ dependi...


In [56]:
# shuffle the triplets
shuffled = triples_df.sample(frac=1, random_state=42).reset_index(drop=True)

# compute split sizes (80% train, 10% val, 10% test)
n = len(shuffled)
n_train = int(0.8 * n)
n_val   = int(0.1 * n)

# split the triplets into train, validate, test
train_df = shuffled.iloc[:n_train]
validation_df   = shuffled.iloc[n_train : n_train + n_val]
test_df  = shuffled.iloc[n_train + n_val :]

# verify the splits
print(f"Train: {len(train_df)}")
print(f"Validation: {len(validation_df)}")
print(f"Test: {len(test_df)}")

# save the splits
train_df.to_parquet("dataprep/ms_marco_combined/train.parquet", index=False)
validation_df.to_parquet("dataprep/ms_marco_combined/validation.parquet", index=False)
test_df.to_parquet("dataprep/ms_marco_combined/test.parquet", index=False)


Train: 79004
Validation: 9875
Test: 9876
