In [1]:
import pandas as pd
import numpy as np
from transformers import BertTokenizer, BertModel
import pickle as pkl
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
import json 
from tqdm import tqdm
import os

CHUNK_SIZE = 256
EPOCHS = 3


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [2]:
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document

def pre_process(data):
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=CHUNK_SIZE,
        chunk_overlap=0,
        length_function=lambda x: len(tokenizer.encode(x, add_special_tokens=True))
    )
    pairs = [(chunk, row[0]) for row in tqdm(data[["cord_uid", "abstract"]].to_numpy(), desc="Creating pairs") for chunk in splitter.split_text(row[1])]

    docs = [Document(page_content=pair[0], metadata={"cord_uid": pair[1]}) for pair in pairs]
    titles = [Document(page_content=row[1], metadata={"cord_uid": row[0]}) for row in data[["cord_uid", "title"]].to_numpy()]

    docs.extend(titles)
    return docs

In [3]:
%%script false --no-raise-error
# Preprocessing step.
# If docs.json already exists, skip this step.

with open("../X_Data/subtask4b_collection_data.pkl", "rb") as f:
    data = pkl.load(f)
    collection = pd.DataFrame(data)

docs = pre_process(collection)

for doc in docs:
    if len(tokenizer.encode(doc.page_content, add_special_tokens=True)) > CHUNK_SIZE:
        print(f"Document {doc.metadata['cord_uid']} is too long: {len(doc.page_content)} characters")
        raise Exception("Doc too long")

objs = [
    {
        "cord_uid": doc.metadata["cord_uid"],
        "text": doc.page_content,
    } for doc in docs
]

with open("docs.json", "w") as f:
    json.dump(objs, f)

In [4]:
with open("docs.json", "r") as f:
    paper_chunks = json.load(f)

# Group chunks by paper ID
paper_dict = {}
for entry in paper_chunks:
    uid = entry["cord_uid"]
    paper_dict.setdefault(uid, []).append(entry["text"])

df = pd.read_csv("../X_Data/subtask4b_query_tweets_train.tsv", sep="\t")
train_df = df[["tweet_text", "cord_uid"]].dropna()
dev_df = pd.read_csv("../X_Data/subtask4b_query_tweets_dev.tsv", sep="\t")

In [None]:
import random

class TweetPaperDataset(Dataset):
    def __init__(self, df, paper_dict, tokenizer, max_len=256, num_negatives=1):
        self.df = df
        self.paper_dict = paper_dict
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.neg = num_negatives

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        tweet = row["tweet_text"]
        pos_id = row["cord_uid"]

        pos_chunks = self.paper_dict.get(pos_id, [])
        pos_text = random.choice(pos_chunks) if pos_chunks else ""

        all_ids = list(self.paper_dict.keys())
        all_ids.remove(pos_id)
        neg_ids = random.sample(all_ids, self.neg)
        neg_chunks = [random.choice(self.paper_dict[nid]) for nid in neg_ids]

        tweet_inputs = self.tokenizer(tweet, padding="max_length", truncation=True, max_length=self.max_len, return_tensors="pt")
        pos_inputs = self.tokenizer(pos_text, padding="max_length", truncation=True, max_length=self.max_len, return_tensors="pt")
        neg_inputs = [self.tokenizer(chunk, padding="max_length", truncation=True, max_length=self.max_len, return_tensors="pt") for chunk in neg_chunks]

        return {
            "tweet": tweet_inputs,
            "pos": pos_inputs,
            "neg": neg_inputs
        }

In [None]:
class DualBertEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state[:, 0, :]

In [None]:
def compute_loss(tweet_emb, pos_emb, neg_embs):
    # Triplet loss: bring tweet closer to positive than negative
    pos_sim = nn.functional.cosine_similarity(tweet_emb, pos_emb)
    loss = 0
    for neg in neg_embs:
        neg_sim = nn.functional.cosine_similarity(tweet_emb, neg)
        loss += torch.mean(nn.functional.relu(1 - pos_sim + neg_sim))
    return loss

