In [None]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchtext.data.utils import get_tokenizer
from datasets import load_dataset
import json
import pickle
import requests
import io
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# -----------------------------
# 1. Load MS MARCO V1.1 training dataset
# -----------------------------

# This will stream the data, you don't have to download the full file
dataset = load_dataset("microsoft/ms_marco", "v1.1", split="train")  # or "validation"

README.md:   0%|          | 0.00/9.48k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/175M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/10047 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/82326 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/9650 [00:00<?, ? examples/s]

TypeError: string indices must be integers, not 'str'

In [None]:
# -----------------------------
# 2. Create triples of (query, relevant_docs, irrelevant_docs)
# -----------------------------

# Create a list of all passages to pass into loop to help create irrelevant documents
# all_passages = set()
# for row in tqdm(dataset, desc="Building passage pool..."):
#     all_passages.update(row['passages']['passage_text'])
# all_passages = list(all_passages)

# # Create triples of (query, relevant_docs, irrelevant_docs)
# triples = []
# for row in tqdm(dataset, desc="Creating triples..."):
#     query = row['query']
#     relevant = row['passages']['passage_text'][:10]
#     relevant_set = set(relevant)
#     irrelevant_pool = list(set(all_passages) - relevant_set)
#     irrelevant = random.sample(irrelevant_pool, 10)
#     triples.append((query, relevant, irrelevant))
#     if len(triples) >= 1000:
#         break

# # Save the triples to a file for later use
# with open("triples_1000.pkl", "wb") as f:
#     pickle.dump(triples, f)

# with open("triples_1000.json", "w", encoding="utf-8") as f:
#     json.dump(triples, f, ensure_ascii=False, indent=2)

# print(triples[0][0])  # query
# print(triples[0][1])  # 10 relevant docs
# print(triples[0][2])  # 10 irrelevant docs

Building passage pool...:   0%|          | 0/82326 [00:00<?, ?it/s]

Building passage pool...: 100%|██████████| 82326/82326 [00:24<00:00, 3327.26it/s]
Creating triples...:   1%|          | 999/82326 [04:14<5:45:39,  3.92it/s] 


