In [1]:
import torch
from torch.nn.functional import cosine_similarity
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModel

from tqdm.notebook import tqdm
import json
from pathlib import Path
import re
import random
import os
import numpy as np
import pandas as pd

# 1) Importing query and collection data

In [2]:
PATH_COLLECTION_DATA = 'subtask_4b/subtask4b_collection_data.pkl' 
df_collection = pd.read_pickle(PATH_COLLECTION_DATA)

In [3]:
PATH_QUERY_TRAIN_DATA = 'subtask_4b/subtask4b_query_tweets_train.tsv'
PATH_QUERY_DEV_DATA = 'subtask_4b/subtask4b_query_tweets_dev.tsv' 
df_query_train = pd.read_csv(PATH_QUERY_TRAIN_DATA, sep = '\t')
df_query_dev = pd.read_csv(PATH_QUERY_DEV_DATA, sep = '\t')

In [25]:
#df_query_train.head()

In [5]:
#df_collection.head()

# 2) Running the BM25 baseline
The following code runs a BM25 baseline.


In [15]:
from rank_bm25 import BM25Okapi

In [16]:
# Create the BM25 corpus
corpus = df_collection[:][['title', 'abstract']].apply(lambda x: f"{x['title']} {x['abstract']}", axis=1).tolist()
cord_uids = df_collection[:]['cord_uid'].tolist()
tokenized_corpus = [doc.split(' ') for doc in corpus]
bm25 = BM25Okapi(tokenized_corpus)

In [None]:
def get_top_cord_uids(query):
  text2bm25top = {}
  if query in text2bm25top.keys():
      return text2bm25top[query]
  else:
      tokenized_query = query.split(' ')
      doc_scores = bm25.get_scores(tokenized_query)
      indices = np.argsort(-doc_scores)[:100] # @k: how many docs shall the ranked list include?
      bm25_topk = [cord_uids[x] for x in indices]

      text2bm25top[query] = bm25_topk
      return bm25_topk


In [7]:
# Retrieve top50 candidates using the BM25 model

train_pkl_path = 'df_query_train_top100.pkl'
dev_pkl_path = 'df_query_dev_top100.pkl'

if not os.path.exists(train_pkl_path):
    df_query_train['bm25_topk'] = df_query_train['tweet_text'].apply(lambda x: get_top_cord_uids(x))
    df_query_train.to_pickle(train_pkl_path)
else:
    df_query_train = pd.read_pickle(train_pkl_path)

if not os.path.exists(dev_pkl_path):
    df_query_dev['bm25_topk'] = df_query_dev['tweet_text'].apply(lambda x: get_top_cord_uids(x))
    df_query_dev.to_pickle(dev_pkl_path)
else:
    df_query_dev = pd.read_pickle(dev_pkl_path)

# 3) Neural Re-Ranking

In [25]:
# load BM25 pre-ranked query dataframes
train_pkl_path = 'df_query_train_top100.pkl'
dev_pkl_path = 'df_query_dev_top100.pkl'
test_pkl_path = 'df_query_test_top100.pkl'

df_query_dev = pd.read_pickle(dev_pkl_path)
df_query_train = pd.read_pickle(train_pkl_path)
df_query_test = pd.read_pickle(test_pkl_path)

In [17]:
train_pkl_path = 'df_query_train_top200.pkl'
dev_pkl_path = 'df_query_dev_top200.pkl'
test_pkl_path = 'df_query_test_top200.pkl'

df_query_dev = pd.read_pickle(dev_pkl_path)
df_query_train = pd.read_pickle(train_pkl_path)
df_query_test = pd.read_pickle(test_pkl_path)

In [27]:
train_pkl_path = 'df_query_train_top50.pkl'
dev_pkl_path = 'df_query_dev_top50.pkl'
test_pkl_path = 'df_query_test_top50.pkl'

df_query_dev = pd.read_pickle(dev_pkl_path)
df_query_train = pd.read_pickle(train_pkl_path)
df_query_test = pd.read_pickle(test_pkl_path)

In [36]:
train_pkl_path = 'df_query_train_top25.pkl'
dev_pkl_path = 'df_query_dev_top25.pkl'
test_pkl_path = 'df_query_test_top25.pkl'

df_query_dev = pd.read_pickle(dev_pkl_path)
df_query_train = pd.read_pickle(train_pkl_path)
df_query_test = pd.read_pickle(test_pkl_path)

## 3.1) ColBERT w/ fine-tuned BERT

In [28]:
# get token embeddings of a specified text passage from some model
def get_token_embeddings(text, tokenizer, model, device='cpu'):
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    outputs = model(**inputs)
    token_embeddings = outputs.last_hidden_state.squeeze(0)
    attention_mask = inputs['attention_mask'].squeeze(0).bool()
    token_embeddings = token_embeddings[attention_mask] 
    return token_embeddings

# pre compute all the token embeddings of the documents
def build_and_save_doc_embeddings(
    docs_df,
    model_name,
    save_dir,
    max_len=512,
    device="cuda"
):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to(device)
    model.eval()

    save_path = Path("doc_embeddings_" + save_dir)
    save_path.mkdir(parents=True, exist_ok=True)

    metadata_path = save_path / "metadata.json"
    if metadata_path.exists():
        with open(metadata_path, "r") as f:
            metadata = json.load(f)
    else:
        metadata = {}

    print("Precomputing document embeddings.")
    for i, row in tqdm(docs_df.iterrows(), total=len(docs_df)):
        doc_id = row.get("cord_uid", f"doc_{i}")
        file_path = save_path / f"{doc_id}.pt"

        if file_path.exists() and doc_id in metadata:
            continue

        text = str(row.get('title', '')) + " " + str(row.get('abstract', '')) + " Authors: " + str(row.get('authors', ''))

        inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=max_len)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        output = model(**inputs)
        token_embeddings = output.last_hidden_state.squeeze(0)
        attention_mask = inputs['attention_mask'].squeeze(0).bool()
        token_embeddings = token_embeddings[attention_mask]

        n_tokens = token_embeddings.size(0)
        pad_len = max_len - n_tokens

        if pad_len > 0:
            padding = torch.zeros(pad_len, token_embeddings.size(1), device=device)
            token_embeddings = torch.cat([token_embeddings, padding], dim=0)
        else:
            token_embeddings = token_embeddings[:max_len]

        try:
            torch.save(token_embeddings.cuda(), file_path)
        except Exception as e:
            print(f"Error saving document {doc_id}: {e}")
            continue

        metadata[doc_id] = {
            "title": row.get("title", ""),
            "abstract": row.get("abstract", ""),
            "authors": row.get("authors", ""),
            "length": min(n_tokens, max_len),
            "path": str(file_path)
        }

    with open(metadata_path, "w") as f:
        json.dump(metadata, f)

    return metadata

# either precompute or load precomputed doc embeddings
def get_precomputed_doc_embeddings(save_name):
    def split_at_slash(s):
        if '/' in s:
            return s.split('/', 1)
        else:
            return ['', s]
        
    if not os.path.exists("doc_embeddings_" + split_at_slash(save_name)[1] + "/metadata.json"):
        metadata = build_and_save_doc_embeddings(df_collection, model_name=save_name, save_dir=save_name, device="cuda")
    else:
        with open("doc_embeddings_" + save_name + "/metadata.json", "r") as f:
            metadata = json.load(f)
    return metadata

