In [None]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchtext.data.utils import get_tokenizer
from datasets import load_dataset
import json
import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
# -----------------------------
# 1. Load MS MARCO V1.1 training dataset
# -----------------------------

# This will stream the data, you don't have to download the full file
dataset = load_dataset("microsoft/ms_marco", "v1.1", split="train")  # or "validation"

In [None]:
# -----------------------------
# 2. Create triples of (query, relevant_docs, irrelevant_docs)
# -----------------------------

# Create a list of all passages to pass into loop to help create irrelevant documents
all_passages = set()
for row in tqdm(dataset, desc="Building passage pool..."):
    all_passages.update(row['passages']['passage_text'])
all_passages = list(all_passages)

# Create triples of (query, relevant_docs, irrelevant_docs)
triples = []
for row in tqdm(dataset, desc="Creating triples..."):
    query = row['query']
    relevant = row['passages']['passage_text'][:10]
    relevant_set = set(relevant)
    irrelevant_pool = list(set(all_passages) - relevant_set)
    irrelevant = random.sample(irrelevant_pool, 10)
    triples.append((query, relevant, irrelevant))
    if len(triples) >= 1000:
        break

# Save the triples to a file for later use
with open("triples_1000.pkl", "wb") as f:
    pickle.dump(triples, f)

with open("triples_1000.json", "w", encoding="utf-8") as f:
    json.dump(triples, f, ensure_ascii=False, indent=2)

print(triples[0][0])  # query
print(triples[0][1])  # 10 relevant docs
print(triples[0][2])  # 10 irrelevant docs

Building passage pool...:   0%|          | 0/82326 [00:00<?, ?it/s]

Building passage pool...: 100%|██████████| 82326/82326 [00:24<00:00, 3327.26it/s]
Creating triples...:   1%|          | 999/82326 [04:14<5:45:39,  3.92it/s] 


