In [1]:
import pandas as pd
import numpy as np
import time
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, util

In [2]:
# Load SBERT model
model_name = "all-MiniLM-L6-v2"
print(f"Loading model: {model_name}")
start_time = time.time()
model = SentenceTransformer(model_name)
load_time = time.time() - start_time
print(f"Model loaded in {load_time:.2f} seconds.")

Loading model: all-MiniLM-L6-v2
Model loaded in 2.90 seconds.


In [3]:
# Load Data
PATH_COLLECTION_DATA = 'subtask4b_collection_data.pkl'

df_collection = pd.read_pickle(PATH_COLLECTION_DATA)

PATH_QUERY_TRAIN_DATA = 'subtask4b_query_tweets_train.tsv'
PATH_QUERY_DEV_DATA = 'subtask4b_query_tweets_dev.tsv'

df_query_train = pd.read_csv(PATH_QUERY_TRAIN_DATA, sep = '\t')
df_query_dev = pd.read_csv(PATH_QUERY_DEV_DATA, sep = '\t')

In [4]:
# Check data
print("Sample tweet train:", df_query_train['tweet_text'].iloc[0])
print("Sample tweet dev:", df_query_dev['tweet_text'].iloc[0])
print("Sample doc title:", df_collection['title'].iloc[0])

df_query_dev.head()

Sample tweet train: Oral care in rehabilitation medicine: oral vulnerability, oral muscle wasting, and hospital-associated oral issues
Sample tweet dev: covid recovery: this study from the usa reveals that a proportion of cases experience impairment in some cognitive functions for several months after infection. some possible biases &amp; limitations but more research is required on impact of these long term effects.
Sample doc title: Professional and Home-Made Face Masks Reduce Exposure to Respiratory Infections among the General Population