In [29]:
# creating training dataset by getting the positive and a random negative document for each query
class ColBERTTripletDataset(Dataset):
    def __init__(self, df, metadata, tokenizer, num_negatives=1):
        self.data = []
        self.tokenizer = tokenizer
        self.metadata = metadata
        for _, row in df.iterrows():
            query = row["tweet_text"]
            pos = row["cord_uid"]
            negatives = [doc for doc in row["bm25_topk"] if doc != pos]
            if negatives:
                for _ in range(num_negatives):
                    neg = random.choice(negatives)
                    self.data.append((query, pos, neg))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# basic ColBERT scoring i.e. match matrix aggregation
def colbert_score_from_emb(q_emb, d_emb):
    q_norm = q_emb / q_emb.norm(dim=1, keepdim=True)
    d_norm = d_emb / d_emb.norm(dim=1, keepdim=True)
    sim_matrix = torch.matmul(q_norm, d_norm.T)
    max_sim_per_q = sim_matrix.max(dim=1).values
    return max_sim_per_q.sum()

# finetuning some BERT-model to get higher ColBERT-score 
# for the positive document than for the negative (per query)
def bert_finetune(save_name, MARGIN=0.5, BATCH_SIZE=8, EPOCHS=6, LR=2e-5, num_negatives=1):    
    model_name = "allenai/scibert_scivocab_uncased" # specify baseline BERT model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
    metadata = get_precomputed_doc_embeddings(model_name)

    # create training triplets
    train_dataset = ColBERTTripletDataset(df_query_train, metadata, tokenizer, num_negatives)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    
    # optimizer
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    model.train()
    model.to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    
    for epoch in range(EPOCHS):
        total_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            queries, pos_ids, neg_ids = batch
    
            inputs = tokenizer(list(queries), return_tensors='pt', padding=True, truncation=True, max_length=512)
            inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
            outputs = model(**inputs)
            q_emb_batch = outputs.last_hidden_state  # [B, L, D]
            attention_mask = inputs["attention_mask"].bool()
            q_embs = [emb[mask] for emb, mask in zip(q_emb_batch, attention_mask)]
    
            score_pos_list = []
            score_neg_list = []
    
            for i in range(len(queries)):
                d_pos_emb = torch.load(metadata[pos_ids[i]]["path"]).to(DEVICE)[:metadata[pos_ids[i]]["length"]]
                d_neg_emb = torch.load(metadata[neg_ids[i]]["path"]).to(DEVICE)[:metadata[neg_ids[i]]["length"]]
    
                q_emb = q_embs[i]
                score_pos = colbert_score_from_emb(q_emb, d_pos_emb)
                score_neg = colbert_score_from_emb(q_emb, d_neg_emb)
    
                score_pos_list.append(score_pos)
                score_neg_list.append(score_neg)
    
            score_pos_batch = torch.stack(score_pos_list)
            score_neg_batch = torch.stack(score_neg_list)
    
            loss = F.relu(MARGIN + score_neg_batch - score_pos_batch).mean()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            total_loss += loss.item()
    
        print(f"Epoch {epoch+1} Loss: {total_loss:.4f}")

    model.save_pretrained(save_name)
    tokenizer.save_pretrained(save_name)

In [33]:
bert_finetune("colB_sciB_marg05", MARGIN=0.5)

Precomputing document embeddings.


  0%|          | 0/7718 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 1 Loss: 188.6070


Epoch 2:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 2 Loss: 81.0262


Epoch 3:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 3 Loss: 35.0394


Epoch 4:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 4 Loss: 18.8931


Epoch 5:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 5 Loss: 16.2508


Epoch 6:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 6 Loss: 12.9944


### Reranking:

In [30]:
def rerank(df, metadata, tokenizer, model, save_name):
    device = next(model.parameters()).device
    df[save_name + '_scores'] = [[] for _ in range(len(df))]

    doc_embeddings = {}
    for doc_id, data in metadata.items():
        emb = torch.load(data["path"], map_location="cpu")
        doc_embeddings[doc_id] = emb

    with torch.no_grad():
        for idx, row in tqdm(df.iterrows(), total=len(df)):
            tweet_text = row['tweet_text']
            pre_ranked_docs = row['bm25_topk']

            q_emb = get_token_embeddings(tweet_text, tokenizer, model).to(device)
            q_norm = q_emb / q_emb.norm(dim=1, keepdim=True)

            scores = []
            for doc in pre_ranked_docs:
                emb = doc_embeddings[doc].to(device)
                length = metadata[doc]["length"]
                d_emb = emb[:length]
                d_norm = d_emb / d_emb.norm(dim=1, keepdim=True)

                sim_matrix = torch.matmul(q_norm, d_norm.T)
                max_sim_per_q = sim_matrix.max(dim=1).values
                score = max_sim_per_q.sum().item()
                scores.append(score)

            df.at[idx, save_name + '_scores'] = scores

    def sort_docs_by_score(row):
        doc_ids = row['bm25_topk']
        scores = row[save_name + '_scores']
        sorted_docs = [doc for doc, _ in sorted(zip(doc_ids, scores), key=lambda x: x[1], reverse=True)]
        return sorted_docs

    df[save_name + '_topk'] = df.apply(sort_docs_by_score, axis=1)
    return df

In [37]:
# specify model for re-ranking
model_name = "colB_sciB_marg05"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

# pre-compute embeddings
metadata = get_precomputed_doc_embeddings(model_name)

In [38]:
# re-rank BM25 list for dev data
df_query_dev = rerank(df_query_dev, metadata, tokenizer, model, model_name)

  0%|          | 0/1400 [00:00<?, ?it/s]

In [39]:
df_query_dev

