# Transformer-Kernel (TK) Reranker Pipeline
Includes robust file checks, BM25 candidates, TK model fine-tuning with scheduler, and detailed logging.

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import random
import pickle
from collections import defaultdict
import math
from tqdm.auto import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
# Load data - Traditional IR BM25 Embeddings
df_query_train = pd.read_pickle("df_query_train_top100.pkl")
df_query_dev   = pd.read_pickle("df_query_dev_top100.pkl")
df_query_test  = pd.read_pickle("df_query_test_top100.pkl")

df_collection = pd.read_pickle("subtask_4b/subtask4b_collection_data.pkl")

In [4]:
# print(df_collection.columns.tolist())
print(df_query_train.columns.tolist())

['post_id', 'tweet_text', 'cord_uid', 'normalized_tweet_text', 'cleaned_tweet_text', 'final_query', 'bm25_topk', 'in_topx']


## TK - Model Definition

- Loads SciBERT (`AutoModel`)
- Defines Gaussian kernels
- Implements `forward(q_input, d_input)` returning a score tensor

In [5]:
class TKReRanker(nn.Module):
    def __init__(self, model_name='allenai/scibert_scivocab_uncased', num_kernels=11):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.kernels = self._init_kernels(num_kernels)
        self.linear = nn.Linear(num_kernels, 1)

    def _init_kernels(self, num_kernels):
        mus = torch.linspace(-1, 1, steps=num_kernels)
        sigmas = torch.full((num_kernels,), 0.1)
        return nn.ParameterDict({
            'mus': nn.Parameter(mus, requires_grad=False),
            'sigmas': nn.Parameter(sigmas, requires_grad=False)
        })

    def forward(self, q_input, d_input):
        q_vecs = self.bert(**q_input).last_hidden_state
        d_vecs = self.bert(**d_input).last_hidden_state

        q_norm = F.normalize(q_vecs, p=2, dim=-1)
        d_norm = F.normalize(d_vecs, p=2, dim=-1)
        sim = torch.bmm(q_norm, d_norm.transpose(1, 2))  # [B, Q, D]

        batch_kernels = []
        for mu, sigma in zip(self.kernels['mus'], self.kernels['sigmas']):
            kernel = torch.exp(-((sim - mu)**2) / (2 * sigma**2))
            log_pool = torch.log(torch.clamp(kernel.sum(dim=2), min=1e-8))
            batch_kernels.append(log_pool.sum(dim=1))

        K = torch.stack(batch_kernels, dim=1)
        return self.linear(K).squeeze(1)

In [6]:
class TKTripletDataset(Dataset):
    def __init__(self, df_q, df_coll, tokenizer, num_negatives=1, doc_field='abstract'):
        self.tokenizer = tokenizer
        self.samples = []
        coll_text = df_coll.set_index('cord_uid')[doc_field].to_dict()
        for _, r in df_q.iterrows():
            q, pos_uid = r['final_query'], r['cord_uid']
            pos_txt = coll_text[pos_uid]
            negs = [uid for uid in r['bm25_topk'] if uid != pos_uid]
            for _ in range(num_negatives):
                neg_uid = random.choice(negs)
                self.samples.append((q, pos_txt, coll_text[neg_uid]))
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx): return self.samples[idx]

def collate_tk(batch):
    claims, pos_docs, neg_docs = zip(*batch)
    labels = torch.zeros(len(batch), dtype=torch.float)

    q_tok = tokenizer(list(claims), padding=True, truncation=True, max_length=512, return_tensors='pt' )
    p_tok = tokenizer(list(pos_docs), padding=True, truncation=True, max_length=512, return_tensors='pt')
    n_tok = tokenizer(list(neg_docs), padding=True, truncation=True, max_length=512, return_tensors='pt')

    return q_tok, p_tok, n_tok, labels

In [7]:
def rerank_batch(df_q, df_coll, model, tokenizer, prerank_col='bm25_topk', query_col='final_query', 
                 doc_id_col='cord_uid', doc_field='abstract', rerank_k=10, device=device):

    model.eval()
    reranked = []

    # build a quick lookup for doc text
    coll_text = df_coll.set_index(doc_id_col)[doc_field].to_dict()

    with torch.no_grad():
        for q_text, cands in zip(df_q[query_col], df_q[prerank_col]):
            scores = []
            # only consider the top-N from prerank
            for uid in cands:
                doc_text = coll_text[uid]

                # tokenize & move to device
                q_tok = tokenizer(q_text, padding=True, truncation=True, max_length=512, return_tensors='pt').to(device)
                d_tok = tokenizer(doc_text, padding=True, truncation=True, max_length=512, return_tensors='pt').to(device)

                # score
                s = model(q_tok, d_tok).item()
                scores.append((uid, s))

            # sort by score descending, take top rerank_k
            topk = [uid for uid, _ in sorted(scores, key=lambda x: x[1], reverse=True)[:rerank_k]]
            reranked.append(topk)

    return reranked

def get_performance_mrr(data, col_gold, col_pred, list_k=[1,5,10]):
    perf = {}
    for k in list_k:
        data['in_topx'] = data.apply(
            lambda x: 1/([i for i in x[col_pred][:k]].index(x[col_gold])+1)
                    if x[col_gold] in x[col_pred][:k] else 0,
            axis=1
        )
        perf[k] = data['in_topx'].mean()
    return perf

