In [1]:
!pip install datasets
!pip install --upgrade sentence-transformers
!pip install langchain_experimental

import pandas as pd
import numpy as np
import ast
import json
from tqdm import tqdm
from sentence_transformers import CrossEncoder
from torch.utils.data import DataLoader
import torch
from datasets import Dataset
from datasets import load_dataset
from sentence_transformers.util import mine_hard_negatives
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder.losses import BinaryCrossEntropyLoss
from sentence_transformers.cross_encoder import CrossEncoderTrainer
from sentence_transformers.cross_encoder import CrossEncoderTrainingArguments
import os
from collections import defaultdict
from sentence_transformers.cross_encoder.evaluation import CrossEncoderRerankingEvaluator
from langchain_experimental.text_splitter import SemanticChunker
from langchain_community.embeddings import HuggingFaceEmbeddings

Collecting datasets
  Downloading datasets-3.5.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.1-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.4/491.4 kB[0m [31m33.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m16.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl 

In [3]:
PATH_COLLECTION_DATA = 'subtask4b_collection_data.pkl'
PATH_QUERY_TRAIN_DATA = 'subtask4b_query_tweets_train.tsv' #MODIFY PATH
PATH_QUERY_DEV_DATA = 'subtask4b_query_tweets_dev.tsv' #MODIFY PATH
PATH_QUERY_TRAIN_BM25 = 'df_train_bm25_50.csv' #MODIFY PATH
PATH_QUERY_DEV_BM25 = 'df_dev_bm25_50.csv' #MODIFY PATH
PATH_QUERY_TRAIN_GRANITE = 'granite_top75_train.json' #MODIFY PATH
PATH_QUERY_DEV_GRANITE = 'granite_top75_dev.json' #MODIFY PATH

df_collection = pd.read_pickle(PATH_COLLECTION_DATA)
df_train = pd.read_csv(PATH_QUERY_TRAIN_DATA, sep = '\t')
df_dev = pd.read_csv(PATH_QUERY_DEV_DATA, sep = '\t')
#df_train_bm25 = pd.read_csv(PATH_QUERY_TRAIN_BM25, sep = ',')
#df_dev_bm25 = pd.read_csv(PATH_QUERY_DEV_BM25, sep = ',')

#df_dev_bm25["bm25_topk"] = df_dev_bm25["bm25_topk"].apply(ast.literal_eval)
#df_train_bm25["bm25_topk"] = df_train_bm25["bm25_topk"].apply(ast.literal_eval)

df_train_granite = pd.read_json(PATH_QUERY_TRAIN_GRANITE)
df_dev_granite = pd.read_json(PATH_QUERY_DEV_GRANITE)

df_train_granite = pd.merge(df_train_granite, df_train[['post_id', 'tweet_text']], left_on='tweet', right_on='post_id', how='left').drop(columns='post_id')
df_dev_granite = pd.merge(df_dev_granite, df_dev[['post_id', 'tweet_text']], left_on='tweet', right_on='post_id', how='left').drop(columns='post_id')

print(df_train_granite)
print(df_dev_granite)

       tweet gold_paper                                          retrieved  \
0          0   htlvpvz5  [htlvpvz5, dp9x046e, 9gnqfmbq, 2aowm09g, yq2dt...   
1          1   4kfl29ul  [wvfw94n1, 7k8nlea3, 3o7rd8pt, m4vu77v6, 29z4q...   
2          2   jtwb17u8  [jtwb17u8, iobpcfs5, 5aev7ltr, 79m3sdfe, bzeqs...   
3          3   0w9k8iy1  [0w9k8iy1, 2dfw87sl, 8fbsaocw, nvbt5gxl, v4xsz...   
4          4   tiqksd69  [tiqksd69, snk26ii3, aqbhxv1f, 8n4zf9oo, b0dzh...   
...      ...        ...                                                ...   
12848  14248   9169o29b  [tz2shoso, 08bw0h8m, n3nwra0o, 2veblo5v, wuazk...   
12849  14249   s2bpha8l  [s2bpha8l, 8a3fp7ym, 4vq9ljlg, 307rt03e, 9rxv6...   
12850  14250   atloc9th  [pc2cnhjd, atloc9th, iqe6sdq2, e0pbs354, x51jo...   
12851  14251   t4y1ylb3  [7a543f7v, t4y1ylb3, rnfh9v1h, sgo76prc, zctjk...   
12852  14252   nlsv8bin  [nlsv8bin, 43joavrl, apfimvix, nmrxjal1, fup55...   

                                              tweet_text  
0   

In [13]:
MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L12-v2"
EMBEDDING_MODEL_NAME = "sentence-transformers/static-retrieval-mrl-en-v1"
TRAIN_BATCH_SIZE = 32
NUM_EPOCHS = 3
NUM_HARD_NEGATIVES = 3
os.environ["WANDB_DISABLED"] = "true"

In [14]:
embedding_model = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
text_splitter = SemanticChunker(
    embeddings=embedding_model,
    breakpoint_threshold_type="gradient",
    breakpoint_threshold_amount=0.3
)

def semantic_chunking(text):
    documents = text_splitter.create_documents([text])
    chunks = [doc.page_content for doc in documents]
    return chunks

In [6]:
queries, documents, labels = [], [], []

for row in tqdm(df_train_granite.itertuples(), total=len(df_train_granite)):
    matched = df_collection[df_collection['cord_uid'] == row.gold_paper]

    queries.append(row.tweet_text)
    documents.append(matched['abstract'].iloc[0])
    labels.append(1.0)

full_dataset = Dataset.from_dict({
    "query": queries,
    "answer": documents,
    "label": labels
})

dataset = full_dataset.train_test_split(test_size=1000, seed=12)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]


100%|██████████| 12853/12853 [00:13<00:00, 956.48it/s]


In [15]:
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)

hard_train_dataset = mine_hard_negatives(
    train_dataset,
    embedding_model,
    num_negatives=NUM_HARD_NEGATIVES,
    relative_margin=0.1,
    max_score=0.8,
    range_min=1,
    range_max=2000,
    sampling_strategy="top",
    output_format='labeled-pair',
    batch_size=128
)

Found 11843 unique queries out of 11853 total queries.
Found an average of 1.001 positives per query.


Batches:   0%|          | 0/52 [00:00<?, ?it/s]

Batches:   0%|          | 0/93 [00:00<?, ?it/s]

Metric       Positive       Negative     Difference
Count          11,853         34,180               
Mean           0.5195         0.4490         0.0842
Median         0.5362         0.4560         0.0638
Std            0.1639         0.1217         0.0633
Min           -0.0981         0.0285        -0.6680
25%            0.4158         0.3673         0.0463
50%            0.5362         0.4560         0.0638
75%            0.6411         0.5380         0.0981
Max            0.9577         0.7882         0.6012
Skipped 2,686,569 potential negatives (11.33%) due to the relative_margin of 0.1.
Could not find enough negatives for 1379 samples (3.88%). Consider adjusting the range_max, range_min, relative_margin and max_score parameters if you'd like to find more valid negatives.


In [16]:
query, document, labels = [], [], []

for row in tqdm(hard_train_dataset):
    doc_chunks = semantic_chunking(row['answer'])
    for chunk in doc_chunks:
        query.append(row['query'])
        document.append(chunk)
        labels.append(row['label'])

hard_train_dataset = Dataset.from_dict({
    "query": query,
    "document": document,
    "label": labels
})

100%|██████████| 46006/46006 [03:09<00:00, 242.95it/s]


In [17]:
hard_eval_dataset = mine_hard_negatives(
    eval_dataset,
    embedding_model,
    num_negatives=NUM_HARD_NEGATIVES,
    relative_margin=0.1,
    max_score=0.8,
    range_min=1,
    range_max=250,
    sampling_strategy="top",
    batch_size=128,
    output_format='n-tuple'
)

reranking_evaluator = CrossEncoderRerankingEvaluator(
    samples=[{
        "query": sample["query"],
        "positive": [sample["answer"]],
        "negative": [sample[col] for col in hard_eval_dataset.column_names[2:]],
    } for sample in hard_eval_dataset],
    batch_size=TRAIN_BATCH_SIZE
)

Batches:   0%|          | 0/7 [00:00<?, ?it/s]

Batches:   0%|          | 0/8 [00:00<?, ?it/s]

Metric       Positive       Negative     Difference
Count           1,000          2,844               
Mean           0.5174         0.4193         0.1161
Median         0.5309         0.4238         0.0847
Std            0.1652         0.1141         0.0876
Min           -0.0474        -0.0157        -0.0012
25%            0.4085         0.3414         0.0539
50%            0.5310         0.4238         0.0848
75%            0.6416         0.4989         0.1540
Max            0.9473         0.7379         0.5897
Skipped 32,181 potential negatives (12.82%) due to the relative_margin of 0.1.
Could not find enough negatives for 156 samples (5.20%). Consider adjusting the range_max, range_min, relative_margin and max_score parameters if you'd like to find more valid negatives.


In [18]:
model = CrossEncoder(MODEL_NAME, device='cuda')
loss = BinaryCrossEntropyLoss(model=model, pos_weight=torch.tensor(NUM_HARD_NEGATIVES))

args = CrossEncoderTrainingArguments(
    output_dir="models",
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=TRAIN_BATCH_SIZE,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    dataloader_num_workers=2,
    load_best_model_at_end=True,
    metric_for_best_model='eval_mrr@10',
    eval_strategy="steps",
    eval_steps=5000,
    save_steps=5000,
    seed=12
)

trainer = CrossEncoderTrainer(
    model=model,
    args=args,
    train_dataset=hard_train_dataset,
    loss=loss,
    evaluator=reranking_evaluator,
)

trainer.train()

results = reranking_evaluator(model)
print(results)

trainer.save_model("reranker_model")

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Step,Training Loss,Validation Loss,Map,Mrr@10,Ndcg@10
5000,0.5535,No log,0.938994,0.938994,0.954703
10000,0.4728,No log,0.95596,0.95596,0.967309
15000,0.3762,No log,0.966772,0.966772,0.97534
20000,0.3328,No log,0.954729,0.954729,0.966484
25000,0.3076,No log,0.973365,0.973365,0.980273
30000,0.2478,No log,0.956136,0.956136,0.967506
35000,0.2375,No log,0.958861,0.958861,0.96955
40000,0.2427,No log,0.957982,0.957982,0.968918


{'map': 0.9733649789029536, 'mrr@10': 0.9733649789029536, 'ndcg@10': 0.9802729920406151}


In [19]:
#model.half()

collection_dict = df_collection.set_index('cord_uid')['abstract'].to_dict()

pairs = []
query_indices = []
uid_mappings = []

for idx, row in tqdm(enumerate(df_dev_granite.itertuples())):
    query = row.tweet_text
    candidate_uids = row.retrieved

    for uid in candidate_uids[:25]:
        abstract = collection_dict[uid]
        chunks = semantic_chunking(abstract)

        for chunk in chunks:
            pairs.append([query, chunk])
            query_indices.append(idx)
            uid_mappings.append(uid)


all_scores = model.predict(pairs)

query_results = [defaultdict(float) for _ in range(len(df_dev_granite))]

for idx, uid, score in zip(query_indices, uid_mappings, all_scores):
    query_results[idx][uid] = max(query_results[idx][uid], score)

reranked_uids = []

for idx in range(len(df_dev_granite)):
    max_scores = query_results[idx]
    sorted_uids = sorted(max_scores.items(), key=lambda x: x[1], reverse=True)
    reranked_uids.append([uid for uid, _ in sorted_uids])

df_dev_granite['reranked'] = reranked_uids

1400it [02:35,  8.98it/s]


In [20]:
def get_performance_mrr(data, col_gold, col_pred, list_k = [1, 5, 10]):
    d_performance = {}
    for k in list_k:
        data["in_topx"] = data.apply(lambda x: (1/([i for i in x[col_pred][:k]].index(x[col_gold]) + 1) if x[col_gold] in [i for i in x[col_pred][:k]] else 0), axis=1)
        #performances.append(data["in_topx"].mean())
        d_performance[k] = data["in_topx"].mean()
    return d_performance

results_dev = get_performance_mrr(df_dev_granite, 'gold_paper', 'retrieved')
results_dev_reranked = get_performance_mrr(df_dev_granite, 'gold_paper', 'reranked')

print(f"Granite Results: {results_dev}")
print(f"Reranked Results: {results_dev_reranked}")

Granite Results: {1: np.float64(0.5257142857142857), 5: np.float64(0.5813214285714287), 10: np.float64(0.5882225056689343)}
Reranked Results: {1: np.float64(0.5742857142857143), 5: np.float64(0.6289642857142856), 10: np.float64(0.6347859977324263)}
