In [1]:
import pandas as pd
import numpy as np
import time
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, util

In [2]:
# Load SBERT model
model_name = "multi-qa-MiniLM-L6-cos-v1"
print(f"Loading model: {model_name}")
start_time = time.time()
model = SentenceTransformer(model_name)
load_time = time.time() - start_time
print(f"Model loaded in {load_time:.2f} seconds.")

Loading model: multi-qa-MiniLM-L6-cos-v1
Model loaded in 2.73 seconds.


In [3]:
# Load Data
PATH_COLLECTION_DATA = 'subtask4b_collection_data.pkl'

df_collection = pd.read_pickle(PATH_COLLECTION_DATA)

PATH_QUERY_TRAIN_DATA = 'subtask4b_query_tweets_train.tsv'
PATH_QUERY_DEV_DATA = 'subtask4b_query_tweets_dev.tsv'

df_query_train = pd.read_csv(PATH_QUERY_TRAIN_DATA, sep = '\t')
df_query_dev = pd.read_csv(PATH_QUERY_DEV_DATA, sep = '\t')

In [4]:
# Encode documents
print("Encoding documents...")
start = time.time()
df_collection['full_text'] = df_collection['title'].fillna('') + " " + df_collection['abstract'].fillna('')
doc_embeddings = model.encode(df_collection['full_text'].tolist(), show_progress_bar=True, convert_to_tensor=True)
doc_encoding_time = time.time() - start
print(f"Document encoding time: {doc_encoding_time:.2f} seconds")

Encoding documents...


Batches:   0%|          | 0/242 [00:00<?, ?it/s]

Document encoding time: 1761.70 seconds


In [5]:
print("Encoding tweets...")
start = time.time()
tweet_embeddings = model.encode(df_query_dev['tweet_text'].tolist(), convert_to_tensor=True, show_progress_bar=True)
tweet_encoding_time = time.time() - start
print(f"Tweet encoding time: {tweet_encoding_time:.2f} seconds")

Encoding tweets...


Batches:   0%|          | 0/44 [00:00<?, ?it/s]

Tweet encoding time: 183.12 seconds


In [6]:
# Similarity & Top-5 Predictions
import torch

predictions = []

tweet_texts = df_query_dev['tweet_text'].tolist()
tweet_ids = df_query_dev['post_id'].tolist()
true_labels = df_query_dev['cord_uid'].tolist()

doc_texts = df_collection['full_text'].tolist()
doc_uids = df_collection['cord_uid'].tolist()

for i in tqdm(range(len(tweet_embeddings))):
    tweet_vec = tweet_embeddings[i]
    cosine_scores = util.cos_sim(tweet_vec, doc_embeddings)[0]
    top_results = torch.topk(cosine_scores, k=5)
    top_indices = top_results.indices.tolist()
    top_cord_uids = [doc_uids[idx] for idx in top_indices]

    predictions.append({
        'post_id': tweet_ids[i],
        'tweet_text': tweet_texts[i],
        'true': true_labels[i],
        'preds': top_cord_uids
    })

100%|███████████████████████████████████████| 1400/1400 [00:28<00:00, 48.94it/s]


In [7]:
# MRR@5 Evaluation
def mrr_at_k(predictions, k=5):
    total_mrr = 0
    for pred in predictions:
        if pred['true'] in pred['preds']:
            rank = pred['preds'].index(pred['true']) + 1
            total_mrr += 1 / rank
    return total_mrr / len(predictions)

mrr5 = mrr_at_k(predictions, k=5)
print(f"MRR@5 for multi-qa-MiniLM-L6-cos-v1: {mrr5:.4f}")

MRR@5 for multi-qa-MiniLM-L6-cos-v1: 0.4938


In [8]:
df_out = pd.DataFrame()
df_out['post_id'] = [p['post_id'] for p in predictions]
df_out['preds'] = [str(p['preds']) for p in predictions]

df_out.to_csv('predictions_multi-qa-MiniLM-L6-cos-v1.tsv', sep='\t', index=False)

In [8]:
df_preds = pd.DataFrame(predictions)

df_preds['length'] = df_preds['tweet_text'].str.len()
short_tweets = df_preds[df_preds['length'] < 80].copy()

def mrr_score(preds, true_label):
    if true_label in preds:
        return 1 / (preds.index(true_label) + 1)
    else:
        return 0