In [None]:
def train(model, dataloader, optimizer, epoch=0):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc=f"Training epoch {epoch}"):
        tweet_inputs = batch["tweet"]
        pos_inputs = batch["pos"]
        neg_inputs = batch["neg"]

        tweet_emb = model(tweet_inputs["input_ids"].squeeze(1).to(device),
                          tweet_inputs["attention_mask"].squeeze(1).to(device))
        pos_emb = model(pos_inputs["input_ids"].squeeze(1).to(device),
                        pos_inputs["attention_mask"].squeeze(1).to(device))
        neg_embs = [model(n["input_ids"].squeeze(1).to(device),
                          n["attention_mask"].squeeze(1).to(device)) for n in neg_inputs]

        loss = compute_loss(tweet_emb, pos_emb, neg_embs)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)


In [9]:
from os import putenv

putenv("HSA_OVERRIDE_GFX_VERSION", "11.0.0")
putenv("PYTORCH_ROCM_ARCH", "gfx1100")
putenv("HIP_VISIBLE_DEVICES", "0")
putenv("ROOCM_PATH", "/opt/rocm-6.3.4")
putenv("HIP_PLATFORM", "amd")
putenv("HIP_DEVICE", "0")
putenv("AMD_SERIALIZE_KERNEL", "3")
putenv("AMD_LOG_LEVEL", "5")

print(torch.version.hip)

6.3.42131-fa1d09cbd


