In [1]:
import pickle
import torch
from transformers import AutoTokenizer, AutoModel
from torch.nn.functional import cosine_similarity
from tqdm import tqdm


In [2]:
# Load data from data path
data_path = "../data"

# load docs
with open(f"{data_path}/docs.pkl", "rb") as f:
    docs = pickle.load(f)

# load queries
with open(f"{data_path}/queries.pkl", "rb") as f:
    queries = pickle.load(f)

In [3]:
print(f"Number of documents: {len(docs)}")
print(f"Number of queries: {len(queries)}")

Number of documents: 210158
Number of queries: 150


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Number of GPUs available: {torch.cuda.device_count()}")

# Load BERT model
model = AutoModel.from_pretrained("bert-base-uncased")
model = torch.nn.DataParallel(model)  # Enable data parallelism
model = model.to(device)  # Move model to GPUs

model.eval()

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

Using device: cuda
Number of GPUs available: 4


In [5]:
def batch_embed_texts(texts, tokenizer, model, batch_size=4096):
    embeddings = []
    for i in tqdm(
        range(0, len(texts), batch_size), desc="Batch embedding", unit="batch"
    ):
        batch_texts = texts[i : i + batch_size]
        inputs = tokenizer(
            batch_texts,
            return_tensors="pt",
            truncation=True,
            padding=True,
            max_length=512,
        ).to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu()
        embeddings.extend(batch_embeddings)
    return embeddings


def compute_similarity(query_embedding, doc_embedding):
    return cosine_similarity(
        query_embedding.unsqueeze(0), doc_embedding.unsqueeze(0)
    ).item()


In [6]:
print(f"number of documents: {len(docs)}")

# Preprocess for empty query.query or empty doc.text
empty_docs = [doc for doc in docs if doc.text == "" or doc.text is None]

# Remove empty docs
for doc in empty_docs:
    docs.remove(doc)
print(f"number of documents: {len(docs)}")

number of documents: 210158
number of documents: 210157


In [7]:
# Embed documents in batches
doc_texts = [doc.text for doc in docs]
doc_embeddings = batch_embed_texts(doc_texts, tokenizer, model)

# Embed queries in batches
query_texts = [query.query for query in queries]
query_embeddings = batch_embed_texts(query_texts, tokenizer, model)


Batch embedding: 100%|██████████| 52/52 [06:19<00:00,  7.31s/batch]
Batch embedding: 100%|██████████| 1/1 [00:00<00:00,  8.42batch/s]


In [8]:
len(doc_embeddings), len(query_embeddings)

(210157, 150)

In [31]:

print(compute_similarity(query_embeddings[20], doc_embeddings[103216]))

# calculate top 10 similar documents for query 321
query_no = 20
query_embedding = query_embeddings[query_no]
similar_docs = []
for idx, doc_embedding in enumerate(doc_embeddings):
    similarity = compute_similarity(query_embedding, doc_embedding)
    similar_docs.append((idx, similarity))

# sort by similarity
similar_docs = sorted(similar_docs, key=lambda x: x[1], reverse=True)

# get top 10
top_10_docs = similar_docs[:10]

for doc_no, similarity in top_10_docs:
    print(f"Doc: {doc_no}:{docs[doc_no].doc_no}\t Similarity: {similarity:.4f}")


0.6789228916168213
Doc: 35468:FT941-9178	 Similarity: 0.8791
Doc: 154820:FT941-451	 Similarity: 0.8700
Doc: 25759:FT942-5917	 Similarity: 0.8682
Doc: 37007:FT941-13076	 Similarity: 0.8639
Doc: 15490:FT944-9379	 Similarity: 0.8637
Doc: 143993:FT932-7852	 Similarity: 0.8621
Doc: 161869:FT932-11021	 Similarity: 0.8617
Doc: 143323:FT943-2100	 Similarity: 0.8614
Doc: 199180:FT942-2839	 Similarity: 0.8609
Doc: 69284:FT944-7539	 Similarity: 0.8591