In [8]:
# --- Training with mixed‐precision, scheduler & Dev‐MRR logging ---
model = TKReRanker().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=1e-2)
epochs = 4

# Pre‐tokenization is optional—your collate_tk already does truncation to 512
train_ds = TKTripletDataset(df_query_train, df_collection, tokenizer, num_negatives=1)

train_loader = DataLoader(
    train_ds,
    batch_size=16,         # ↑ larger batch if GPU permits
    shuffle=True,
    num_workers=4,         # ↑ parallel tokenization
    pin_memory=True,       # ↑ faster CPU→GPU copies
    collate_fn=collate_tk
)

total_steps = len(train_loader) * epochs
scheduler   = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)

criterion    = nn.MarginRankingLoss(margin=0.5)
loss_history = []
mrr_history  = []
best_tk_mrr  = 0.0

# initialize mixed‐precision scaler
scaler = GradScaler()

for epoch in range(epochs):
    model.train()
    epoch_loss = 0.0

    for step, (q_tok, p_tok, n_tok, _) in enumerate(train_loader):
        # ─── Move everything to GPU here ───
        q_tok = {k: v.to(device) for k, v in q_tok.items()}
        p_tok = {k: v.to(device) for k, v in p_tok.items()}
        n_tok = {k: v.to(device) for k, v in n_tok.items()}
        
        labels = _.to(device)
    
        optimizer.zero_grad()
        with autocast():
            s_pos = model(q_tok, p_tok)
            s_neg = model(q_tok, n_tok)
            target = torch.ones_like(s_pos)
            loss   = criterion(s_pos, s_neg, target)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        epoch_loss += loss.item()
        if step % 100 == 0:
            print(f"  [Step {step}/{len(train_loader)}]  loss = {loss.item():.4f}")

    avg_loss = epoch_loss / len(train_loader)
    loss_history.append(avg_loss)

    # --- Dev‐set evaluation ---
    model.eval()
    tk_dev_top10 = rerank_batch(df_query_dev, df_collection, model, tokenizer)
    df_query_dev['tk_top10'] = tk_dev_top10
    dev_mrr = get_performance_mrr(df_query_dev, 'cord_uid', 'tk_top10')[5]

    print(f"*** Epoch {epoch} complete. Avg loss = {avg_loss:.4f} | Dev MRR@5 = {dev_mrr:.4f} ***")

    # --- Save best model checkpoint ---
    if dev_mrr > best_tk_mrr:
        best_tk_mrr = dev_mrr
        torch.save(model.state_dict(), "tk_best.pt")
        print(f" --- New best Dev MRR@5 = {dev_mrr:.4f} — saved tk_best.pt")

# --- Persist histories for reporting/plotting ---
with open("tk_loss_history.pkl", "wb") as f:
    pickle.dump(loss_history, f)
with open("tk_mrr_history.pkl", "wb") as f:
    pickle.dump(mrr_history, f)

print("Training done. Best Dev MRR@5 =", best_tk_mrr)

  scaler = GradScaler()
Traceback (most recent call last):
  File [35m"<string>"[0m, line [35m1[0m, in [35m<module>[0m
    from multiprocessing.spawn import spawn_main; [31mspawn_main[0m[1;31m(tracker_fd=82, pipe_handle=96)[0m
                                                  [31m~~~~~~~~~~[0m[1;31m^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^[0m
  File [35m"/Users/kartikarya/.pyenv/versions/3.13.2/lib/python3.13/multiprocessing/spawn.py"[0m, line [35m122[0m, in [35mspawn_main[0m
    exitcode = _main(fd, parent_sentinel)
  File [35m"/Users/kartikarya/.pyenv/versions/3.13.2/lib/python3.13/multiprocessing/spawn.py"[0m, line [35m132[0m, in [35m_main[0m
    self = reduction.pickle.load(from_parent)
[1;35mAttributeError[0m: [35mCan't get attribute 'TKTripletDataset' on <module '__main__' (<class '_frozen_importlib.BuiltinImporter'>)>[0m


KeyboardInterrupt: 

In [None]:
# --- Evaluation with logs ---
model = TKReRanker().to(device)
model.load_state_dict(torch.load('tk_reranker_kartik.pt', map_location=device))
model.eval()

tk_topk = []
for idx, (query, bm25_cands) in enumerate(zip(df_query_test['final_query'], df_query_test['bm25_topk'])):
    if idx % 10 == 0:
        print(f'Reranking query {idx}/{len(df_query_test)}')
    scores = []
    for uid in bm25_cands[:100]:
        doc = df_collection.loc[df_collection['cord_uid']==uid, 'abstract'].item()
        q_tok = tokenizer(query, padding=True, truncation=True, max_length=512, return_tensors='pt').to(device)
        d_tok = tokenizer(doc,   padding=True, truncation=True, max_length=512, return_tensors='pt').to(device)
        with torch.no_grad(): s = model(q_tok, d_tok).item()
        scores.append((uid, s))
    ranked = [u for u,_ in sorted(scores, key=lambda x: x[1], reverse=True)[:10]]
    tk_topk.append(ranked)
df_query_dev['tk_topk'] = tk_topk

results = get_performance_mrr(df_query_test, 'cord_uid', 'tk_topk')
print('TK Reranker MRR@k:', results)