Unnamed: 0,post_id,tweet_text,cord_uid,normalized_tweet_text,cleaned_tweet_text,final_query,bm25_topk,colB_sciB_marg05_scores,colB_sciB_marg05_topk
0,16,covid recovery: this study from the usa reveal...,3qvh482o,covid recovery: this study from the usa reveal...,covid recoveri studi usa reveal proport case e...,covid recoveri studi usa reveal proport case e...,"[25aj8rj5, 66g5lpm6, o4vvlmr4, vmmwtdia, trrg1...","[32.9312629699707, 31.442752838134766, 31.2328...","[styavbvi, bqn29m9k, atji1xge, vymre7uv, trrg1..."
1,69,"""Among 139 clients exposed to two symptomatic ...",r58aohnu,"""among 139 clients exposed to two symptomatic ...",among 139 client expos two symptomat hair styl...,among 139 client expos two symptomat hair styl...,"[r58aohnu, p0kg6dyz, s2vckt2w, yrowv62k, g5hg3...","[45.07097625732422, 36.53841781616211, 40.1458...","[r58aohnu, icgsbelo, tgd6gy3z, s2vckt2w, ncayc..."
2,73,I recall early on reading that researchers who...,sts48u9i,i recall early on reading that researchers wor...,recal earli read research who examin coronavir...,recal earli read research who examin coronavir...,"[mkwgkkoi, gruir7aw, xavegbty, vx1hjh26, ntxuf...","[23.75394630432129, 25.241092681884766, 25.281...","[sgo76prc, xavegbty, gruir7aw, l4o7nicc, ntxuf..."
3,93,You know you're credible when NIH website has ...,3sr2exq9,you know you're credible when national institu...,know your credibl nih websit paper 💃💃 someon p...,know your credibl nih websit paper 💃💃 someon p...,"[3sr2exq9, sv48gjkk, tx8ypqsm, z795y51f, k0f4c...","[45.87574005126953, 42.32080841064453, 36.4390...","[3sr2exq9, k0f4cwig, sv48gjkk, 8j3bb6zx, kca5r..."
4,96,Resistance to antifungal medications is a grow...,ybwwmyqy,resistance to antifungal medications is a grow...,resist antifung medic grow issu global scope d...,resist antifung medic grow issu global scope d...,"[ybwwmyqy, ouvq2wpq, rs3umc1x, sxx3yid9, vabb2...","[31.468902587890625, 29.60997200012207, 29.042...","[ybwwmyqy, vabb2f26, 3l6ipiwk, lzddnb8j, ouvq2..."
...,...,...,...,...,...,...,...,...,...
1395,14193,Residents at high risk of covid-19: effectiven...,0gn3b98n,residents at high risk of covid-19: effectiven...,resid high risk covid19 effect isol affect siz...,resid high risk covid19 effect isol affect siz...,"[0gn3b98n, n5sei1oc, d8x3b9a3, wotf0lzx, 4sqjv...","[37.77806091308594, 31.755569458007812, 32.635...","[0gn3b98n, bwmpamea, ueb7mjnv, zqekxlz9, qpzg8..."
1396,14196,"61% of teenagers hospitalized for covid were ""...",25bdifv6,"61% of teenagers hospitalized for covid were ""...",61 teenag hospit covid extrem obes less 5 teen...,61 teenag hospit covid extrem obes less 5 teen...,"[yhmcx7ae, s1gdbsfx, dq3qunwe, 0yysikc1, a1xjh...","[42.61531066894531, 40.86850357055664, 38.4249...","[cjmmwl2q, lpqdnuil, ocl5qf9o, yhmcx7ae, a1xjh..."
1397,14203,"""fresh evidence backing melatonin against covi...",qn6wawxk,"""fresh evidence backing melatonin against covi...",fresh evid back melatonin covid melatonin medi...,fresh evid back melatonin covid melatonin medi...,"[qn6wawxk, dsz66r4u, b3ui95vx, 059oj76m, 7x1aj...","[36.224021911621094, 31.55619239807129, 32.542...","[qn6wawxk, wrsk5vh9, b3ui95vx, dsz66r4u, yk2th..."
1398,14233,"the vaccine doesn't halt the spread, it is pro...",3u3i5myh,"the vaccine doesn't halt the spread, it is pro...",vaccine doesnt halt spread proven allevi sympt...,vaccine doesnt halt spread proven allevi sympt...,"[wt6azxc1, 25aj8rj5, uuxo3jk9, qh6fqna8, gtp5d...","[36.293182373046875, 32.04758071899414, 35.591...","[h9nzxlaf, 7368psat, y4m987yn, u66awao9, yx0u0..."


### 3.2) Finetuned BERT + MLP on matchmatrix

In [18]:
class MatchMatrixMLP(nn.Module):
    def __init__(self):
        super(MatchMatrixMLP, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(12, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, sim_matrix):
        """
        sim_matrix: [q_len, d_len]
        Returns: scalar score
        """

        # Query-wise pooling
        max_per_q = sim_matrix.max(dim=1).values  # [q_len]
        mean_per_q = sim_matrix.mean(dim=1)
        std_per_q = sim_matrix.std(dim=1)

        # Document-wise pooling
        max_per_d = sim_matrix.max(dim=0).values
        mean_per_d = sim_matrix.mean(dim=0)
        std_per_d = sim_matrix.std(dim=0)

        # Global pooling
        global_max = sim_matrix.max()
        global_mean = sim_matrix.mean()
        global_std = sim_matrix.std()

        # Aggregate features
        features = torch.tensor([
            max_per_q.mean(), mean_per_q.mean(), std_per_q.mean(),
            max_per_d.mean(), mean_per_d.mean(), std_per_d.mean(),
            global_max, global_mean, global_std,
            max_per_q.max(), max_per_d.max(), global_std
        ], device=sim_matrix.device)

        return self.mlp(features.unsqueeze(0)).squeeze()


In [19]:
def train_matchmlp(
    bert_model,
    tokenizer,
    metadata,
    train_dataset,
    device="cuda",
    epochs=5,
    margin=0.3,
    batch_size=8,
    lr=2e-5
):
    from torch.utils.data import DataLoader

    match_model = MatchMatrixMLP().to(device)
    bert_model.to(device).eval()
    match_model.train()

    optimizer = torch.optim.Adam(match_model.parameters(), lr=lr)
    triplet_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    def sim_matrix(q, d):
        q_norm = q / q.norm(dim=1, keepdim=True)
        d_norm = d / d.norm(dim=1, keepdim=True)
        return torch.matmul(q_norm, d_norm.T)  # [q_len, d_len]

    for epoch in range(epochs):
        total_loss = 0.0
        for batch in tqdm(triplet_loader, desc=f"Epoch {epoch+1}"):
            queries, pos_ids, neg_ids = batch

            # Encode queries
            inputs = tokenizer(list(queries), return_tensors='pt', padding=True, truncation=True, max_length=512)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            with torch.no_grad():
                q_output = bert_model(**inputs)

            attention_mask = inputs["attention_mask"].bool()
            q_embs = [emb[mask] for emb, mask in zip(q_output.last_hidden_state, attention_mask)]

            loss = 0.0
            for i in range(len(queries)):
                q_emb = q_embs[i]
                d_pos = torch.load(metadata[pos_ids[i]]["path"]).to(device)[:metadata[pos_ids[i]]["length"]]
                d_neg = torch.load(metadata[neg_ids[i]]["path"]).to(device)[:metadata[neg_ids[i]]["length"]]

                sim_pos = sim_matrix(q_emb, d_pos)
                sim_neg = sim_matrix(q_emb, d_neg)

                score_pos = match_model(sim_pos)
                score_neg = match_model(sim_neg)

                loss += F.relu(margin + score_neg - score_pos)

            loss = loss / len(queries)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}: Loss = {total_loss:.4f}")

    return match_model


In [22]:
model_name = "colB_sciB_marg05"
tokenizer = AutoTokenizer.from_pretrained(model_name)
scibert_model = AutoModel.from_pretrained(model_name)
train_dataset = ColBERTTripletDataset(df_query_train, metadata, tokenizer)

mlp_model = train_matchmlp(
    bert_model=scibert_model,
    tokenizer=tokenizer,
    metadata=metadata,
    train_dataset=train_dataset,
    device="cuda",
    epochs=10,
    margin=0.5,
    batch_size=8,
    lr=2e-5
)

Epoch 1:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 1: Loss = 718.9339


Epoch 2:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 2: Loss = 602.8695


Epoch 3:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 3: Loss = 461.1193


Epoch 4:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 4: Loss = 341.4806


Epoch 5:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 5: Loss = 262.1859


Epoch 6:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 6: Loss = 212.2987


Epoch 7:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 7: Loss = 180.7391


Epoch 8:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 8: Loss = 160.4647


Epoch 9:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 9: Loss = 147.2528


Epoch 10:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 10: Loss = 138.4881


In [23]:
def rerank_with_mlp(df, metadata, tokenizer, bert_model, match_model, save_name):
    device = next(bert_model.parameters()).device
    bert_model.eval()
    match_model.eval()

    doc_embeddings = {
        doc_id: torch.load(meta["path"], map_location=device)[:meta["length"]]
        for doc_id, meta in metadata.items()
    }

    df[save_name + "_scores"] = [[] for _ in range(len(df))]

    with torch.no_grad():
        for idx, row in tqdm(df.iterrows(), total=len(df)):
            tweet_text = row['tweet_text']
            doc_ids = row['bm25_topk']

            q_emb = get_token_embeddings(tweet_text, tokenizer, bert_model, device=device)
            q_emb = q_emb / q_emb.norm(dim=1, keepdim=True)

            scores = []
            for doc_id in doc_ids:
                d_emb = doc_embeddings[doc_id]
                d_emb = d_emb / d_emb.norm(dim=1, keepdim=True)

                sim = torch.matmul(q_emb, d_emb.T)  # [q_len, d_len]
                score = match_model(sim).item()
                scores.append(score)

            df.at[idx, save_name + "_scores"] = scores

    df[save_name + "_topk"] = df.apply(
        lambda row: [doc for doc, _ in sorted(
            zip(row['bm25_topk'], row[save_name + '_scores']),
            key=lambda x: x[1], reverse=True
        )],
        axis=1
    )

    return df

In [24]:
df_reranked = rerank_with_cnn(
    df=df_query_dev,
    metadata=metadata,
    tokenizer=tokenizer,
    bert_model=scibert_model,
    match_model=mlp_model,
    save_name="sciB_mlp_1"
)

  0%|          | 0/1400 [00:00<?, ?it/s]

### 3.2) Finetuned BERT + CNN on matchmatrix

In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MatchPyramid(nn.Module):
    def __init__(self, input_channels=1, conv_channels=16, pool_size=(6, 6)):
        super(MatchPyramid, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_channels, 8, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(8, conv_channels, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.pool = nn.AdaptiveMaxPool2d(pool_size)
        self.mlp = nn.Sequential(
            nn.Flatten(),
            nn.Linear(conv_channels * pool_size[0] * pool_size[1], 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, sim_matrix):
        # sim_matrix: [q_len, d_len]
        x = sim_matrix.unsqueeze(0).unsqueeze(0)  # [1, 1, q_len, d_len]
        x = self.conv(x)                          # [1, C, H, W]
        x = self.pool(x)                          # [1, C, pool_H, pool_W]
        score = self.mlp(x)                       # [1, 1]
        return score.squeeze()


In [27]:
def train_matchpyramid(
    bert_model,
    tokenizer,
    metadata,
    train_dataset,
    device="cuda",
    epochs=6,
    margin=0.5,
    batch_size=8,
    lr=2e-5
):
    from torch.utils.data import DataLoader

    model = MatchPyramid().to(device)
    bert_model.to(device).eval()
    model.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    def get_sim(q, d):
        q_norm = q / q.norm(dim=1, keepdim=True)
        d_norm = d / d.norm(dim=1, keepdim=True)
        return torch.matmul(q_norm, d_norm.T)

    for epoch in range(epochs):
        total_loss = 0.0
        for queries, pos_ids, neg_ids in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
            inputs = tokenizer(list(queries), return_tensors='pt', padding=True, truncation=True, max_length=512)
            inputs = {k: v.to(device) for k, v in inputs.items()}

            with torch.no_grad():
                output = bert_model(**inputs)
                attention_mask = inputs["attention_mask"].bool()
                q_embs = [o[mask] for o, mask in zip(output.last_hidden_state, attention_mask)]

            loss = 0.0
            for i in range(len(queries)):
                q_emb = q_embs[i]
                d_pos = torch.load(metadata[pos_ids[i]]["path"]).to(device)[:metadata[pos_ids[i]]["length"]]
                d_neg = torch.load(metadata[neg_ids[i]]["path"]).to(device)[:metadata[neg_ids[i]]["length"]]

                sim_pos = get_sim(q_emb, d_pos)
                sim_neg = get_sim(q_emb, d_neg)

                score_pos = model(sim_pos)
                score_neg = model(sim_neg)

                loss += F.relu(margin + score_neg - score_pos)

            loss = loss / len(queries)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}: Loss = {total_loss:.4f}")

    return model

In [28]:
model_name = "colB_sciB_marg05"
tokenizer = AutoTokenizer.from_pretrained(model_name)
scibert_model = AutoModel.from_pretrained(model_name)
train_dataset = ColBERTTripletDataset(df_query_train, metadata, tokenizer)

matchpyramid_model = train_matchpyramid(
    bert_model=scibert_model,
    tokenizer=tokenizer,
    metadata=metadata,
    train_dataset=train_dataset
)

Epoch 1:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 1: Loss = 341.4415


Epoch 2:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 2: Loss = 161.8874


Epoch 3:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 3: Loss = 150.7042


Epoch 4:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 4: Loss = 144.8148


Epoch 5:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 5: Loss = 141.0205


Epoch 6:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 6: Loss = 138.0958


In [None]:
torch.save(cnn_model.state_dict(), "matchcnn_model_1.pt")

In [29]:
def rerank_with_matchpyramid(df, metadata, tokenizer, bert_model, match_model, save_name):
    device = next(bert_model.parameters()).device
    bert_model.eval()
    match_model.eval()

    doc_embeddings = {
        doc_id: torch.load(meta["path"], map_location=device)[:meta["length"]]
        for doc_id, meta in metadata.items()
    }

    df[save_name + "_scores"] = [[] for _ in range(len(df))]

    with torch.no_grad():
        for idx, row in tqdm(df.iterrows(), total=len(df)):
            tweet_text = row['tweet_text']
            doc_ids = row['bm25_topk']

            q_emb = get_token_embeddings(tweet_text, tokenizer, bert_model, device=device)
            q_emb = q_emb / q_emb.norm(dim=1, keepdim=True)

            scores = []
            for doc_id in doc_ids:
                d_emb = doc_embeddings[doc_id]
                d_emb = d_emb / d_emb.norm(dim=1, keepdim=True)

                sim = torch.matmul(q_emb, d_emb.T)
                score = match_model(sim).item()
                scores.append(score)

            df.at[idx, save_name + "_scores"] = scores

    df[save_name + "_topk"] = df.apply(
        lambda row: [doc for doc, _ in sorted(zip(row['bm25_topk'], row[save_name + '_scores']), key=lambda x: x[1], reverse=True)],
        axis=1
    )

    return df


In [None]:
cnn_model = MatchCNN(input_size=32)
cnn_model.load_state_dict(torch.load("matchcnn_model_1.pt"))
cnn_model.to("cuda")
cnn_model.eval()

In [31]:
df_reranked = rerank_with_matchpyramid(
    df=df_query_dev,
    metadata=metadata,
    tokenizer=tokenizer,
    bert_model=scibert_model,
    match_model=matchpyramid_model,
    save_name="sciB_matchpyramid_1"
)

  0%|          | 0/1400 [00:00<?, ?it/s]

# 4) Evaluation
The following code evaluates the BM25 retrieval baseline on the query set using the Mean Reciprocal Rank score (MRR@5).

In [40]:
# Evaluate retrieved candidates using MRR@k
def get_performance_mrr(data, col_gold, col_pred, list_k = [1, 5, 10]):
    d_performance = {}
    for k in list_k:
        data["in_topx"] = data.apply(lambda x: (1/([i for i in x[col_pred][:k]].index(x[col_gold]) + 1) if x[col_gold] in [i for i in x[col_pred][:k]] else 0), axis=1)
        #performances.append(data["in_topx"].mean())
        d_performance[k] = data["in_topx"].mean()
    return d_performance

In [25]:
# ---- BM25 Baseline ----
results_train = get_performance_mrr(df_query_train, 'cord_uid', 'bm25_topk')
results_dev = get_performance_mrr(df_query_dev, 'cord_uid', 'bm25_topk')

print("---- BM25 Baseline ----")
print(f"Results on the train set: {results_train}")
print(f"Results on the dev set: {results_dev}")

---- BM25 Baseline ----
Results on the train set: {1: 0.5731735781529604, 5: 0.625250914183459, 10: 0.6308237901348459}
Results on the dev set: {1: 0.5657142857142857, 5: 0.616095238095238, 10: 0.6224325396825396}


In [32]:
# ---- ColBERT Re-Ranking (100 docs) @ Margin 0.5 ----
model_name = "colB_sciB_marg05"

results_dev = get_performance_mrr(df_query_dev, 'cord_uid', f'{model_name}_topk')
print("---- Re-Ranking Finetune: ColBERT (SciBERT) ----")
print(f"MRR@5 on dev set: {results_dev[5]}")

---- Re-Ranking Finetune-4: ColBERT (SciBERT) ----
MRR@5: 0.6806309523809524


In [22]:
# ---- ColBERT Re-Ranking (200 docs) @ Margin 0.5 ----
model_name = "colB_sciB_marg05"

results_dev = get_performance_mrr(df_query_dev, 'cord_uid', f'{model_name}_topk')
print("---- Re-Ranking Finetune: ColBERT (SciBERT) ----")
print(f"MRR@5 on dev set: {results_dev[5]}")

---- Re-Ranking Finetune: ColBERT (SciBERT) ----
MRR@5 on dev set: 0.6758214285714286


In [35]:
# ---- ColBERT Re-Ranking (50 docs) @ Margin 0.5 ----
model_name = "colB_sciB_marg05"

results_dev = get_performance_mrr(df_query_dev, 'cord_uid', f'{model_name}_topk')
print("---- Re-Ranking Finetune: ColBERT (SciBERT) ----")
print(f"MRR@5 on dev set: {results_dev[5]}")

---- Re-Ranking Finetune: ColBERT (SciBERT) ----
MRR@5 on dev set: 0.6803214285714285


In [41]:
# ---- ColBERT Re-Ranking (25 docs) @ Margin 0.5 ----
model_name = "colB_sciB_marg05"

results_dev = get_performance_mrr(df_query_dev, 'cord_uid', f'{model_name}_topk')
print("---- Re-Ranking Finetune: ColBERT (SciBERT) ----")
print(f"MRR@5 on dev set: {results_dev[5]}")

---- Re-Ranking Finetune: ColBERT (SciBERT) ----
MRR@5 on dev set: 0.6775357142857142


In [39]:
# ---- ColBERT Re-Ranking @ Margin 0.6 ----
model_name = "colB_sciBERT_marg06"

results_dev = get_performance_mrr(df_query_dev, 'cord_uid', f'{model_name}_topk')
print("---- Re-Ranking Finetune: ColBERT (SciBERT) ----")
print(f"MRR@5 on dev set: {results_dev[5]}")

---- Re-Ranking Finetune: ColBERT (SciBERT) ----
MRR@5 on dev set: 0.6532380952380952


In [25]:
# ---- SciBERT + MLP Re-Ranking ----

results_dev = get_performance_mrr(df_query_dev, 'cord_uid', 'sciB_mlp_1_topk')
print("---- Re-Ranking Finetune: ColBERT (SciBERT) ----")
print(f"MRR@5 on dev set: {results_dev[5]}")

---- Re-Ranking Finetune: ColBERT (SciBERT) ----
MRR@5 on dev set: 0.616095238095238


In [32]:
# ---- SciBERT + MLP Re-Ranking ----

results_dev = get_performance_mrr(df_query_dev, 'cord_uid', 'sciB_matchpyramid_1_topk')
print("---- Re-Ranking Finetune: ColBERT (SciBERT) ----")
print(f"MRR@5 on dev set: {results_dev[5]}")

---- Re-Ranking Finetune: ColBERT (SciBERT) ----
MRR@5 on dev set: 0.5443095238095238


## Results documentation

### 1) SciBERT in ColBERT architecture
Re-Ranking of top 50 BM25 results for each query:
- SciBERT (ColBERT)
    - used SciBERT model out of the box
- SciBERT (ColBERT) @ 2EP
    - fine tuned SciBERT in 2 epochs on train data
- SciBERT (ColBERT) @ 4EP
    - based on previous fine tune (SciBERT (ColBERT) @ 2EP), fine tuned SciBERT in additional 2 epochs, resulting in 4 epochs fine tune compared to out of the box SciBERT
- SciBERT (ColBERT) @ 8EP
    - based on previous fine tune (SciBERT (ColBERT) @ 4EP), fine tuned SciBERT in additional 4 epochs, resulting in 8 epochs fine tune compared to out of the box SciBERT

Hyperparameters:
- BATCH_SIZE = 8
- LR = 2e-5
- MARGIN = 0.2

Loss Function: <br>
loss = F.relu(MARGIN + score_neg_batch - score_pos_batch).mean()
<br>
<br>

|Model                        | MRR@5 (dev)     |
|-----------------------------|-----------------|
|BM25 (baseline)              |55.20%           |
|SciBERT (ColBERT)            |56.94%           | 
|SciBERT (ColBERT) @ 2EP      |63.51%           |
|SciBERT (ColBERT) @ 4EP      |64.02%           |
|SciBERT (ColBERT) @ 8EP      |62.26%           |

#### 1.1) Grid search for finding best margin and number of negative samples
| Margin | Negatives | MRR@5     | Last Epoch Loss |
|--------|-----------|-----------|--------------|
| 0.3    | 1         | 0.6537    | 8.0212       |
| 0.3    | 2         | 0.6477    | 16.7373      |
| 0.3    | 4         | 0.6118    | 29.3307      |
| 0.4    | 1         | 0.6574    | 12.9181      |
| 0.4    | 2         | 0.6503    | 21.4273      |
| 0.4    | 4         | 0.6360    | 27.8490      |
| **0.5**    | **1**         | **0.6610**    | **9.6190**       |
| 0.5    | 2         | 0.6393    | 18.1380      |
| 0.5    | 4         | 0.6259    | 31.5070      |

For margin 0.6 MRR@5 goes down again: 0.6532

# 5) Exporting results to prepare the submission on Codalab

In [15]:
model_name = "colB_sciB_marg05"

df_query_test['preds'] = df_query_test[f'{model_name}_topk'].apply(lambda x: x[:5])
df_query_test["post_id"] = df_query_test["post_id"].astype(str)

In [16]:
df_query_test[['post_id', 'preds']].to_csv('predictions.tsv', index=None, sep='\t')

## _) Archive

### Grid search for best parameters

Trying out different combinations of margin and number of negative document samples for training. <br>
Result: **margin of 0.5** seems to work best.

In [11]:
def get_performance_mrr(data, col_gold, col_pred, list_k = [1, 5, 10]):
    d_performance = {}
    for k in list_k:
        data["in_topx"] = data.apply(lambda x: (1/([i for i in x[col_pred][:k]].index(x[col_gold]) + 1) if x[col_gold] in [i for i in x[col_pred][:k]] else 0), axis=1)
        #performances.append(data["in_topx"].mean())
        d_performance[k] = data["in_topx"].mean()
    return d_performance

In [69]:
from itertools import product
import random, torch, numpy as np

param_grid = {
    "margin":        [0.3, 0.4, 0.5],
    "num_negatives": [1, 2, 4]
}
keys, values = zip(*param_grid.items())

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

for i, (margin, num_neg) in enumerate(product(*values), 1):
    set_seed()
    model_name = f"colB_sciB_m{str(margin)[-1]}_neg{num_neg}"
    print(f"\n🔍 Combination {i}/9 → {model_name}")

    scibert_finetune(
        save_name      = model_name,
        MARGIN         = margin,
        BATCH_SIZE     = 8,
        EPOCHS         = 6,
        LR             = 2e-5,
        num_negatives  = num_neg
    )

    # ---------- Evaluation ----------
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model     = AutoModel.from_pretrained(model_name)
    metadata  = pre_compute_embeddings(model_name)

    df_query_dev = rerank(df_query_dev, metadata, tokenizer, model, model_name)
    mrr_scores   = get_performance_mrr(df_query_dev, "cord_uid", f"{model_name}_topk")
    print("MRR@5:", mrr_scores[5])



🔍 Combination 1/9 → colB_sciB_m3_neg1


Epoch 1:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 1 Loss: 121.9535


Epoch 2:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 2 Loss: 40.3961


Epoch 3:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 3 Loss: 23.0165


Epoch 4:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 4 Loss: 12.7019


Epoch 5:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 5 Loss: 7.5942


Epoch 6:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 6 Loss: 8.0212


  0%|          | 0/7718 [00:00<?, ?it/s]

  0%|          | 0/1400 [00:00<?, ?it/s]

MRR@5: 0.6537142857142857

🔍 Combination 2/9 → colB_sciB_m3_neg2


Epoch 1:   0%|          | 0/3214 [00:00<?, ?it/s]

Epoch 1 Loss: 198.2437


Epoch 2:   0%|          | 0/3214 [00:00<?, ?it/s]

Epoch 2 Loss: 72.1588


Epoch 3:   0%|          | 0/3214 [00:00<?, ?it/s]

Epoch 3 Loss: 46.9973


Epoch 4:   0%|          | 0/3214 [00:00<?, ?it/s]

Epoch 4 Loss: 28.1074


Epoch 5:   0%|          | 0/3214 [00:00<?, ?it/s]

Epoch 5 Loss: 23.5954


Epoch 6:   0%|          | 0/3214 [00:00<?, ?it/s]

Epoch 6 Loss: 16.7373


  0%|          | 0/7718 [00:00<?, ?it/s]

  0%|          | 0/1400 [00:00<?, ?it/s]

MRR@5: 0.6477142857142858

🔍 Combination 3/9 → colB_sciB_m3_neg4


Epoch 1:   0%|          | 0/6427 [00:00<?, ?it/s]

Epoch 1 Loss: 352.6436


Epoch 2:   0%|          | 0/6427 [00:00<?, ?it/s]

Epoch 2 Loss: 134.0740


Epoch 3:   0%|          | 0/6427 [00:00<?, ?it/s]

Epoch 3 Loss: 63.8498


Epoch 4:   0%|          | 0/6427 [00:00<?, ?it/s]

Epoch 4 Loss: 42.9125


Epoch 5:   0%|          | 0/6427 [00:00<?, ?it/s]

Epoch 5 Loss: 31.2010


Epoch 6:   0%|          | 0/6427 [00:00<?, ?it/s]

Epoch 6 Loss: 29.3307


  0%|          | 0/7718 [00:00<?, ?it/s]

  0%|          | 0/1400 [00:00<?, ?it/s]

MRR@5: 0.6118333333333333

🔍 Combination 4/9 → colB_sciB_m4_neg1


Epoch 1:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 1 Loss: 137.3641


Epoch 2:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 2 Loss: 52.0689


Epoch 3:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 3 Loss: 23.6070


Epoch 4:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 4 Loss: 17.4979


Epoch 5:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 5 Loss: 9.9772


Epoch 6:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 6 Loss: 12.9181


  0%|          | 0/7718 [00:00<?, ?it/s]

  0%|          | 0/1400 [00:00<?, ?it/s]

MRR@5: 0.6574285714285714

🔍 Combination 5/9 → colB_sciB_m4_neg2


Epoch 1:   0%|          | 0/3214 [00:00<?, ?it/s]

Epoch 1 Loss: 227.4885


Epoch 2:   0%|          | 0/3214 [00:00<?, ?it/s]

Epoch 2 Loss: 92.5719


Epoch 3:   0%|          | 0/3214 [00:00<?, ?it/s]

Epoch 3 Loss: 51.2666


Epoch 4:   0%|          | 0/3214 [00:00<?, ?it/s]

Epoch 4 Loss: 30.8887


Epoch 5:   0%|          | 0/3214 [00:00<?, ?it/s]

Epoch 5 Loss: 23.0024


Epoch 6:   0%|          | 0/3214 [00:00<?, ?it/s]

Epoch 6 Loss: 21.4273


  0%|          | 0/7718 [00:00<?, ?it/s]

  0%|          | 0/1400 [00:00<?, ?it/s]

MRR@5: 0.6502738095238095

🔍 Combination 6/9 → colB_sciB_m4_neg4


Epoch 1:   0%|          | 0/6427 [00:00<?, ?it/s]

Epoch 1 Loss: 359.6384


Epoch 2:   0%|          | 0/6427 [00:00<?, ?it/s]

Epoch 2 Loss: 131.8794


Epoch 3:   0%|          | 0/6427 [00:00<?, ?it/s]

Epoch 3 Loss: 63.0106


Epoch 4:   0%|          | 0/6427 [00:00<?, ?it/s]

Epoch 4 Loss: 50.9670


Epoch 5:   0%|          | 0/6427 [00:00<?, ?it/s]

Epoch 5 Loss: 38.3159


Epoch 6:   0%|          | 0/6427 [00:00<?, ?it/s]

Epoch 6 Loss: 27.8490


  0%|          | 0/7718 [00:00<?, ?it/s]

  0%|          | 0/1400 [00:00<?, ?it/s]

MRR@5: 0.6360357142857143

🔍 Combination 7/9 → colB_sciB_m5_neg1


Epoch 1:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 1 Loss: 160.1443


Epoch 2:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 2 Loss: 65.4520


Epoch 3:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 3 Loss: 29.0368


Epoch 4:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 4 Loss: 29.2456


Epoch 5:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 5 Loss: 15.5198


Epoch 6:   0%|          | 0/1607 [00:00<?, ?it/s]

Epoch 6 Loss: 9.6190


  0%|          | 0/7718 [00:00<?, ?it/s]

  0%|          | 0/1400 [00:00<?, ?it/s]

MRR@5: 0.6610238095238096

🔍 Combination 8/9 → colB_sciB_m5_neg2


Epoch 1:   0%|          | 0/3214 [00:00<?, ?it/s]

Epoch 1 Loss: 257.0342


Epoch 2:   0%|          | 0/3214 [00:00<?, ?it/s]

Epoch 2 Loss: 89.1174


Epoch 3:   0%|          | 0/3214 [00:00<?, ?it/s]

Epoch 3 Loss: 41.8070


Epoch 4:   0%|          | 0/3214 [00:00<?, ?it/s]

Epoch 4 Loss: 29.2456


Epoch 5:   0%|          | 0/3214 [00:00<?, ?it/s]

Epoch 5 Loss: 21.2743


Epoch 6:   0%|          | 0/3214 [00:00<?, ?it/s]

Epoch 6 Loss: 18.1380


  0%|          | 0/7718 [00:00<?, ?it/s]

  0%|          | 0/1400 [00:00<?, ?it/s]

MRR@5: 0.639345238095238

🔍 Combination 9/9 → colB_sciB_m5_neg4


Epoch 1:   0%|          | 0/6427 [00:00<?, ?it/s]

Epoch 1 Loss: 395.6334


Epoch 2:   0%|          | 0/6427 [00:00<?, ?it/s]

Epoch 2 Loss: 110.6216


Epoch 3:   0%|          | 0/6427 [00:00<?, ?it/s]

Epoch 3 Loss: 54.9982


Epoch 4:   0%|          | 0/6427 [00:00<?, ?it/s]

Epoch 4 Loss: 41.8254


Epoch 5:   0%|          | 0/6427 [00:00<?, ?it/s]

Epoch 5 Loss: 32.5702


Epoch 6:   0%|          | 0/6427 [00:00<?, ?it/s]

Epoch 6 Loss: 31.5070


  0%|          | 0/7718 [00:00<?, ?it/s]

  0%|          | 0/1400 [00:00<?, ?it/s]

MRR@5: 0.6259285714285714


### 3.1) Baseline: ColBERT architecture with SciBERT
Use a pretrained SciBERT model to:
- embed each query-token
- embed each doc-token (can be pre-computed)

For each query-doc pair:
- calculate match-matrix: each query-token – doc-token pair gets cosine similarity value
- aggregate the score: 
    - for each query-token take max cosine similarity value with corresponding doc-tokens
    - sum over all of the max elements

In [14]:
def get_token_embeddings(text, tokenizer, model, device='cpu'):
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    outputs = model(**inputs)
    token_embeddings = outputs.last_hidden_state.squeeze(0)
    attention_mask = inputs['attention_mask'].squeeze(0).bool()
    token_embeddings = token_embeddings[attention_mask] 
    return token_embeddings

def build_and_save_doc_embeddings(
    docs_df,
    model_name,
    save_dir,
    max_len=512,
    device="cuda"
):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to(device)
    model.eval()

    save_path = Path("doc_embeddings_" + save_dir)
    save_path.mkdir(parents=True, exist_ok=True)

    metadata_path = save_path / "metadata.json"
    if metadata_path.exists():
        with open(metadata_path, "r") as f:
            metadata = json.load(f)
    else:
        metadata = {}

    for i, row in tqdm(docs_df.iterrows(), total=len(docs_df)):
        doc_id = row.get("cord_uid", f"doc_{i}")
        file_path = save_path / f"{doc_id}.pt"

        if file_path.exists() and doc_id in metadata:
            continue

        text = str(row.get('title', '')) + " " + str(row.get('abstract', '')) + " Authors: " + str(row.get('authors', ''))

        inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=max_len)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        output = model(**inputs)
        token_embeddings = output.last_hidden_state.squeeze(0)
        attention_mask = inputs['attention_mask'].squeeze(0).bool()
        token_embeddings = token_embeddings[attention_mask]

        n_tokens = token_embeddings.size(0)
        pad_len = max_len - n_tokens

        if pad_len > 0:
            padding = torch.zeros(pad_len, token_embeddings.size(1), device=device)
            token_embeddings = torch.cat([token_embeddings, padding], dim=0)
        else:
            token_embeddings = token_embeddings[:max_len]

        try:
            torch.save(token_embeddings.cuda(), file_path)
        except Exception as e:
            print(f"Fehler beim Speichern von {doc_id}: {e}")
            continue

        metadata[doc_id] = {
            "title": row.get("title", ""),
            "abstract": row.get("abstract", ""),
            "authors": row.get("authors", ""),
            "length": min(n_tokens, max_len),
            "path": str(file_path)
        }

    with open(metadata_path, "w") as f:
        json.dump(metadata, f)

    return metadata


In [15]:
def pre_compute_embeddings(save_name):
    if not os.path.exists("doc_embeddings_" + save_name + "/metadata.json"):
        metadata = build_and_save_doc_embeddings(df_collection, model_name=model_name, save_dir=save_name, device="cuda")
    else:
        with open("doc_embeddings_" + save_name + "/metadata.json", "r") as f:
            metadata = json.load(f)
    return metadata

In [16]:
def rerank(df, metadata, tokenizer, model, save_name):
    device = next(model.parameters()).device
    df[save_name + '_scores'] = [[] for _ in range(len(df))]

    doc_embeddings = {}
    for doc_id, data in metadata.items():
        emb = torch.load(data["path"], map_location="cpu")
        doc_embeddings[doc_id] = emb

    with torch.no_grad():
        for idx, row in tqdm(df.iterrows(), total=len(df)):
            tweet_text = row['tweet_text']
            pre_ranked_docs = row['bm25_topk']

            q_emb = get_token_embeddings(tweet_text, tokenizer, model).to(device)
            q_norm = q_emb / q_emb.norm(dim=1, keepdim=True)

            scores = []
            for doc in pre_ranked_docs:
                emb = doc_embeddings[doc].to(device)
                length = metadata[doc]["length"]
                d_emb = emb[:length]
                d_norm = d_emb / d_emb.norm(dim=1, keepdim=True)

                sim_matrix = torch.matmul(q_norm, d_norm.T)
                max_sim_per_q = sim_matrix.max(dim=1).values
                score = max_sim_per_q.sum().item()
                scores.append(score)

            df.at[idx, save_name + '_scores'] = scores

    def sort_docs_by_score(row):
        doc_ids = row['bm25_topk']
        scores = row[save_name + '_scores']
        sorted_docs = [doc for doc, _ in sorted(zip(doc_ids, scores), key=lambda x: x[1], reverse=True)]
        return sorted_docs

    df[save_name + '_topk'] = df.apply(sort_docs_by_score, axis=1)
    return df

In [17]:
# settings for model run:
model_name = "allenai/scibert_scivocab_uncased"
save_name = "scibert_baseline"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

  return self.fget.__get__(instance, owner)()


In [18]:
# pre-compute embeddings
metadata = pre_compute_embeddings(save_name)

In [None]:
# re-rank BM25 list for dev data
df_query_dev = rerank(df_query_dev, metadata, tokenizer, model, save_name)

In [None]:
# re-rank BM25 list for train data
df_query_train = rerank(df_query_train, metadata, tokenizer, model)

### ColBERT w/ fine-tuned SciBERT for docs and CTBERT for queries

In [23]:
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn.functional as F
from tqdm import tqdm
import json
import random
import torch.nn as nn

# hyperparameter
BATCH_SIZE = 8
EPOCHS = 4
LR = 2e-5
MARGIN = 0.2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# load model
query_model_name = "digitalepidemiologylab/covid-twitter-bert"
doc_model_name = "allenai/scibert_scivocab_uncased"

query_tokenizer = AutoTokenizer.from_pretrained(query_model_name)
doc_tokenizer = AutoTokenizer.from_pretrained(doc_model_name)

query_model = AutoModel.from_pretrained(query_model_name).to(DEVICE)
doc_model = AutoModel.from_pretrained(doc_model_name).to(DEVICE)

# projection to same dimensions
query_projection = nn.Linear(1024, 768).to(DEVICE)

# dataset creation
class ColBERTTripletDataset(Dataset):
    def __init__(self, df, metadata, num_negatives=1):
        self.data = []
        self.metadata = metadata
        for _, row in df.iterrows():
            query = row["tweet_text"]
            pos = row["cord_uid"]
            negatives = [doc for doc in row["bm25_topk"] if doc != pos]
            if negatives:
                for _ in range(num_negatives):
                    neg = random.choice(negatives)
                    self.data.append((query, pos, neg))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# helper functions
def get_token_embeddings(text, tokenizer, model, device='cpu'):
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    outputs = model(**inputs)
    token_embeddings = outputs.last_hidden_state.squeeze(0)
    attention_mask = inputs['attention_mask'].squeeze(0).bool()
    return token_embeddings[attention_mask]

def get_doc_embedding(doc_id, metadata, tokenizer, model, device):
    text = f"{metadata[doc_id]['title']} {metadata[doc_id]['abstract']} Authors: {metadata[doc_id]['authors']}"
    return get_token_embeddings(text, tokenizer, model, device)

def colbert_score_from_emb(q_emb, d_emb):
    q_norm = q_emb / q_emb.norm(dim=1, keepdim=True)
    d_norm = d_emb / d_emb.norm(dim=1, keepdim=True)
    sim_matrix = torch.matmul(q_norm, d_norm.T)
    return sim_matrix.max(dim=1).values.sum()

# training function
def train_colbert_dual_encoder(df_query_train, metadata):
    dataset = ColBERTTripletDataset(df_query_train, metadata)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    query_model.train()
    doc_model.train()
    query_projection.train()

    optimizer = torch.optim.AdamW(
        list(query_model.parameters()) +
        list(doc_model.parameters()) +
        list(query_projection.parameters()),
        lr=LR
    )

    for epoch in range(EPOCHS):
        total_loss = 0.0
        for batch in tqdm(loader, desc=f"Epoch {epoch+1}"):
            queries, pos_ids, neg_ids = batch

            inputs = query_tokenizer(list(queries), return_tensors='pt', padding=True, truncation=True, max_length=512)
            inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
            outputs = query_model(**inputs)
            q_emb_batch = outputs.last_hidden_state
            attention_mask = inputs["attention_mask"].bool()
            q_embs = [query_projection(emb[mask]) for emb, mask in zip(q_emb_batch, attention_mask)]

            score_pos_list = []
            score_neg_list = []

            for i in range(len(queries)):
                d_pos_emb = get_doc_embedding(pos_ids[i], metadata, doc_tokenizer, doc_model, DEVICE)
                d_neg_emb = get_doc_embedding(neg_ids[i], metadata, doc_tokenizer, doc_model, DEVICE)
                q_emb = q_embs[i]

                score_pos = colbert_score_from_emb(q_emb, d_pos_emb)
                score_neg = colbert_score_from_emb(q_emb, d_neg_emb)

                score_pos_list.append(score_pos)
                score_neg_list.append(score_neg)

            score_pos_batch = torch.stack(score_pos_list)
            score_neg_batch = torch.stack(score_neg_list)

            loss = F.relu(MARGIN + score_neg_batch - score_pos_batch).mean()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            total_loss += loss.item()

        print(f"Epoch {epoch+1} Loss: {total_loss:.4f}")

    query_model.save_pretrained("finetuned_ctbert")
    query_tokenizer.save_pretrained("finetuned_ctbert")
    doc_model.save_pretrained("finetuned_scibert")
    doc_tokenizer.save_pretrained("finetuned_scibert")
    torch.save(query_projection.state_dict(), "query_projection.pt")


In [None]:
train_colbert_dual_encoder(df_query_train, metadata)

In [32]:
def rerank(df, metadata, tokenizer, q_model, d_model, save_name):
    device_q = next(q_model.parameters()).device
    device_d = next(d_model.parameters()).device
    df[save_name + '_scores'] = [[] for _ in range(len(df))]

    query_projection = nn.Linear(1024, 768)
    query_projection.load_state_dict(torch.load("query_projection.pt"))
    query_projection.to(device_q)
    query_projection.eval()

    doc_embeddings = {}
    for doc_id, data in metadata.items():
        emb = torch.load(data["path"], map_location="cpu")
        doc_embeddings[doc_id] = emb

    with torch.no_grad():
        for idx, row in tqdm(df.iterrows(), total=len(df)):
            tweet_text = row['tweet_text']
            pre_ranked_docs = row['bm25_topk']

            q_emb = get_token_embeddings(tweet_text, tokenizer, q_model, device=device_q)
            q_emb = query_projection(q_emb)
            q_norm = q_emb / q_emb.norm(dim=1, keepdim=True)

            scores = []
            for doc in pre_ranked_docs:
                emb = doc_embeddings[doc].to(device_d)
                length = metadata[doc]["length"]
                d_emb = emb[:length]
                d_norm = d_emb / d_emb.norm(dim=1, keepdim=True)

                sim_matrix = torch.matmul(q_norm, d_norm.T)
                max_sim_per_q = sim_matrix.max(dim=1).values
                score = max_sim_per_q.sum().item()
                scores.append(score)

            df.at[idx, save_name + '_scores'] = scores

    def sort_docs_by_score(row):
        doc_ids = row['bm25_topk']
        scores = row[save_name + '_scores']
        sorted_docs = [doc for doc, _ in sorted(zip(doc_ids, scores), key=lambda x: x[1], reverse=True)]
        return sorted_docs

    df[save_name + '_topk'] = df.apply(sort_docs_by_score, axis=1)
    return df

In [25]:
# pre-compute embeddings
metadata = pre_compute_embeddings("finetuned_scibert")

100%|██████████| 7718/7718 [03:40<00:00, 34.94it/s]


In [33]:
# re-rank BM25 list for dev data
df_query_dev = rerank(df_query_dev, metadata, tokenizer, query_model, doc_model, "sciCtBERT-1")

100%|██████████| 1400/1400 [02:21<00:00,  9.89it/s]