In [None]:
train_dataset = TweetPaperDataset(train_df, paper_dict, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

model = DualBertEncoder().to(device)
if os.path.exists("bert_dual_encoder.pt"):
    model.load_state_dict(torch.load("bert_dual_encoder.pt"))

optimizer = AdamW(model.parameters(), lr=2e-5)
if os.path.exists("bert_dual_encoder_optimizer.pt"):
    optimizer.load_state_dict(torch.load("bert_dual_encoder_optimizer.pt"))

for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    loss = train(model, train_loader, optimizer, epoch=epoch)
    print(f"Loss: {loss:.4f}")

In [71]:
torch.save(model.state_dict(), "bert_dual_encoder.pt")
torch.save(optimizer.state_dict(), "bert_dual_encoder_optimizer.pt")

In [23]:
from torch.nn.functional import cosine_similarity

@torch.no_grad()
def embed_texts(texts, model, tokenizer, max_len=256, batch_size=32):
    model.eval()
    embeddings = []
    for i in tqdm(range(0, len(texts), batch_size)):
        batch = texts[i:i+batch_size]
        encodings = tokenizer(batch, padding=True, truncation=True, max_length=max_len, return_tensors="pt").to(device)
        with torch.no_grad():
            output = model(encodings["input_ids"], encodings["attention_mask"])
        embeddings.extend(output.cpu())
    return torch.stack(embeddings)

In [3]:
paper_chunk_texts = []
paper_chunk_uids = []

for uid, chunks in paper_dict.items():
    for chunk in chunks:
        paper_chunk_texts.append(chunk)
        paper_chunk_uids.append(uid)

# paper_chunk_embeddings = embed_texts(paper_chunk_texts, model, tokenizer)
# print(f"Embedded {len(paper_chunk_texts)} paper chunks.")

In [6]:
dev_tweets = dev_df["tweet_text"].tolist()
dev_ids = dev_df["cord_uid"].tolist()

# tweet_embeddings = embed_texts(dev_tweets, model, tokenizer)
# print(f"Embedded {len(dev_tweets)} dev tweets.")

In [None]:
torch.save({
    "embeddings": paper_chunk_embeddings,
    "uids": paper_chunk_uids
}, "paper_chunk_embeddings.pt")

torch.save({
    "embeddings": tweet_embeddings,
    "tweets": dev_tweets,
    "cord_uids": dev_ids
}, "dev_tweet_embeddings.pt")

In [None]:
from collections import defaultdict

top_k = 5
results = []

for idx, tweet_emb in tqdm(enumerate(tweet_embeddings), desc="Tweets"):
    scores = defaultdict(float)
    
    for chunk_idx, chunk_emb in tqdm(enumerate(paper_chunk_embeddings), desc="Chunks", leave=False):
        uid = paper_chunk_uids[chunk_idx]
        sim = cosine_similarity(tweet_emb.unsqueeze(0), chunk_emb.unsqueeze(0)).item()
        scores[uid] = max(scores[uid], sim)  # Keep best similarity per paper

    # Sort by similarity and get top_k papers
    top_papers = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
    results.append({
        "tweet": dev_tweets[idx],
        "gold_cord_uid": dev_ids[idx],
        "retrieved": [uid for uid, score in top_papers]
    })

# Display a sample
for r in results[5:10]:
    print(f"\nTweet: {r['tweet']}")
    print(f"Gold paper: {r['gold_cord_uid']}")
    print("Top 5 retrieved:", r["retrieved"])

Tweets: 1400it [18:21,  1.27it/s]


Tweet: covid recovery: this study from the usa reveals that a proportion of cases experience impairment in some cognitive functions for several months after infection. some possible biases &amp; limitations but more research is required on impact of these long term effects.
Gold paper: 3qvh482o
Top 5 retrieved: ['yf3z913h', 'ws75gpsc', 'lj37a4xn', 'n2rec4i8', 'ho6qjkyr']

Tweet: "Among 139 clients exposed to two symptomatic hair stylists with confirmed COVID-19, while both the stylists & the clients wore face masks, no additional symptomatic cases were reported; among 67 clients tested for SARS-CoV-2, all test results were negative"
Gold paper: r58aohnu
Top 5 retrieved: ['r58aohnu', 'astxi4el', '2e02uktc', 'h3tor88n', 'ydv0hc0m']

Tweet: I recall early on reading that researchers who had examined other coronaviruses discovered that individuals could contract the same virus multiple times within the same year.  I even located a source for it!
Gold paper: sts48u9i
Top 5 retrieved: ['r67




In [28]:
for r in results[5:10]:
    print(f"\nTweet: {r['tweet']}")
    print(f"Gold paper: {r['gold_cord_uid']}")
    print("Top 5 retrieved:", r["retrieved"])

def compute_mrr5(results):
    mrr = 0
    for result in results:
        gold_uid = result["gold_cord_uid"]
        retrieved = result["retrieved"]
        if gold_uid in retrieved:
            rank = retrieved.index(gold_uid) + 1
            mrr += 1 / rank
    return mrr / len(results)

mrr = compute_mrr5(results)
print(f"Mean Reciprocal Rank (MRR@5): {mrr:.4f}")


Tweet: IL-6 seems to be a primary catalyst of this uncontrolled inflammation in #covid19, and #tocilizumab, a mab IL-6 receptor blocker, has been used in small series of severe covid-19 cases with early reports of success
Gold paper: 8cvjsisw
Top 5 retrieved: ['ysd9pmq1', 'et1ekgdl', '36zu137v', '3bo4md44', '0b1dbz6q']

Tweet: macro-level, multi-national analysis shows that public mask compliance is negligible for stopping the spread of covid-19  — it appears to be socially damaging.
Gold paper: tra5ewc5
Top 5 retrieved: ['f96qs295', '9b6cepf4', 'nv1w6juh', 'zl4ixyg1', 'opjfy3xr']

Tweet: Long, slender transmission chains of severe acute respiratory syndrome coronavirus (sars-cov-2) may go undetected for several weeks at low to moderate reproductive numbers: implications for containment and elimination strategy  [🚨preprint]
Gold paper: yoiq6cgt
Top 5 retrieved: ['yoiq6cgt', '3o5c0l24', 'x0cs571f', 'yjbmi8ur', 'wwt7mn55']

Tweet: Significant vitamin D deficiency in people with COVID-19

In [9]:
# Compare to performance with OpenAI embeddings
from openai import OpenAI
import time

openai = OpenAI()

EMBEDDING_BATCH_SIZE = 1000 

def get_openai_embeddings(texts: list[str], sleep=False):
    embeddings = []
    for i in tqdm(range(0, len(texts), EMBEDDING_BATCH_SIZE)):
        batch = texts[i:i+EMBEDDING_BATCH_SIZE]
        batch = [t.replace("\n", "").strip() for t in batch]
        print(len(batch))

        embedding_response = openai.embeddings.create(
            input=batch,
            model="text-embedding-3-large"
        )
        print("Embeddings")
        embeddings.extend([np.array(embedding.embedding) for embedding in embedding_response.data])
        if sleep:
            time.sleep(3)
    return embeddings

openai_tweet_embeddings = get_openai_embeddings(dev_tweets)

openai_paper_chunk_embeddings = get_openai_embeddings(paper_chunk_texts, sleep=True)

torch_tweet_embeddings = [torch.from_numpy(a) for a in openai_tweet_embeddings]
torch_paper_chunk_embeddings = [torch.from_numpy(a) for a in openai_paper_chunk_embeddings]

# Compute Cosine Similarity
similarities = []
for tweet_emb in tqdm(openai_tweet_embeddings, desc="Tweets"):
    scores = defaultdict(float)
    
    for chunk_emb in tqdm(openai_paper_chunk_embeddings, desc="Chunks", leave=False):
        sim = cosine_similarity(tweet_emb.unsqueeze(0), chunk_emb.unsqueeze(0)).item()
        scores[uid] = max(scores[uid], sim)  # Keep best similarity per paper

    # Sort by similarity and get top_k papers
    top_papers = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
    similarities.append({
        "tweet": dev_tweets[idx],
        "gold_cord_uid": dev_ids[idx],
        "retrieved": [uid for uid, score in top_papers]
    })

# Display a sample
for r in similarities[5:10]:
    print(f"\nTweet: {r['tweet']}")
    print(f"Gold paper: {r['gold_cord_uid']}")
    print("Top 5 retrieved:", r["retrieved"])

def compute_mrr5(results):
    mrr = 0
    for result in results:
        gold_uid = result["gold_cord_uid"]
        retrieved = result["retrieved"]
        if gold_uid in retrieved:
            rank = retrieved.index(gold_uid) + 1
            mrr += 1 / rank
    return mrr / len(results)
mrr = compute_mrr5(similarities)
print(f"Mean Reciprocal Rank (MRR@5): {mrr:.4f}")


OpenAIError: The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable

In [70]:
torch_tweet_embeddings = [torch.from_numpy(a) for a in openai_tweet_embeddings]
torch_paper_chunk_embeddings = [torch.from_numpy(a) for a in openai_paper_chunk_embeddings]

torch.save({
    "embeddings": torch_paper_chunk_embeddings,
    "uids": paper_chunk_uids
}, "paper_chunk_embeddings.pt")
torch.save({
    "embeddings": torch_tweet_embeddings,
    "tweets": dev_tweets,
    "cord_uids": dev_ids
}, "dev_tweet_embeddings.pt")

# similarities = []
# for idx, tweet_emb in tqdm(enumerate(torch_tweet_embeddings), desc="Tweets"):
#     scores = defaultdict(float)
    
#     for chunk_idx, chunk_emb in tqdm(enumerate(torch_paper_chunk_embeddings), desc="Chunks", leave=False):
#         uid = paper_chunk_uids[chunk_idx]
#         sim = cosine_similarity(tweet_emb.unsqueeze(0), chunk_emb.unsqueeze(0)).item()
#         scores[uid] = max(scores[uid], sim)  # Keep best similarity per paper

#     # Sort by similarity and get top_k papers
#     top_papers = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
#     similarities.append({
#         "tweet": dev_tweets[idx],
#         "gold_cord_uid": dev_ids[idx],
#         "retrieved": [uid for uid, score in top_papers]
#     })

# # Display a sample
# for r in similarities[5:10]:
#     print(f"\nTweet: {r['tweet']}")
#     print(f"Gold paper: {r['gold_cord_uid']}")
#     print("Top 5 retrieved:", r["retrieved"])

# def compute_mrr5(results):
#     mrr = 0
#     for result in results:
#         gold_uid = result["gold_cord_uid"]
#         retrieved = result["retrieved"]
#         if gold_uid in retrieved:
#             rank = retrieved.index(gold_uid) + 1
#             mrr += 1 / rank
#     return mrr / len(results)
# mrr = compute_mrr5(similarities)
# print(f"Mean Reciprocal Rank (MRR@5): {mrr:.4f}")

In [None]:
from sentence_transformers import SentenceTransformer
from collections import defaultdict
from torch.nn.functional import cosine_similarity


model = SentenceTransformer('all-MiniLM-L6-v2')

def mini_embed_texts(texts, model, batch_size=32):
    embeddings = []
    for i in tqdm(range(0, len(texts), batch_size)):
        batch = texts[i:i+batch_size]
        output = model.encode(batch, convert_to_tensor=True)
        embeddings.extend(output)
    return embeddings

mini_paper_chunk_embeddings = mini_embed_texts(paper_chunk_texts, model)
mini_tweet_embeddings = mini_embed_texts(dev_tweets, model)

def compute_similarities(tweet_embeds, paper_embeds, paper_uids, top_k=5):
    similarities = []
    for idx, tweet_emb in tqdm(enumerate(tweet_embeds), desc="Tweets"):
        scores = defaultdict(float)
        
        for chunk_idx, chunk_emb in tqdm(enumerate(paper_embeds), desc="Chunks", leave=False):
            uid = paper_uids[chunk_idx]
            sim = cosine_similarity(tweet_emb, chunk_emb, dim=0).item()
            scores[uid] = max(scores[uid], sim)  # Keep best similarity per paper
    
        # Sort by similarity and get top_k papers
        top_papers = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
        similarities.append({
            "tweet": dev_tweets[idx],
            "gold_cord_uid": dev_ids[idx],
            "retrieved": [uid for uid, score in top_papers]
        })
    return similarities

similarities = compute_similarities(mini_tweet_embeddings, mini_paper_chunk_embeddings, paper_chunk_uids)

# Display a sample
for r in similarities[5:10]:
    print(f"\nTweet: {r['tweet']}")
    print(f"Gold paper: {r['gold_cord_uid']}")
    print("Top 5 retrieved:", r["retrieved"])

def compute_mrr5(results):
    mrr = 0
    for result in results:
        gold_uid = result["gold_cord_uid"]
        retrieved = result["retrieved"]
        if gold_uid in retrieved:
            rank = retrieved.index(gold_uid) + 1
            mrr += 1 / rank
    return mrr / len(results)

result = compute_mrr5(similarities) 

print(f"Mean Reciprocal Rank (MRR@5): {result:.4f}")

0.30.2


100%|██████████| 1584/1584 [00:26<00:00, 59.85it/s]
100%|██████████| 44/44 [00:00<00:00, 76.31it/s]
Tweets: 1400it [1:37:12,  4.17s/it]


Tweet: IL-6 seems to be a primary catalyst of this uncontrolled inflammation in #covid19, and #tocilizumab, a mab IL-6 receptor blocker, has been used in small series of severe covid-19 cases with early reports of success
Gold paper: 8cvjsisw
Top 5 retrieved: ['z9jqbliw', '3r418rss', 'vx9vqr1k', 'zt5alyy2', 'xh723tgl']

Tweet: macro-level, multi-national analysis shows that public mask compliance is negligible for stopping the spread of covid-19  — it appears to be socially damaging.
Gold paper: tra5ewc5
Top 5 retrieved: ['qi1henyy', '1s8jzzwg', 'zycgczqy', '763v4duh', 'jjh1z5c6']

Tweet: Long, slender transmission chains of severe acute respiratory syndrome coronavirus (sars-cov-2) may go undetected for several weeks at low to moderate reproductive numbers: implications for containment and elimination strategy  [🚨preprint]
Gold paper: yoiq6cgt
Top 5 retrieved: ['yoiq6cgt', 'vmmztj0a', 'ueb7mjnv', 'hbkl5cam', 'w1azm2mc']

Tweet: Significant vitamin D deficiency in people with COVID-19