short_tweets['hit_in_top5'] = short_tweets.apply(lambda row: row['true'] in row['preds'], axis=1)
short_tweets['mrr@5'] = short_tweets.apply(lambda row: mrr_score(row['preds'], row['true']), axis=1)

print("Short tweet count:", len(short_tweets))
print("MRR@5 on short tweets:", short_tweets['mrr@5'].mean())
print("Accuracy (hit in top-5):", short_tweets['hit_in_top5'].mean())

short_tweets[['post_id', 'tweet_text', 'true', 'preds', 'hit_in_top5', 'mrr@5']].head(10)

Short tweet count: 32
MRR@5 on short tweets: 0.2859375
Accuracy (hit in top-5): 0.4375


Unnamed: 0,post_id,tweet_text,true,preds,hit_in_top5,mrr@5
112,1119,it doesn't stop it. it reduces the risk of it....,w0ebmg16,"[5i35zdmv, rmho6pur, lehzj4d8, 2fryixdh, hl956...",False,0.0
126,1229,How can indoor spread of COVID-19 through the ...,od5nnxvg,"[aawjla6h, od5nnxvg, 5zn5mgi9, 4p6fcy8f, pc2cn...",True,0.5
210,2132,Lives lost to covid-19 in 81 countries #ovhea...,6a728le9,"[pn516wom, 1blzi9r3, n39y3kq2, 7omyaap8, nj94r...",False,0.0
249,2492,Hospital admission rate is 10x higher in unvac...,rpjg4a9i,"[z1y1zgo8, snk26ii3, k2zrdjyo, 6ukt0gbn, cfd1x...",False,0.0
271,2719,Vitamins and risk of COVID-19,ikon1ktb,"[ikon1ktb, gg5c8v7d, m22h669g, lgtpeqhw, ncfvl...",True,1.0
308,3071,Death from COVID-19 isn't the only issue. Brai...,dogsza0f,"[3q3ywthu, ag6lu4em, 7xt894vr, 25aj8rj5, 6mfd3...",False,0.0
349,3491,Bile salts in gut and liver pathophysiology,mlozjg9h,"[mlozjg9h, 4evznllv, 306381wy, r9datawi, k0f4c...",True,1.0
385,3874,the vaccine can worsen covid.... why would you...,rb20ge7e,"[tcby6780, urv9o2f1, vw9jd88a, 72jwlfqr, 7hpor...",False,0.0
417,4218,Masks function. Vaccines function. Your refu...,1s8jzzwg,"[qi1henyy, u8mu4yga, t0iw2vod, 1s8jzzwg, w1bx4...",True,0.25
485,4894,human IgG neutralizing monoclonal antibodies b...,ypls4zau,"[ypls4zau, obhm5mc5, 91ea40nz, nc2sh98g, wq0me...",True,1.0


In [9]:
wrong_matches = []
correct_matches = []
for tweet_id, tweet_text, true_uid, tweet_emb in zip(tweet_ids, tweet_texts, true_labels, tweet_embeddings):
    cosine_scores = util.cos_sim(tweet_emb, doc_embeddings)[0]

    top_k = 5
    top_results = torch.topk(cosine_scores, k=top_k)

    top_indices = top_results.indices.tolist()
    top_similarities = top_results.values.tolist()
    top_cord_uids = [doc_uids[idx] for idx in top_indices]

    result = {
        'post_id': tweet_id,
        'tweet_text': tweet_text,
        'top_5': top_cord_uids,
        'true_cord_uid': true_uid,
        'similarity_scores': top_similarities
    }

    if true_uid in top_cord_uids:
        correct_matches.append(result)
    else:
        wrong_matches.append(result)

df_correct_matches = pd.DataFrame(correct_matches)
df_wrong_matches = pd.DataFrame(wrong_matches)

print(f"Correct matches: {len(df_correct_matches)}")
print(f"Wrong matches: {len(df_wrong_matches)}")

from IPython.display import display

print("Sample Wrong Matches:")
display(df_wrong_matches[['post_id', 'tweet_text', 'top_5', 'true_cord_uid', 'similarity_scores']].head())

print("Sample Correct Matches:")
display(df_correct_matches[['post_id', 'tweet_text', 'top_5', 'true_cord_uid', 'similarity_scores']].head())