what is rba
["Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal. These RBA subsidiaries were involved in bribing overseas officials so that Australia might win lucrative note-printing contracts. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site.", "The Reserve Bank of Australia (RBA) came into being on 14 January 1960 as Australia 's central bank and banknote issuing authority, when the Reserve Bank Act 1959 removed the central banking functions from the Commonwealth Bank. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site

In [None]:
# -------------------------------
# 2b. Load triples from file
# -------------------------------

# Load .pkl
with open("triples_1000.pkl", "rb") as f:
    triples = pickle.load(f)

# # Load .json
# with open("triples_1000.json", "r", encoding="utf-8") as f:
#     triples = json.load(f)


In [28]:
# -------------------------------
# 3. Tokenise triples
# -------------------------------

tokenizer = get_tokenizer("basic_english")

# Lists to hold the tokenized items
tokenized_queries = []
tokenized_relevant_docs = []
tokenized_irrelevant_docs = []

for query, rels, irrels in triples:
    # Tokenize the query
    tokenized_query = tokenizer(query)
    tokenized_queries.append(tokenized_query)
    
    # Tokenize relevant docs
    tokenized_rels = [tokenizer(doc) for doc in rels]
    tokenized_relevant_docs.append(tokenized_rels)
    
    # Tokenize irrelevant docs
    tokenized_irrels = [tokenizer(doc) for doc in irrels]
    tokenized_irrelevant_docs.append(tokenized_irrels)

tokenized_triples = []
for query, rels, irrels in triples:
    tokenized_query = tokenizer(query)
    tokenized_rels = [tokenizer(doc) for doc in rels]
    tokenized_irrels = [tokenizer(doc) for doc in irrels]
    tokenized_triples.append((tokenized_query, tokenized_rels, tokenized_irrels))

# Check the first tokenized items
print(tokenized_triples[0])  # Tokenized query

(['what', 'is', 'rba'], [['since', '2007', ',', 'the', 'rba', "'", 's', 'outstanding', 'reputation', 'has', 'been', 'affected', 'by', 'the', "'", 'securency', "'", 'or', 'npa', 'scandal', '.', 'these', 'rba', 'subsidiaries', 'were', 'involved', 'in', 'bribing', 'overseas', 'officials', 'so', 'that', 'australia', 'might', 'win', 'lucrative', 'note-printing', 'contracts', '.', 'the', 'assets', 'of', 'the', 'bank', 'include', 'the', 'gold', 'and', 'foreign', 'exchange', 'reserves', 'of', 'australia', ',', 'which', 'is', 'estimated', 'to', 'have', 'a', 'net', 'worth', 'of', 'a$101', 'billion', '.', 'nearly', '94%', 'of', 'the', 'rba', "'", 's', 'employees', 'work', 'at', 'its', 'headquarters', 'in', 'sydney', ',', 'new', 'south', 'wales', 'and', 'at', 'the', 'business', 'resumption', 'site', '.'], ['the', 'reserve', 'bank', 'of', 'australia', '(', 'rba', ')', 'came', 'into', 'being', 'on', '14', 'january', '1960', 'as', 'australia', "'", 's', 'central', 'bank', 'and', 'banknote', 'issuing'

In [25]:
# -----------------------------
# 4. Load Vocabulary of CBOW model
# -----------------------------
with open("vocab_new.json", "r", encoding="utf-8") as f:
    word_to_ix = json.load(f)

ix_to_word = {int(i): w for w, i in word_to_ix.items()}
vocab_size = len(word_to_ix)

# -----------------------------
# 5. Load Pre-trained Embeddings (placeholder)
# -----------------------------
embed_dim = 200  
state = torch.load("text8_cbow_embeddings.pth", map_location='cpu')  # Shape: [vocab_size, embed_dim]
embeddings = state["embeddings.weight"] 

assert embeddings.shape[0] == vocab_size, "Vocab size mismatch!"

In [26]:
# -----------------------------
# 6. CBOW Model
# -----------------------------
class CBOW(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, embed_dim)
        self.linear = nn.Linear(embed_dim, vocab_size)

    def forward(self, inputs):
        embeds = self.embeddings(inputs).mean(dim=1)
        return self.linear(embeds)

cbow_model = CBOW(vocab_size, embed_dim)
cbow_model.embeddings.weight.data.copy_(embeddings)
cbow_model.embeddings.weight.requires_grad = False  

In [33]:
# -----------------------------
# 7. Create query, relevant document and irrelevant document embeddings
# -----------------------------

num_docs = 10  # number of relevant/irrelevant docs you expect per triple

query_embeddings = []
relevant_doc_embeddings = []
irrelevant_doc_embeddings = []

for i, (tokenized_query, tokenized_rels, tokenized_irrels) in enumerate(tokenized_triples):
    # Query
    q_ids = [word_to_ix[t] for t in tokenized_query if t in word_to_ix]
    if q_ids:
        with torch.no_grad():
            q_vecs = cbow_model.embeddings(torch.tensor(q_ids))
            q_emb = q_vecs.mean(dim=0)
        query_embeddings.append(q_emb)
    else:
        query_embeddings.append(torch.zeros(embed_dim))
    
    # Relevant docs (pad to num_docs)
    rel_embs = []
    for doc_tokens in tokenized_rels[:num_docs]:
        doc_ids = [word_to_ix[t] for t in doc_tokens if t in word_to_ix]
        if doc_ids:
            with torch.no_grad():
                doc_vecs = cbow_model.embeddings(torch.tensor(doc_ids))
                doc_emb = doc_vecs.mean(dim=0)
            rel_embs.append(doc_emb)
        else:
            rel_embs.append(torch.zeros(embed_dim))
    # Pad if fewer than num_docs
    while len(rel_embs) < num_docs:
        rel_embs.append(torch.zeros(embed_dim))
    relevant_doc_embeddings.append(torch.stack(rel_embs))
    
    # Irrelevant docs (pad to num_docs)
    irrel_embs = []
    for doc_tokens in tokenized_irrels[:num_docs]:
        doc_ids = [word_to_ix[t] for t in doc_tokens if t in word_to_ix]
        if doc_ids:
            with torch.no_grad():
                doc_vecs = cbow_model.embeddings(torch.tensor(doc_ids))
                doc_emb = doc_vecs.mean(dim=0)
            irrel_embs.append(doc_emb)
        else:
            irrel_embs.append(torch.zeros(embed_dim))
    while len(irrel_embs) < num_docs:
        irrel_embs.append(torch.zeros(embed_dim))
    irrelevant_doc_embeddings.append(torch.stack(irrel_embs))

# Now you can safely stack
X_queries = torch.stack(query_embeddings)                  # [n, embed_dim]
X_rels = torch.stack(relevant_doc_embeddings)              # [n, 10, embed_dim]
X_irrels = torch.stack(irrelevant_doc_embeddings)          # [n, 10, embed_dim]

print(X_queries.shape)  # Should be [1000, embed_dim]
print(X_rels.shape)     # Should be [1000, 10, embed_dim]
print(X_irrels.shape)   # Should be [1000, 10, embed_dim]

torch.Size([1000, 200])
torch.Size([1000, 10, 200])
torch.Size([1000, 10, 200])


In [None]:
# -----------------------------
# 7. Define Two Tower Model (QueryTower and DocTower)
# -----------------------------

class QueryTower(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_layers=1, rnn_type='gru'):
        super().__init__()
        if rnn_type == 'gru':
            self.rnn = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True)
        elif rnn_type == 'lstm':
            self.rnn = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        else:
            raise ValueError("Unknown rnn_type: choose 'gru' or 'lstm'")

    def forward(self, x):
        # x: (batch, seq_len, embed_dim)
        out, h = self.rnn(x)
        # If LSTM, h is a tuple (h_n, c_n); for GRU, h is just h_n
        if isinstance(h, tuple):
            h = h[0]
        # h: (num_layers, batch, hidden_dim)
        # Use the last layer's hidden state
        return h[-1]  # (batch, hidden_dim)

class DocTower(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_layers=1, rnn_type='gru'):
        super().__init__()
        if rnn_type == 'gru':
            self.rnn = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True)
        elif rnn_type == 'lstm':
            self.rnn = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        else:
            raise ValueError("Unknown rnn_type: choose 'gru' or 'lstm'")

    def forward(self, x):
        out, h = self.rnn(x)
        if isinstance(h, tuple):
            h = h[0]
        return h[-1]  # (batch, hidden_dim)


In [None]:
# -----------------------------
# 9. 
# -----------------------------

def cosine_similarity(x, y):
    return F.cosine_similarity(x, y, dim=1)  # (batch,)

def cosine_distance(x, y):
    return 1 - cosine_similarity(x, y)       # (batch,)

def triplet_loss_function(query, relevant_doc, irrelevant_doc, distance_function, margin):
    # query, relevant_doc, irrelevant_doc: (batch, dim)
    relevant_distance = distance_function(query, relevant_doc)       # (batch,)
    irrelevant_distance = distance_function(query, irrelevant_doc)   # (batch,)
    # Triplet loss per sample: max(0, rel_dist - irrel_dist + margin)
    triplet_loss = torch.relu(relevant_distance - irrelevant_distance + margin)
    return triplet_loss.mean()  # Mean over batch

# Example setup
embed_dim = 200      # Dimension of your CBOW word embeddings
hidden_dim = 128     # You choose this (can tune it)
num_layers = 1

qry_tower = QueryTower(embed_dim, hidden_dim, num_layers)
doc_tower = DocTower(embed_dim, hidden_dim, num_layers)

# Example data (single instance, batch size = 1)
# Suppose q_embeds, rel_embeds, irrel_embeds are (seq_len, embed_dim)
# Add batch dimension: (1, seq_len, embed_dim)

q_vec = qry_tower(q_embeds.unsqueeze(0))           # (1, hidden_dim)
rel_vec = doc_tower(rel_embeds.unsqueeze(0))       # (1, hidden_dim)
irrel_vec = doc_tower(irrel_embeds.unsqueeze(0))   # (1, hidden_dim)


In [None]:
# Triplet loss function

def triplet_loss_function((query, relevant_document, irrelevant_document), distance_function, margin):
    relevant_distance = distance_function(query, relevant_document)
    irrelevant_distance = distance_function(query, irrelevant_document)
    triplet_loss = max(0, relevant_distance - irrelevant_distance + margin)
    return triplet_loss

In [None]:
# Cosine distance function
def cosine_similarity(x, y):
    return F.cosine_similarity(x, y, dim=1)

cosine_distance(x,y) = 1 - cosine_similarity(x, y)

# To minimise distance between query & relevant doc - Maximise cosine similarity
def distance_function(query, relevant_document):
    return cosine_distance(query, relevant_document)