what is rba
["Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal. These RBA subsidiaries were involved in bribing overseas officials so that Australia might win lucrative note-printing contracts. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site.", "The Reserve Bank of Australia (RBA) came into being on 14 January 1960 as Australia 's central bank and banknote issuing authority, when the Reserve Bank Act 1959 removed the central banking functions from the Commonwealth Bank. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site

In [4]:
# -------------------------------
# 2b. Load triples from file
# -------------------------------

# Load .pkl
with open("triples_1000.pkl", "rb") as f:
    triples = pickle.load(f)

# # Load .json
# with open("triples_1000.json", "r", encoding="utf-8") as f:
#     triples = json.load(f)


In [5]:
# -------------------------------
# 3. Tokenise triples
# -------------------------------

tokenizer = get_tokenizer("basic_english")

# Lists to hold the tokenized items
tokenized_queries = []
tokenized_relevant_docs = []
tokenized_irrelevant_docs = []

for query, rels, irrels in triples:
    # Tokenize the query
    tokenized_query = tokenizer(query)
    tokenized_queries.append(tokenized_query)
    
    # Tokenize relevant docs
    tokenized_rels = [tokenizer(doc) for doc in rels]
    tokenized_relevant_docs.append(tokenized_rels)
    
    # Tokenize irrelevant docs
    tokenized_irrels = [tokenizer(doc) for doc in irrels]
    tokenized_irrelevant_docs.append(tokenized_irrels)

tokenized_triples = []
for query, rels, irrels in triples:
    tokenized_query = tokenizer(query)
    tokenized_rels = [tokenizer(doc) for doc in rels]
    tokenized_irrels = [tokenizer(doc) for doc in irrels]
    tokenized_triples.append((tokenized_query, tokenized_rels, tokenized_irrels))

# Check the first tokenized items
print(tokenized_triples[0])  # Tokenized query

(['what', 'is', 'rba'], [['since', '2007', ',', 'the', 'rba', "'", 's', 'outstanding', 'reputation', 'has', 'been', 'affected', 'by', 'the', "'", 'securency', "'", 'or', 'npa', 'scandal', '.', 'these', 'rba', 'subsidiaries', 'were', 'involved', 'in', 'bribing', 'overseas', 'officials', 'so', 'that', 'australia', 'might', 'win', 'lucrative', 'note-printing', 'contracts', '.', 'the', 'assets', 'of', 'the', 'bank', 'include', 'the', 'gold', 'and', 'foreign', 'exchange', 'reserves', 'of', 'australia', ',', 'which', 'is', 'estimated', 'to', 'have', 'a', 'net', 'worth', 'of', 'a$101', 'billion', '.', 'nearly', '94%', 'of', 'the', 'rba', "'", 's', 'employees', 'work', 'at', 'its', 'headquarters', 'in', 'sydney', ',', 'new', 'south', 'wales', 'and', 'at', 'the', 'business', 'resumption', 'site', '.'], ['the', 'reserve', 'bank', 'of', 'australia', '(', 'rba', ')', 'came', 'into', 'being', 'on', '14', 'january', '1960', 'as', 'australia', "'", 's', 'central', 'bank', 'and', 'banknote', 'issuing'

In [6]:
# -----------------------------
# 4. Load Vocabulary of CBOW model
# -----------------------------
with open("vocab_new.json", "r", encoding="utf-8") as f:
    word_to_ix = json.load(f)

ix_to_word = {int(i): w for w, i in word_to_ix.items()}
vocab_size = len(word_to_ix)

# -----------------------------
# 5. Load Pre-trained Embeddings (placeholder)
# -----------------------------
embed_dim = 200  
state = torch.load("text8_cbow_embeddings.pth", map_location='cpu')  # Shape: [vocab_size, embed_dim]
embeddings = state["embeddings.weight"] 

assert embeddings.shape[0] == vocab_size, "Vocab size mismatch!"

In [7]:
# -----------------------------
# 6. CBOW Model
# -----------------------------
class CBOW(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, embed_dim)
        self.linear = nn.Linear(embed_dim, vocab_size)

    def forward(self, inputs):
        embeds = self.embeddings(inputs).mean(dim=1)
        return self.linear(embeds)

cbow_model = CBOW(vocab_size, embed_dim)
cbow_model.embeddings.weight.data.copy_(embeddings)
cbow_model.embeddings.weight.requires_grad = False  

In [None]:
query_embeddings = []
relevant_doc_embeddings = []
irrelevant_doc_embeddings = []

for i, (tokenized_query, tokenized_rels, tokenized_irrels) in enumerate(tokenized_triples):
    # --- Query: embed each token (do not average)
    q_ids = [word_to_ix[t] for t in tokenized_query if t in word_to_ix]
    if q_ids:
        with torch.no_grad():
            q_vecs = cbow_model.embeddings(torch.tensor(q_ids))  # (query_seq_len, embed_dim)
        query_embeddings.append(q_vecs)
    else:
        query_embeddings.append(torch.zeros(1, embed_dim))  # Empty query, 1 "dummy" token

    # --- Relevant docs (list of (doc_seq_len, embed_dim))
    rel_embs = []
    for doc_tokens in tokenized_rels[:num_docs]:
        doc_ids = [word_to_ix[t] for t in doc_tokens if t in word_to_ix]
        if doc_ids:
            with torch.no_grad():
                doc_vecs = cbow_model.embeddings(torch.tensor(doc_ids))  # (doc_seq_len, embed_dim)
            rel_embs.append(doc_vecs)
        else:
            rel_embs.append(torch.zeros(1, embed_dim))  # 1 "dummy" token if empty
    # Pad if fewer than num_docs
    while len(rel_embs) < num_docs:
        rel_embs.append(torch.zeros(1, embed_dim))
    relevant_doc_embeddings.append(rel_embs)  # List of 10 tensors (doc_seq_len, embed_dim)

    # --- Irrelevant docs
    irrel_embs = []
    for doc_tokens in tokenized_irrels[:num_docs]:
        doc_ids = [word_to_ix[t] for t in doc_tokens if t in word_to_ix]
        if doc_ids:
            with torch.no_grad():
                doc_vecs = cbow_model.embeddings(torch.tensor(doc_ids))
            irrel_embs.append(doc_vecs)
        else:
            irrel_embs.append(torch.zeros(1, embed_dim))
    while len(irrel_embs) < num_docs:
        irrel_embs.append(torch.zeros(1, embed_dim))
    irrelevant_doc_embeddings.append(irrel_embs)  # List of 10 tensors

# Note: You will NOT stack here because each sequence has a different length.
# Instead, you keep:
# - query_embeddings: List of N tensors, each (query_seq_len, embed_dim)
# - relevant_doc_embeddings: List of N lists of 10 tensors (each (doc_seq_len, embed_dim))
# - irrelevant_doc_embeddings: List of N lists of 10 tensors (each (doc_seq_len, embed_dim))

print(len(query_embeddings))        # Should be N
print(len(relevant_doc_embeddings)) # Should be N
print(len(irrelevant_doc_embeddings)) # Should be N
print(query_embeddings[0].shape)    # (seq_len, embed_dim)
print(relevant_doc_embeddings[0][0].shape) # (doc_seq_len, embed_dim)

1000
1000
1000
torch.Size([2, 200])
torch.Size([70, 200])


In [44]:
# -----------------------------
# 7. Define distance functions & Triplet loss function
# -----------------------------

# Cosine similarity for calculation of cosine distance
def cosine_similarity(x, y):
    return F.cosine_similarity(x, y, dim=1)

# Cosine - Smallest value means most similar
def cosine_distance(x, y):
    return 1 - cosine_similarity(x, y)

# Euclidean (L2) - Smallest value means most similar
def euclidean_distance(x, y):
    return torch.norm(x - y, p=2, dim=1)

# Squared Euclidean - Smallest value means most similar
def squared_euclidean_distance(x, y):
    return torch.sum((x - y) ** 2, dim=1)

# Manhattan (L1) - Smallest value means most similar
def manhattan_distance(x, y):
    return torch.norm(x - y, p=1, dim=1)

# Chebyshev (L-infinity) - Smallest value means most similar
def chebyshev_distance(x, y):
    return torch.max(torch.abs(x - y), dim=1).values

# Minkowski - Smallest value means most similar
def minkowski_distance(x, y, p=3):
    return torch.norm(x - y, p=p, dim=1)

# Triplet loss function - will compute the loss for a batch of triplets
def triplet_loss_function(query, relevant_doc, irrelevant_doc, distance_function, margin):
    rel_dist = distance_function(query, relevant_doc)         # (batch,)
    irrel_dist = distance_function(query, irrelevant_doc)     # (batch,)
    triplet_loss = torch.relu(rel_dist - irrel_dist + margin)
    return triplet_loss.mean()                                # Average over batch

In [51]:
# -----------------------------
# 8. Define Two Tower Model (QueryTower and DocTower)
# -----------------------------

class QueryTower(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_layers=1, rnn_type='gru'):
        super().__init__()
        if rnn_type == 'gru':
            self.rnn = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True)
        elif rnn_type == 'lstm':
            self.rnn = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        else:
            raise ValueError("Unknown rnn_type: choose 'gru' or 'lstm'")

    def forward(self, x, lengths):
        # x: (batch, seq_len, embed_dim)
        packed = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        out, h = self.rnn(packed)
        if isinstance(h, tuple):  # LSTM
            h = h[0]
        return h[-1]  # (batch, hidden_dim)

class DocTower(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_layers=1, rnn_type='gru'):
        super().__init__()
        if rnn_type == 'gru':
            self.rnn = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True)
        elif rnn_type == 'lstm':
            self.rnn = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        else:
            raise ValueError("Unknown rnn_type: choose 'gru' or 'lstm'")

    def forward(self, x, lengths):
        packed = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        out, h = self.rnn(packed)
        if isinstance(h, tuple):
            h = h[0]
        return h[-1]



In [52]:
# -----------------------------
# 9. Prepare dataset for DataLoader
# -----------------------------

class TripleDataset(Dataset):
    def __init__(self, X_queries, X_rels, X_irrels):
        self.X_queries = X_queries    # list of N tensors (seq_len, embed_dim)
        self.X_rels = X_rels          # list of N lists of 10 tensors (seq_len, embed_dim)
        self.X_irrels = X_irrels      # list of N lists of 10 tensors (seq_len, embed_dim)
        self.n = len(X_queries)

    def __len__(self):
        return self.n * 10

    def __getitem__(self, idx):
        triple_idx = idx // 10
        doc_idx = idx % 10

        qry = self.X_queries[triple_idx]              # (q_seq_len, embed_dim)
        rel = self.X_rels[triple_idx][doc_idx]        # (rel_seq_len, embed_dim)
        irrel = self.X_irrels[triple_idx][doc_idx]    # (irrel_seq_len, embed_dim)

        return qry, rel, irrel
    
def collate_fn(batch):
    q_seqs, r_seqs, i_seqs = zip(*batch)
    q_lens = [x.shape[0] for x in q_seqs]
    r_lens = [x.shape[0] for x in r_seqs]
    i_lens = [x.shape[0] for x in i_seqs]
    q_padded = pad_sequence(q_seqs, batch_first=True)
    r_padded = pad_sequence(r_seqs, batch_first=True)
    i_padded = pad_sequence(i_seqs, batch_first=True)
    return q_padded, r_padded, i_padded, q_lens, r_lens, i_lens


# Create the dataset and dataloader for batching
batch_size = 32
dataset = TripleDataset(query_embeddings, relevant_doc_embeddings, irrelevant_doc_embeddings)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)


In [None]:
# -----------------------------
# 10. Train the model (change hyperparameters as needed)
# -----------------------------

# embed_dim defined in section 5. CBOW Model

# Define hyperparameters
hidden_dim = 128  # Dimension of the hidden state in RNNs (GRU/LSTM - can be adjusted)
margin = 0.2  # Margin for triplet loss (can be adjusted)
distance_function = cosine_distance  # Choose distance function (cosine_distance, euclidean_distance, manhattan_distance, squared_euclidean_distance, chebyshev_distance, minkowski_distance)

qry_tower = QueryTower(embed_dim, hidden_dim)
doc_tower = DocTower(embed_dim, hidden_dim)

optimizer = torch.optim.Adam(list(qry_tower.parameters()) + list(doc_tower.parameters()), lr=1e-3)
num_epochs = 10  # Set as needed

for epoch in range(num_epochs):
    epoch_loss = 0.0
    for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        qry_embeds, rel_embeds, irrel_embeds, q_lens, r_lens, i_lens = batch

        qry_vecs    = qry_tower(qry_embeds, q_lens)
        rel_vecs    = doc_tower(rel_embeds, r_lens)
        irrel_vecs  = doc_tower(irrel_embeds, i_lens)

        loss = triplet_loss_function(
            qry_vecs, rel_vecs, irrel_vecs,
            distance_function=distance_function,
            margin=margin
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {epoch_loss / len(dataloader):.4f}")


# Save final model after all epochs are done
torch.save({
    'epoch': num_epochs,
    'qry_tower_state_dict': qry_tower.state_dict(),
    'doc_tower_state_dict': doc_tower.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': epoch_loss,  # from last epoch
}, "twotower_final.pt")


Epoch 1/10: 100%|██████████| 313/313 [01:36<00:00,  3.24it/s]


Epoch 1/10 | Loss: 0.1173


Epoch 2/10: 100%|██████████| 313/313 [01:46<00:00,  2.93it/s]


Epoch 2/10 | Loss: 0.0490


Epoch 3/10: 100%|██████████| 313/313 [01:38<00:00,  3.16it/s]


Epoch 3/10 | Loss: 0.0368


Epoch 4/10: 100%|██████████| 313/313 [02:02<00:00,  2.55it/s]


Epoch 4/10 | Loss: 0.0879


Epoch 5/10: 100%|██████████| 313/313 [02:35<00:00,  2.01it/s]


Epoch 5/10 | Loss: 0.1506


Epoch 6/10: 100%|██████████| 313/313 [02:35<00:00,  2.01it/s]


Epoch 6/10 | Loss: 0.1439


Epoch 7/10: 100%|██████████| 313/313 [02:34<00:00,  2.03it/s]


Epoch 7/10 | Loss: 0.1373


Epoch 8/10: 100%|██████████| 313/313 [01:50<00:00,  2.84it/s]


Epoch 8/10 | Loss: 0.1374


Epoch 9/10: 100%|██████████| 313/313 [02:28<00:00,  2.11it/s]


Epoch 9/10 | Loss: 0.1361


Epoch 10/10: 100%|██████████| 313/313 [02:16<00:00,  2.29it/s]

Epoch 10/10 | Loss: 0.1296





In [None]:
# -----------------------------
# 11. Inference: Encode documents and queries, find top-k relevant docs
# -----------------------------

def encode_documents(doc_tower, all_doc_embeds, all_doc_lens, device='cpu', batch_size=128):
    doc_tower.eval()
    all_vecs = []
    with torch.no_grad():
        for i in range(0, len(all_doc_embeds), batch_size):
            batch_embeds = all_doc_embeds[i:i+batch_size].to(device)
            batch_lens = all_doc_lens[i:i+batch_size]
            vecs = doc_tower(batch_embeds, batch_lens)  # Shape: (batch, hidden_dim)
            all_vecs.append(vecs.cpu())
    return torch.cat(all_vecs, dim=0)  # Shape: (num_docs, hidden_dim)

def encode_query(qry_tower, query_embed, query_len, device='cpu'):
    qry_tower.eval()
    with torch.no_grad():
        query_vec = qry_tower(query_embed.to(device), query_len)
    return query_vec.cpu()  # Shape: (1, hidden_dim)

def find_top_k(query_vec, doc_vecs, k=5, distance_fn=None):
    # query_vec: (1, hidden_dim), doc_vecs: (num_docs, hidden_dim)
    if distance_fn is None:
        # Default to cosine distance
        def distance_fn(q, d):
            return 1 - torch.nn.functional.cosine_similarity(q, d)
    distances = distance_fn(query_vec, doc_vecs)
    # If query_vec is (1,hidden_dim), expand to (num_docs,hidden_dim)
    if query_vec.shape[0] == 1:
        distances = distance_fn(query_vec.expand_as(doc_vecs), doc_vecs)
    # Get top k smallest distances
    topk = torch.topk(-distances, k)  # negative because smallest distance = highest relevance
    indices = topk.indices.cpu().numpy()
    scores = -topk.values.cpu().numpy()
    return indices, scores


In [None]:
# -----------------------------
# 12. Inference test example
# -----------------------------

# Assume you have: 
#   - all_doc_embeds (Tensor: [num_docs, seq_len, embed_dim])
#   - all_doc_lens (List/Tensor: [num_docs])
#   - doc_tower, qry_tower loaded/trained

device = 'cuda' if torch.cuda.is_available() else 'cpu'
doc_tower = doc_tower.to(device)
qry_tower = qry_tower.to(device)

# 1. Encode all documents once
doc_vecs = encode_documents(doc_tower, all_doc_embeds, all_doc_lens, device=device)

# 2. Encode input query
# query_embed: shape [1, seq_len, embed_dim]
# query_len: shape [1]
query_vec = encode_query(qry_tower, query_embed, query_len, device=device)

# 3. Find top-k relevant docs (change k as you like)
k = 5
indices, scores = find_top_k(query_vec, doc_vecs, k=k)

# 4. Output top k docs
print(f"Top {k} document indices:", indices)
print(f"Corresponding scores (lower = more similar):", scores)
# If you have the original doc texts, you can do:
print([all_doc_texts[i] for i in indices])