Correct matches: 848
Wrong matches: 552
Sample Wrong Matches:


Unnamed: 0,post_id,tweet_text,top_5,true_cord_uid,similarity_scores
0,116,IL-6 seems to be a primary catalyst of this un...,"[zt5alyy2, vx9vqr1k, ozbmgd70, xh723tgl, 9l3x3...",8cvjsisw,"[0.6958510875701904, 0.6509284973144531, 0.650..."
1,150,Significant vitamin D deficiency in people wit...,"[gg5c8v7d, vbnke2q5, vzloj6b3, 0a1m1niu, tpmb3...",be8eu3qi,"[0.8673973083496094, 0.8592214584350586, 0.854..."
2,158,The wearing of masks is associated with reduce...,"[f96qs295, 1s8jzzwg, jjh1z5c6, umvrwgaw, w1bx4...",9b6cepf4,"[0.6995384693145752, 0.6373393535614014, 0.623..."
3,169,Let's not forget our unheralded heroes? peer-...,"[edz3up3a, imheos0p, 78kbutc3, foy3dsq4, cfkh0...",z9vjo98p,"[0.7646439075469971, 0.758554220199585, 0.7375..."
4,173,"Here's your proof, Hannah betting on ""herd im...","[urv9o2f1, li5cw8xx, ztxfa5b8, o9me37ri, lbd6h...",q77tr31d,"[0.6718001961708069, 0.6341763734817505, 0.632..."


Sample Correct Matches:


Unnamed: 0,post_id,tweet_text,top_5,true_cord_uid,similarity_scores
0,16,covid recovery: this study from the usa reveal...,"[3qvh482o, 8t2tic9n, jrqlhjsm, rthsl7a9, trrg1...",3qvh482o,"[0.732806921005249, 0.6959764957427979, 0.6864..."
1,69,"""Among 139 clients exposed to two symptomatic ...","[r58aohnu, atn333j9, d06npvro, a66sszp2, cpbu3...",r58aohnu,"[0.7869398593902588, 0.7137505412101746, 0.651..."
2,73,I recall early on reading that researchers who...,"[sts48u9i, qkg8fwbp, ujq9mxk7, dgizpo1z, ec6ov...",sts48u9i,"[0.6247978210449219, 0.5791239738464355, 0.554..."
3,93,You know you're credible when NIH website has ...,"[i03mrw1i, 6x33a6g6, hapu56t4, 3sr2exq9, f8yph...",3sr2exq9,"[0.677467942237854, 0.6468799114227295, 0.6437..."
4,96,Resistance to antifungal medications is a grow...,"[ybwwmyqy, ierqfgo5, vabb2f26, qh6rif48, fiicx...",ybwwmyqy,"[0.7711814641952515, 0.6013863682746887, 0.572..."


In [11]:
model_name2 = "all-MiniLM-L6-v2"
model2 = SentenceTransformer(model_name2)
doc_embeddings2 = model2.encode(df_collection['full_text'].tolist(), show_progress_bar=True, convert_to_tensor=True)
tweet_embeddings2 = model2.encode(df_query_dev['tweet_text'].tolist(), convert_to_tensor=True, show_progress_bar=True)

predictions = []
k = 5

for i in range(len(tweet_texts)):
    tweet_id = tweet_ids[i]
    tweet_text = tweet_texts[i]
    true_uid = true_labels[i]

    scores1 = util.cos_sim(tweet_embeddings[i], doc_embeddings)[0] #"multi-qa-MiniLM-L6-cos-v1"
    scores2 = util.cos_sim(tweet_embeddings2[i], doc_embeddings2)[0] #"all-MiniLM-L6-v2"

    avg_scores = (scores1 + scores2) / 2.0

    top_k = torch.topk(avg_scores, k=k)
    top_indices = top_k.indices.tolist()
    top_cord_uids = [doc_uids[idx] for idx in top_indices]

    predictions.append({
        'post_id': tweet_id,
        'true': true_uid,
        'preds': top_cord_uids
    })

mrr_score = mrr_at_k(predictions, k=5)
print(f"MRR@{k} (score average): {mrr_score:.4f}")

Batches:   0%|          | 0/242 [00:00<?, ?it/s]

Batches:   0%|          | 0/44 [00:00<?, ?it/s]

MRR@5 (score average): 0.5291


In [12]:
df_out_fasttext_crawl = pd.DataFrame()
df_out_fasttext_crawl['post_id'] = [p['post_id'] for p in predictions]
df_out_fasttext_crawl['preds'] = [str(p['preds']) for p in predictions]

df_out_fasttext_crawl.to_csv('predictions_all-MiniLM_multi-qa.tsv', sep='\t', index=False)

In [21]:
from sentence_transformers import SentenceTransformer, InputExample, losses, util
from torch.utils.data import DataLoader
import random

model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')

train_examples = []

all_doc_ids = df_collection['cord_uid'].tolist()

for idx, row in df_query_train.iterrows():
    tweet = row['tweet_text']
    pos_uid = row['cord_uid']

    positive_doc = df_collection[df_collection['cord_uid'] == pos_uid]
    if positive_doc.empty:
        continue
    pos_text = positive_doc.iloc[0]['title'] + " " + positive_doc.iloc[0]['abstract']

    negative_uids = [uid for uid in all_doc_ids if uid != pos_uid]
    neg_uid = random.choice(negative_uids)
    negative_doc = df_collection[df_collection['cord_uid'] == neg_uid]
    if negative_doc.empty:
        continue
    neg_text = negative_doc.iloc[0]['title'] + " " + negative_doc.iloc[0]['abstract']

    train_examples.append(InputExample(texts=[tweet, pos_text, neg_text]))

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
train_loss = losses.TripletLoss(model=model)

model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=3,
    warmup_steps=100,
    show_progress_bar=True
)

model.save('fine-tuned-multi-qa-MiniLM-L6-cos-v1')

model_name = "fine-tuned-multi-qa-MiniLM-L6-cos-v1"
print(f"Loading model: {model_name}")
start_time = time.time()
model = SentenceTransformer(model_name)
load_time = time.time() - start_time
print(f"Model loaded in {load_time:.2f} seconds.")

print("Encoding documents...")
start = time.time()
df_collection['full_text'] = df_collection['title'].fillna('') + " " + df_collection['abstract'].fillna('')
doc_embeddings = model.encode(df_collection['full_text'].tolist(), show_progress_bar=True, convert_to_tensor=True)
doc_encoding_time = time.time() - start
print(f"Document encoding time: {doc_encoding_time:.2f} seconds")

print("Encoding tweets...")
start = time.time()
tweet_embeddings = model.encode(df_query_dev['tweet_text'].tolist(), convert_to_tensor=True, show_progress_bar=True)
tweet_encoding_time = time.time() - start
print(f"Tweet encoding time: {tweet_encoding_time:.2f} seconds")

predictions = []

tweet_texts = df_query_dev['tweet_text'].tolist()
tweet_ids = df_query_dev['post_id'].tolist()
true_labels = df_query_dev['cord_uid'].tolist()

doc_texts = df_collection['full_text'].tolist()
doc_uids = df_collection['cord_uid'].tolist()

for i in tqdm(range(len(tweet_embeddings))):
    tweet_vec = tweet_embeddings[i]
    cosine_scores = util.cos_sim(tweet_vec, doc_embeddings)[0]
    top_results = torch.topk(cosine_scores, k=5)
    top_indices = top_results.indices.tolist()
    top_cord_uids = [doc_uids[idx] for idx in top_indices]

    predictions.append({
        'post_id': tweet_ids[i],
        'tweet_text': tweet_texts[i],
        'true': true_labels[i],
        'preds': top_cord_uids
    })

def mrr_at_k(predictions, k=5):
    total_mrr = 0
    for pred in predictions:
        if pred['true'] in pred['preds']:
            rank = pred['preds'].index(pred['true']) + 1
            total_mrr += 1 / rank
    return total_mrr / len(predictions)

mrr5 = mrr_at_k(predictions, k=5)
print(f"MRR@5 for fine-tuned-all-MiniLM-L6-v2: {mrr5:.4f}")

df_out = pd.DataFrame()
df_out['post_id'] = [p['post_id'] for p in predictions]
df_out['preds'] = [str(p['preds']) for p in predictions]

df_out.to_csv('predictions_multi-qa-MiniLM-L6-cos-v1.tsv', sep='\t', index=False)

NameError: name 'Dataset' is not defined