Unnamed: 0,post_id,tweet_text,cord_uid
0,16,covid recovery: this study from the usa reveal...,3qvh482o
1,69,"""Among 139 clients exposed to two symptomatic ...",r58aohnu
2,73,I recall early on reading that researchers who...,sts48u9i
3,93,You know you're credible when NIH website has ...,3sr2exq9
4,96,Resistance to antifungal medications is a grow...,ybwwmyqy


In [5]:
# Encode documents
print("Encoding documents...")
start = time.time()
df_collection['full_text'] = df_collection['title'].fillna('') + " " + df_collection['abstract'].fillna('')
doc_embeddings = model.encode(df_collection['full_text'].tolist(), show_progress_bar=True, convert_to_tensor=True)
doc_encoding_time = time.time() - start
print(f"Document encoding time: {doc_encoding_time:.2f} seconds")

Encoding documents...


Batches:   0%|          | 0/242 [00:00<?, ?it/s]

Document encoding time: 584.28 seconds


In [6]:
print("Encoding tweets...")
start = time.time()
tweet_embeddings = model.encode(df_query_dev['tweet_text'].tolist(), convert_to_tensor=True, show_progress_bar=True)
tweet_encoding_time = time.time() - start
print(f"Tweet encoding time: {tweet_encoding_time:.2f} seconds")

Encoding tweets...


Batches:   0%|          | 0/44 [00:00<?, ?it/s]

Tweet encoding time: 142.30 seconds


In [7]:
# Similarity & Top-5 Predictions
import torch

predictions = []

tweet_texts = df_query_dev['tweet_text'].tolist()
tweet_ids = df_query_dev['post_id'].tolist()
true_labels = df_query_dev['cord_uid'].tolist()

doc_texts = df_collection['full_text'].tolist()
doc_uids = df_collection['cord_uid'].tolist()

for i in tqdm(range(len(tweet_embeddings))):
    tweet_vec = tweet_embeddings[i]
    cosine_scores = util.cos_sim(tweet_vec, doc_embeddings)[0]
    top_results = torch.topk(cosine_scores, k=5)
    top_indices = top_results.indices.tolist()
    top_cord_uids = [doc_uids[idx] for idx in top_indices]

    predictions.append({
        'post_id': tweet_ids[i],
        'tweet_text': tweet_texts[i],
        'true': true_labels[i],
        'preds': top_cord_uids
    })


100%|███████████████████████████████████████| 1400/1400 [00:27<00:00, 51.78it/s]


In [8]:
# MRR@5 Evaluation
def mrr_at_k(predictions, k=5):
    total_mrr = 0
    for pred in predictions:
        if pred['true'] in pred['preds']:
            rank = pred['preds'].index(pred['true']) + 1
            total_mrr += 1 / rank
    return total_mrr / len(predictions)

mrr5 = mrr_at_k(predictions, k=5)
print(f"MRR@5 for all-MiniLM-L6-v2: {mrr5:.4f}")

MRR@5 for all-MiniLM-L6-v2: 0.4897


In [9]:
df_out = pd.DataFrame()
df_out['post_id'] = [p['post_id'] for p in predictions]
df_out['preds'] = [str(p['preds']) for p in predictions]

df_out.to_csv('predictions_all-MiniLM-L6-v2.tsv', sep='\t', index=False)

In [10]:
df_preds = pd.DataFrame(predictions)

df_preds['length'] = df_preds['tweet_text'].str.len()
short_tweets = df_preds[df_preds['length'] < 80].copy()

def mrr_score(preds, true_label):
    if true_label in preds:
        return 1 / (preds.index(true_label) + 1)
    else:
        return 0

short_tweets['hit_in_top5'] = short_tweets.apply(lambda row: row['true'] in row['preds'], axis=1)
short_tweets['mrr@5'] = short_tweets.apply(lambda row: mrr_score(row['preds'], row['true']), axis=1)

print("Short tweet count:", len(short_tweets))
print("MRR@5 on short tweets:", short_tweets['mrr@5'].mean())
print("Accuracy (hit in top-5):", short_tweets['hit_in_top5'].mean())

short_tweets[['post_id', 'tweet_text', 'true', 'preds', 'hit_in_top5', 'mrr@5']].head(10)

Short tweet count: 32
MRR@5 on short tweets: 0.30989583333333337
Accuracy (hit in top-5): 0.40625


Unnamed: 0,post_id,tweet_text,true,preds,hit_in_top5,mrr@5
112,1119,it doesn't stop it. it reduces the risk of it....,w0ebmg16,"[9ezhwvv9, pecyac7l, wk61uyrt, 65n6p550, ropgq...",False,0.0
126,1229,How can indoor spread of COVID-19 through the ...,od5nnxvg,"[aawjla6h, 4p6fcy8f, gqwwfpch, 5zn5mgi9, je585...",False,0.0
210,2132,Lives lost to covid-19 in 81 countries #ovhea...,6a728le9,"[pn516wom, 6a728le9, 5053t5ki, xolflz8g, ef4wy...",True,0.5
249,2492,Hospital admission rate is 10x higher in unvac...,rpjg4a9i,"[xjc0l0tv, z1y1zgo8, 929rrh59, 9sh9mk6p, yaedo...",False,0.0
271,2719,Vitamins and risk of COVID-19,ikon1ktb,"[l4zku2e9, z2jtzsl6, md0drb25, lgtpeqhw, o8nf7...",False,0.0
308,3071,Death from COVID-19 isn't the only issue. Brai...,dogsza0f,"[x7qlnugx, ok9o9tta, byvsuvn0, 6mfd3n4s, xolfl...",False,0.0
349,3491,Bile salts in gut and liver pathophysiology,mlozjg9h,"[mlozjg9h, 306381wy, 2199ydle, fjzhe9tp, 3xpfj...",True,1.0
385,3874,the vaccine can worsen covid.... why would you...,rb20ge7e,"[tcby6780, rfv0omd6, pofysmv8, 72jwlfqr, 20xgq...",False,0.0
417,4218,Masks function. Vaccines function. Your refu...,1s8jzzwg,"[lkb09vs8, u8mu4yga, 4mx9t5td, z86g8dzs, 9qsqj...",False,0.0
485,4894,human IgG neutralizing monoclonal antibodies b...,ypls4zau,"[ryj83uw3, 0tn06al2, a1pa6g5c, 40fvjskj, hb2bp...",False,0.0


In [11]:
wrong_matches = []
correct_matches = []
for tweet_id, tweet_text, true_uid, tweet_emb in zip(tweet_ids, tweet_texts, true_labels, tweet_embeddings):
    cosine_scores = util.cos_sim(tweet_emb, doc_embeddings)[0]

    top_k = 5
    top_results = torch.topk(cosine_scores, k=top_k)

    top_indices = top_results.indices.tolist()
    top_similarities = top_results.values.tolist()
    top_cord_uids = [doc_uids[idx] for idx in top_indices]

    result = {
        'post_id': tweet_id,
        'tweet_text': tweet_text,
        'top_5': top_cord_uids,
        'true_cord_uid': true_uid,
        'similarity_scores': top_similarities
    }

    if true_uid in top_cord_uids:
        correct_matches.append(result)
    else:
        wrong_matches.append(result)

df_correct_matches = pd.DataFrame(correct_matches)
df_wrong_matches = pd.DataFrame(wrong_matches)

print(f"Correct matches: {len(df_correct_matches)}")
print(f"Wrong matches: {len(df_wrong_matches)}")

from IPython.display import display

print("Sample Wrong Matches:")
display(df_wrong_matches[['post_id', 'tweet_text', 'top_5', 'true_cord_uid', 'similarity_scores']].head())

print("Sample Correct Matches:")
display(df_correct_matches[['post_id', 'tweet_text', 'top_5', 'true_cord_uid', 'similarity_scores']].head())

Correct matches: 846
Wrong matches: 554
Sample Wrong Matches:


Unnamed: 0,post_id,tweet_text,top_5,true_cord_uid,similarity_scores
0,116,IL-6 seems to be a primary catalyst of this un...,"[vx9vqr1k, z9jqbliw, pxhetma5, zt5alyy2, 3r418...",8cvjsisw,"[0.7486677169799805, 0.6814455389976501, 0.671..."
1,150,Significant vitamin D deficiency in people wit...,"[n24k9s1s, dx2cx9lx, z2jtzsl6, wjdif3r3, l4zku...",be8eu3qi,"[0.8709943890571594, 0.8707948923110962, 0.870..."
2,158,The wearing of masks is associated with reduce...,"[f96qs295, jjh1z5c6, zue5hnal, 0clp6zt6, 4mx9t...",9b6cepf4,"[0.5498321056365967, 0.5204370617866516, 0.520..."
3,169,Let's not forget our unheralded heroes? peer-...,"[24u6q3ae, xnxn506o, bw6a5gmy, s3vaa0yc, imheo...",z9vjo98p,"[0.8383280038833618, 0.7862822413444519, 0.777..."
4,173,"Here's your proof, Hannah betting on ""herd im...","[dpws8p4y, vjpf1fk6, ruewzstg, 5kz2s7ag, u1ilc...",q77tr31d,"[0.635127604007721, 0.6033356785774231, 0.5958..."


Sample Correct Matches:


Unnamed: 0,post_id,tweet_text,top_5,true_cord_uid,similarity_scores
0,16,covid recovery: this study from the usa reveal...,"[3qvh482o, jrqlhjsm, 8t2tic9n, nksd3wuw, styav...",3qvh482o,"[0.827496349811554, 0.7675719261169434, 0.7385..."
1,69,"""Among 139 clients exposed to two symptomatic ...","[r58aohnu, eay6qfhz, 8je46886, u8mu4yga, yrowv...",r58aohnu,"[0.8370003700256348, 0.6915689706802368, 0.665..."
2,73,I recall early on reading that researchers who...,"[qkg8fwbp, sts48u9i, u5nxm9tu, lp0r7j5c, myqli...",sts48u9i,"[0.6119166016578674, 0.5853146910667419, 0.585..."
3,93,You know you're credible when NIH website has ...,"[jo38hjqa, mgtxchud, 9mdf927z, l5ogbl5p, 3sr2e...",3sr2exq9,"[0.6730215549468994, 0.659010648727417, 0.6428..."
4,96,Resistance to antifungal medications is a grow...,"[ybwwmyqy, vabb2f26, rs3umc1x, 3l6ipiwk, lzddn...",ybwwmyqy,"[0.7652779817581177, 0.5990149974822998, 0.565..."


In [13]:
from gensim.models import KeyedVectors

model = KeyedVectors.load_word2vec_format('vec/wiki-news-300d-1M.vec', binary=False)

def get_avg_embedding(text, model):
    tokens = text.split()
    embeddings = [model[word] for word in tokens if word in model]
    if not embeddings:
        return np.zeros(model.vector_size)
    return np.mean(embeddings, axis=0)

from sklearn.metrics.pairwise import cosine_similarity

# Precompute document embeddings
doc_texts = df_collection['full_text'].tolist()
doc_ids = df_collection['cord_uid'].tolist()
doc_embs = [get_avg_embedding(text, model) for text in tqdm(doc_texts, desc='Encoding docs')]

k = 5
predictions = []
for _, row in tqdm(df_query_dev.iterrows(), total=len(df_query_dev), desc='Processing queries'):
    tweet_vec = get_avg_embedding(row['tweet_text'], model)
    similarities = cosine_similarity([tweet_vec], doc_embs)[0]
    top_k_indices = np.argsort(similarities)[-k:][::-1]
    top_k_doc_ids = [doc_ids[i] for i in top_k_indices]

    predictions.append({
        'post_id': row['post_id'],
        'true': row['cord_uid'],
        'preds': top_k_doc_ids
    })

mrr_score = mrr_at_k(predictions, k=5)
print(f"MRR@{k}: {mrr_score:.4f}")

Encoding docs: 100%|██████████████████████| 7718/7718 [00:04<00:00, 1655.84it/s]
Processing queries: 100%|███████████████████| 1400/1400 [00:23<00:00, 60.13it/s]

MRR@5: 0.2972





In [15]:
model = KeyedVectors.load_word2vec_format('vec/crawl-300d-2M.vec', binary=False)
doc_embs = [get_avg_embedding(text, model) for text in tqdm(doc_texts, desc='Encoding docs')]
k = 5
predictions = []
for _, row in tqdm(df_query_dev.iterrows(), total=len(df_query_dev), desc='Processing queries'):
    tweet_vec = get_avg_embedding(row['tweet_text'], model)
    similarities = cosine_similarity([tweet_vec], doc_embs)[0]
    top_k_indices = np.argsort(similarities)[-k:][::-1]
    top_k_doc_ids = [doc_ids[i] for i in top_k_indices]

    predictions.append({
        'post_id': row['post_id'],
        'true': row['cord_uid'],
        'preds': top_k_doc_ids
    })

mrr_score = mrr_at_k(predictions, k=5)
print(f"MRR@{k}: {mrr_score:.4f}")

Encoding docs: 100%|██████████████████████| 7718/7718 [00:05<00:00, 1513.87it/s]
Processing queries: 100%|███████████████████| 1400/1400 [00:20<00:00, 67.56it/s]

MRR@5: 0.3402





In [16]:
df_out_fasttext_crawl = pd.DataFrame()
df_out_fasttext_crawl['post_id'] = [p['post_id'] for p in predictions]
df_out_fasttext_crawl['preds'] = [str(p['preds']) for p in predictions]

df_out_fasttext_crawl.to_csv('predictions_fasttext_crawl-300d.tsv', sep='\t', index=False)

In [None]:
from sentence_transformers import SentenceTransformer, InputExample, losses, util
from torch.utils.data import DataLoader
import random

model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')

train_examples = []

all_doc_ids = df_collection['cord_uid'].tolist()

for idx, row in df_query_train.iterrows():
    tweet = row['tweet_text']
    pos_uid = row['cord_uid']

    positive_doc = df_collection[df_collection['cord_uid'] == pos_uid]
    if positive_doc.empty:
        continue
    pos_text = positive_doc.iloc[0]['title'] + " " + positive_doc.iloc[0]['abstract']

    negative_uids = [uid for uid in all_doc_ids if uid != pos_uid]
    neg_uid = random.choice(negative_uids)
    negative_doc = df_collection[df_collection['cord_uid'] == neg_uid]
    if negative_doc.empty:
        continue
    neg_text = negative_doc.iloc[0]['title'] + " " + negative_doc.iloc[0]['abstract']

    train_examples.append(InputExample(texts=[tweet, pos_text, neg_text]))

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
train_loss = losses.TripletLoss(model=model)

model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=3,
    warmup_steps=100,
    show_progress_bar=True
)

model.save('fine-tuned-multi-qa-MiniLM-L6-cos-v1')

model_name = "fine-tuned-multi-qa-MiniLM-L6-cos-v1"
print(f"Loading model: {model_name}")
start_time = time.time()
model = SentenceTransformer(model_name)
load_time = time.time() - start_time
print(f"Model loaded in {load_time:.2f} seconds.")

print("Encoding documents...")
start = time.time()
df_collection['full_text'] = df_collection['title'].fillna('') + " " + df_collection['abstract'].fillna('')
doc_embeddings = model.encode(df_collection['full_text'].tolist(), show_progress_bar=True, convert_to_tensor=True)
doc_encoding_time = time.time() - start
print(f"Document encoding time: {doc_encoding_time:.2f} seconds")

print("Encoding tweets...")
start = time.time()
tweet_embeddings = model.encode(df_query_dev['tweet_text'].tolist(), convert_to_tensor=True, show_progress_bar=True)
tweet_encoding_time = time.time() - start
print(f"Tweet encoding time: {tweet_encoding_time:.2f} seconds")

predictions = []

tweet_texts = df_query_dev['tweet_text'].tolist()
tweet_ids = df_query_dev['post_id'].tolist()
true_labels = df_query_dev['cord_uid'].tolist()

doc_texts = df_collection['full_text'].tolist()
doc_uids = df_collection['cord_uid'].tolist()

for i in tqdm(range(len(tweet_embeddings))):
    tweet_vec = tweet_embeddings[i]
    cosine_scores = util.cos_sim(tweet_vec, doc_embeddings)[0]
    top_results = torch.topk(cosine_scores, k=5)
    top_indices = top_results.indices.tolist()
    top_cord_uids = [doc_uids[idx] for idx in top_indices]

    predictions.append({
        'post_id': tweet_ids[i],
        'tweet_text': tweet_texts[i],
        'true': true_labels[i],
        'preds': top_cord_uids
    })

def mrr_at_k(predictions, k=5):
    total_mrr = 0
    for pred in predictions:
        if pred['true'] in pred['preds']:
            rank = pred['preds'].index(pred['true']) + 1
            total_mrr += 1 / rank
    return total_mrr / len(predictions)

mrr5 = mrr_at_k(predictions, k=5)
print(f"MRR@5 for fine-tuned-all-MiniLM-L6-v2: {mrr5:.4f}")

df_out = pd.DataFrame()
df_out['post_id'] = [p['post_id'] for p in predictions]
df_out['preds'] = [str(p['preds']) for p in predictions]

df_out.to_csv('predictions_multi-qa-MiniLM-L6-cos-v1.tsv', sep='\t', index=False)