In [33]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchtext.data.utils import get_tokenizer
from datasets import load_dataset
from collections import defaultdict
import pickle
import json
import re
from tqdm import tqdm

In [34]:
# # -----------------------------
# # 1. Load MS MARCO V1.1 training dataset
# # -----------------------------

# # This will stream the data, you don't have to download the full file
# dataset = load_dataset("microsoft/ms_marco", "v1.1", split="train")  # or "validation"

In [35]:
# # -----------------------------
# # 2. Create triples of (query, relevant_docs, irrelevant_docs)
# # -----------------------------

# # 1. Build passage pool and index mapping for fast sampling and seed for reproducibility
# passage_to_idx = dict()
# idx_to_passage = []
# for row in tqdm(dataset, desc="Building passage pool..."):
#     for p in row['passages']['passage_text']:
#         if p not in passage_to_idx:
#             passage_to_idx[p] = len(idx_to_passage)
#             idx_to_passage.append(p)
# num_passages = len(idx_to_passage)

# # 2. For each query, map relevant passage indices
# triples = []
# for row in tqdm(dataset, desc="Creating triples..."):
#     query = row['query']
#     relevant_passages = row['passages']['passage_text'][:10]
#     relevant_indices = [passage_to_idx[p] for p in relevant_passages]
    
#     # For fast sampling: mask out relevant indices
#     mask = np.ones(num_passages, dtype=bool)
#     mask[relevant_indices] = False
#     irrelevant_indices = np.random.choice(np.where(mask)[0], 10, replace=False)
#     irrelevant_passages = [idx_to_passage[i] for i in irrelevant_indices]

#     triples.append((query, relevant_passages, irrelevant_passages))

# with open("triples_full.pkl", "wb") as f:
#     pickle.dump(triples, f)

# print(triples[0][0])  # query
# print(triples[0][1])  # 10 relevant docs
# print(triples[0][2])  # 10 irrelevant docs

In [None]:
# -------------------------------
# 1b. Load triples from file & tokenize
# -------------------------------

# Load triples
with open("triples_full.pkl", "rb") as f:
    triples = pickle.load(f)


In [None]:
# -------------------------------
# 1c. Tokenize triples
# -------------------------------

def preprocess(text):
    text = text.lower()
    text = re.sub(r'[^a-z0-9 ]+', '', text)
    return text.split()

tokenized_triples = []
for query, rel_docs, irrel_docs in tqdm(triples, desc="Tokenizing triples"):
    tokenized_query = preprocess(query)
    tokenized_rels = [preprocess(doc) for doc in rel_docs]
    tokenized_irrels = [preprocess(doc) for doc in irrel_docs]
    tokenized_triples.append((tokenized_query, tokenized_rels, tokenized_irrels))


Tokenizing triples: 100%|██████████| 82326/82326 [03:59<00:00, 343.64it/s] 


In [None]:
# -----------------------------
# 2. Load Vocabulary & Pre-trained Embeddings
# -----------------------------
with open("vocab_new.json", "r", encoding="utf-8") as f:
    word_to_ix = json.load(f)

ix_to_word = {int(i): w for w, i in word_to_ix.items()}
vocab_size = len(word_to_ix)

embed_dim = 200  
state = torch.load("text8_cbow_embeddings.pth", map_location='cpu')  # Shape: [vocab_size, embed_dim]
embeddings = state["embeddings.weight"] 

assert embeddings.shape[0] == vocab_size, "Vocab size mismatch!"

In [None]:
# -----------------------------
# 3. CBOW Model
# -----------------------------
class CBOW(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, embed_dim)
        self.linear = nn.Linear(embed_dim, vocab_size)

    def forward(self, inputs):
        embeds = self.embeddings(inputs).mean(dim=1)
        return self.linear(embeds)

cbow_model = CBOW(vocab_size, embed_dim)
cbow_model.embeddings.weight.data.copy_(embeddings)
cbow_model.embeddings.weight.requires_grad = False 


In [None]:
# # -------------------------------
# # 4. Embed queries, rel and irrel documents using pre-trained CBOW model
# # -------------------------------

# device = torch.device("cpu")
# #torch.set_num_threads(12)

# num_docs = 10  # number of relevant/irrelevant docs you expect per triple
# batch_size = 64  # adjust to taste; big enough to speed up, small enough to never threaten RAM

# query_embeddings = []
# relevant_doc_embeddings = []
# irrelevant_doc_embeddings = []

# def process_and_save_embeddings(
#     tokenized_triples, word_to_ix, cbow_model, 
#     query_embeds_path, rel_doc_embeds_path, irrel_doc_embeds_path,
#     batch_size=batch_size, num_docs=num_docs
# ):
#     query_embeds_batch = []
#     rel_doc_embeds_batch = []
#     irrel_doc_embeds_batch = []
    
#     for i, (tokenized_query, tokenized_rels, tokenized_irrels) in enumerate(
#         tqdm(tokenized_triples, desc="CBOW embedding + streaming", total=len(tokenized_triples))
#     ):
#         # Query: embeddings per token
#         q_ids = [word_to_ix[t] for t in tokenized_query if t in word_to_ix]
#         if q_ids:
#             with torch.no_grad():
#                 q_vecs = cbow_model.embeddings(torch.tensor(q_ids))
#             query_embeds_batch.append(q_vecs)  # shape: [query_len, embed_dim]
#         else:
#             query_embeds_batch.append(torch.zeros(1, cbow_model.embeddings.embedding_dim))
        
#         # Relevant docs: list of (doc_len, embed_dim)
#         rel_embs = []
#         for doc_tokens in tokenized_rels[:num_docs]:
#             doc_ids = [word_to_ix[t] for t in doc_tokens if t in word_to_ix]
#             if doc_ids:
#                 with torch.no_grad():
#                     doc_vecs = cbow_model.embeddings(torch.tensor(doc_ids))
#                 rel_embs.append(doc_vecs)
#             else:
#                 rel_embs.append(torch.zeros(1, cbow_model.embeddings.embedding_dim))
#         while len(rel_embs) < num_docs:
#             rel_embs.append(torch.zeros(1, cbow_model.embeddings.embedding_dim))
#         rel_doc_embeds_batch.append(rel_embs)

#         # Irrelevant docs: list of (doc_len, embed_dim)
#         irrel_embs = []
#         for doc_tokens in tokenized_irrels[:num_docs]:
#             doc_ids = [word_to_ix[t] for t in doc_tokens if t in word_to_ix]
#             if doc_ids:
#                 with torch.no_grad():
#                     doc_vecs = cbow_model.embeddings(torch.tensor(doc_ids))
#                 irrel_embs.append(doc_vecs)
#             else:
#                 irrel_embs.append(torch.zeros(1, cbow_model.embeddings.embedding_dim))(0)
#         while len(irrel_embs) < num_docs:
#             irrel_embs.append(torch.zeros(1, cbow_model.embeddings.embedding_dim))
#         irrel_doc_embeds_batch.append(irrel_embs)

#         # Save every batch_size triples
#         if (i + 1) % batch_size == 0 or (i + 1) == len(tokenized_triples):
#             # Save as batch-lists (not stacked), because sequences are ragged
#             with open(query_embeds_path, 'ab') as fq:
#                 pickle.dump(query_embeds_batch, fq)
#             with open(rel_doc_embeds_path, 'ab') as fr:
#                 pickle.dump(rel_doc_embeds_batch, fr)
#             with open(irrel_doc_embeds_path, 'ab') as fi:
#                 pickle.dump(irrel_doc_embeds_batch, fi)

#             # Free RAM
#             query_embeds_batch.clear()
#             rel_doc_embeds_batch.clear()
#             irrel_doc_embeds_batch.clear()


# # Usage example:
# process_and_save_embeddings(
#     tokenized_triples[:5120], word_to_ix, cbow_model,
#     "query_embeds.pkl", "rel_doc_embeds.pkl", "irrel_doc_embeds.pkl",
#     batch_size=batch_size, num_docs=num_docs
# )

# # # Usage:
# # process_and_save_embeddings(
# #     tokenized_triples, word_to_ix, cbow_model,
# #     "query_embeds.pkl", "rel_doc_embeds.pkl", "irrel_doc_embeds.pkl",
# #     batch_size=2048, num_docs=10
# # )

CBOW embedding + streaming: 100%|██████████| 5120/5120 [00:36<00:00, 141.12it/s]


In [None]:
# -----------------------------
# 5. Load Embeddings
# -----------------------------

def load_all_batches(path):
    all_data = []
    with open(path, "rb") as f:
        while True:
            try:
                batch = pickle.load(f)
                all_data.extend(batch)  # for lists of tensors, this flattens batches
            except EOFError:
                break
    return all_data

# Load them all
query_embeds = load_all_batches("query_embeds.pkl")           # list of [query_len, embed_dim] tensors
rel_doc_embeds = load_all_batches("rel_doc_embeds.pkl")       # list of lists: each is [num_docs] of [doc_len, embed_dim] tensors
irrel_doc_embeds = load_all_batches("irrel_doc_embeds.pkl")

In [None]:
# -----------------------------
# 4. Define distance function & Triplet loss function
# -----------------------------

# Cosine similarity for calculation of cosine distance
def cosine_similarity(x, y):
    return F.cosine_similarity(x, y, dim=1)

# Triplet loss function - will compute the loss for a batch of triplets
def triplet_loss_function(query, relevant_doc, irrelevant_doc, distance_function, margin):
    rel_dist = distance_function(query, relevant_doc)         # (batch,)
    irrel_dist = distance_function(query, irrelevant_doc)     # (batch,)
    triplet_loss = torch.relu(rel_dist - irrel_dist + margin)
    return triplet_loss.mean()                                # Average over batch

In [None]:
# -----------------------------
# 5. Define Two Tower Model (QueryTower and DocTower)
# -----------------------------

class QueryTower(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_layers=1, rnn_type='gru'):
        super().__init__()
        if rnn_type == 'gru':
            self.rnn = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True)
        elif rnn_type == 'lstm':
            self.rnn = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        else:
            raise ValueError("Unknown rnn_type: choose 'gru' or 'lstm'")

    def forward(self, x, lengths):
        # x: (batch, seq_len, embed_dim)
        packed = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        out, h = self.rnn(packed)
        if isinstance(h, tuple):  # LSTM
            h = h[0]
        return h[-1]  # (batch, hidden_dim)

class DocTower(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_layers=1, rnn_type='gru'):
        super().__init__()
        if rnn_type == 'gru':
            self.rnn = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True)
        elif rnn_type == 'lstm':
            self.rnn = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        else:
            raise ValueError("Unknown rnn_type: choose 'gru' or 'lstm'")

    def forward(self, x, lengths):
        packed = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        out, h = self.rnn(packed)
        if isinstance(h, tuple):
            h = h[0]
        return h[-1]


In [None]:
# -----------------------------
# 6. Prepare dataset for DataLoader
# -----------------------------

class TripleDataset(Dataset):
    def __init__(self, X_queries, X_rels, X_irrels):
        self.X_queries = X_queries    # list of N tensors (seq_len, embed_dim)
        self.X_rels = X_rels          # list of N lists of 10 tensors (seq_len, embed_dim)
        self.X_irrels = X_irrels      # list of N lists of 10 tensors (seq_len, embed_dim)
        self.n = len(X_queries)

    def __len__(self):
        return self.n * 10

    def __getitem__(self, idx):
        triple_idx = idx // 10
        doc_idx = idx % 10

        qry = self.X_queries[triple_idx]              # (q_seq_len, embed_dim)
        rel = self.X_rels[triple_idx][doc_idx]        # (rel_seq_len, embed_dim)
        irrel = self.X_irrels[triple_idx][doc_idx]    # (irrel_seq_len, embed_dim)

        return qry, rel, irrel
    
def collate_fn(batch):
    q_seqs, r_seqs, i_seqs = zip(*batch)
    q_lens = [x.shape[0] for x in q_seqs]
    r_lens = [x.shape[0] for x in r_seqs]
    i_lens = [x.shape[0] for x in i_seqs]
    q_padded = pad_sequence(q_seqs, batch_first=True)
    r_padded = pad_sequence(r_seqs, batch_first=True)
    i_padded = pad_sequence(i_seqs, batch_first=True)
    return q_padded, r_padded, i_padded, q_lens, r_lens, i_lens


# Create the dataset and dataloader for batching
batch_size = 64
dataset = TripleDataset(query_embeds, rel_doc_embeds, irrel_doc_embeds)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)


In [None]:
# -----------------------------
# 7. Train the model
# -----------------------------

# embed_dim defined in section 3.

# Define hyperparameters
hidden_dim = 128  # Dimension of the hidden state in RNNs (GRU/LSTM - can be adjusted)
margin = 0.2  # Margin for triplet loss (can be adjusted)
distance_function = cosine_similarity)

qry_tower = QueryTower(embed_dim, hidden_dim, rnn_type='gru')  # or 'lstm'
doc_tower = DocTower(embed_dim, hidden_dim, rnn_type='gru')  # or 'lstm'

optimizer = torch.optim.Adam(list(qry_tower.parameters()) + list(doc_tower.parameters()), lr=1e-3)
num_epochs = 5  # Set as needed

for epoch in range(num_epochs):
    epoch_loss = 0.0
    for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        qry_embeds, rel_embeds, irrel_embeds, q_lens, r_lens, i_lens = batch
        # qry_embeds: (batch, q_seq_len, embed_dim)
        # rel_embeds: (batch, r_seq_len, embed_dim)
        # irrel_embeds: (batch, i_seq_len, embed_dim)

        # Mask: keep items where rel doc is *not* all zeros
        # mask shape: (batch,)
        mask = ~torch.all(rel_embeds == 0, dim=(1,2))

        # If mask is all False, skip batch (shouldn't happen)
        if mask.sum() == 0:
            continue

        # Only keep non-padded triples
        qry_embeds    = qry_embeds[mask]
        rel_embeds    = rel_embeds[mask]
        irrel_embeds  = irrel_embeds[mask]
        q_lens        = [q_lens[i] for i in range(len(mask)) if mask[i]]
        r_lens        = [r_lens[i] for i in range(len(mask)) if mask[i]]
        i_lens        = [i_lens[i] for i in range(len(mask)) if mask[i]]

        qry_vecs    = qry_tower(qry_embeds, q_lens)
        rel_vecs    = doc_tower(rel_embeds, r_lens)
        irrel_vecs  = doc_tower(irrel_embeds, i_lens)

        loss = triplet_loss_function(
            qry_vecs, rel_vecs, irrel_vecs,
            distance_function=distance_function,
            margin=margin
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {epoch_loss / len(dataloader):.4f}")


# Save final model after all epochs are done
torch.save({
    'epoch': num_epochs,
    'qry_tower_state_dict': qry_tower.state_dict(),
    'doc_tower_state_dict': doc_tower.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': epoch_loss,  # from last epoch
}, "two_tower_final.pt")

Epoch 1/5: 100%|██████████| 1600/1600 [10:28<00:00,  2.55it/s]


Epoch 1/5 | Loss: 0.0916


Epoch 2/5: 100%|██████████| 1600/1600 [08:48<00:00,  3.03it/s]


Epoch 2/5 | Loss: 0.0322


Epoch 3/5: 100%|██████████| 1600/1600 [09:45<00:00,  2.73it/s]


Epoch 3/5 | Loss: 0.0166


Epoch 4/5: 100%|██████████| 1600/1600 [10:03<00:00,  2.65it/s]


Epoch 4/5 | Loss: 0.0088


Epoch 5/5: 100%|██████████| 1600/1600 [11:44<00:00,  2.27it/s]


Epoch 5/5 | Loss: 0.0055


In [None]:
# -----------------------------
# 8. Load the trained model
# -----------------------------
hidden_dim = 128

# 1. Define your model classes (as in your code)
qry_tower = QueryTower(embed_dim, hidden_dim, num_layers=1, rnn_type='gru')
doc_tower = DocTower(embed_dim, hidden_dim, num_layers=1, rnn_type='gru')

# 2. Create optimizer with the same params as training
optimizer = torch.optim.Adam(list(qry_tower.parameters()) + list(doc_tower.parameters()), lr=1e-3)

# 3. Load checkpoint
checkpoint = torch.load("two_tower_final.pt", map_location="cpu")  # or map_location="cuda" if using GPU

# 4. Restore state dicts
qry_tower.load_state_dict(checkpoint['qry_tower_state_dict'])
doc_tower.load_state_dict(checkpoint['doc_tower_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# 5. (Optional) get other info
epoch = checkpoint['epoch']
last_loss = checkpoint['loss']


In [None]:
# -----------------------------
# 8. Evaluate the model using Recall@K
# -----------------------------

def evaluate_model(qry_tower, doc_tower, val_data, distance_fn, K=1, device="cpu"):
    qry_tower.eval()
    doc_tower.eval()
    num_correct = 0
    total = 0

    with torch.no_grad():
        for q_embed, rel_embeds, irrel_embeds in val_data:
            # q_embed: (q_seq_len, embed_dim)
            # rel_embeds: list of (d_seq_len, embed_dim) [just 1, or list if more]
            # irrel_embeds: list of (d_seq_len, embed_dim)
            
            # Stack all candidate docs: relevant + irrelevants
            candidates = [rel_embeds] + list(irrel_embeds)
            all_doc_tensors = [doc.to(device) for doc in candidates]
            doc_lens = [doc.shape[0] for doc in all_doc_tensors]
            padded_docs = torch.nn.utils.rnn.pad_sequence(all_doc_tensors, batch_first=True)
            
            # Encode query
            q_input = q_embed.unsqueeze(0).to(device)              # (1, q_seq_len, embed_dim)
            q_len = [q_embed.shape[0]]
            q_vec = QueryTower(q_input, q_len)                     # (1, hidden_dim)
            
            # Encode all docs in batch
            d_vecs = DocTower(padded_docs, doc_lens)               # (num_candidates, hidden_dim)

            # Compute distances (query vs. all docs)
            dists = distance_fn(q_vec.repeat(len(doc_lens), 1), d_vecs)  # (num_candidates,)
            sorted_indices = torch.argsort(dists)  # Smallest = most similar

            # Recall@K: Is the relevant doc (index 0) in top K?
            if 0 in sorted_indices[:K]:
                num_correct += 1
            total += 1

    recall_at_k = num_correct / total
    return recall_at_k

In [None]:
# -----------------------------
# 9. Load MS MARCO V1.1 validation dataset, process into triples and tokenize & embed
# -----------------------------

# Load triples
with open("triples_val.pkl", "rb") as f:
    triples_val = pickle.load(f)

def preprocess(text):
    text = text.lower()
    text = re.sub(r'[^a-z0-9 ]+', '', text)
    return text.split()

tokenized_triples_val = []
for query, rel_docs, irrel_docs in tqdm(triples_val, desc="Tokenizing triples"):
    tokenized_query = preprocess(query)
    tokenized_rels = [preprocess(doc) for doc in rel_docs]
    tokenized_irrels = [preprocess(doc) for doc in irrel_docs]
    tokenized_triples_val.append((tokenized_query, tokenized_rels, tokenized_irrels))


num_docs = 10  # number of relevant/irrelevant docs you expect per triple
batch_size = 64  # adjust to taste; big enough to speed up, small enough to never threaten RAM

query_embeddings_val = []
relevant_doc_embeddings_val = []
irrelevant_doc_embeddings_val = []

def process_and_save_embeddings(
    tokenized_triples_val, word_to_ix, cbow_model, 
    query_embeds_path, rel_doc_embeds_path, irrel_doc_embeds_path,
    batch_size=batch_size, num_docs=num_docs
):
    query_embeds_batch = []
    rel_doc_embeds_batch = []
    irrel_doc_embeds_batch = []
    
    for i, (tokenized_query, tokenized_rels, tokenized_irrels) in enumerate(
        tqdm(tokenized_triples_val, desc="CBOW embedding + streaming", total=len(tokenized_triples_val))
    ):
        # Query: embeddings per token
        q_ids = [word_to_ix[t] for t in tokenized_query if t in word_to_ix]
        if q_ids:
            with torch.no_grad():
                q_vecs = cbow_model.embeddings(torch.tensor(q_ids))
            query_embeds_batch.append(q_vecs)  # shape: [query_len, embed_dim]
        else:
            query_embeds_batch.append(torch.zeros(1, cbow_model.embeddings.embedding_dim))
        
        # Relevant docs: list of (doc_len, embed_dim)
        rel_embs = []
        for doc_tokens in tokenized_rels[:num_docs]:
            doc_ids = [word_to_ix[t] for t in doc_tokens if t in word_to_ix]
            if doc_ids:
                with torch.no_grad():
                    doc_vecs = cbow_model.embeddings(torch.tensor(doc_ids))
                rel_embs.append(doc_vecs)
            else:
                rel_embs.append(torch.zeros(1, cbow_model.embeddings.embedding_dim))
        while len(rel_embs) < num_docs:
            rel_embs.append(torch.zeros(1, cbow_model.embeddings.embedding_dim))
        rel_doc_embeds_batch.append(rel_embs)

        # Irrelevant docs: list of (doc_len, embed_dim)
        irrel_embs = []
        for doc_tokens in tokenized_irrels[:num_docs]:
            doc_ids = [word_to_ix[t] for t in doc_tokens if t in word_to_ix]
            if doc_ids:
                with torch.no_grad():
                    doc_vecs = cbow_model.embeddings(torch.tensor(doc_ids))
                irrel_embs.append(doc_vecs)
            else:
                irrel_embs.append(torch.zeros(1, cbow_model.embeddings.embedding_dim))
        while len(irrel_embs) < num_docs:
            irrel_embs.append(torch.zeros(1, cbow_model.embeddings.embedding_dim))
        irrel_doc_embeds_batch.append(irrel_embs)

        # Save every batch_size triples
        if (i + 1) % batch_size == 0 or (i + 1) == len(tokenized_triples_val):
            # Save as batch-lists (not stacked), because sequences are ragged
            with open(query_embeds_path, 'ab') as fq:
                pickle.dump(query_embeds_batch, fq)
            with open(rel_doc_embeds_path, 'ab') as fr:
                pickle.dump(rel_doc_embeds_batch, fr)
            with open(irrel_doc_embeds_path, 'ab') as fi:
                pickle.dump(irrel_doc_embeds_batch, fi)

            # Free RAM
            query_embeds_batch.clear()
            rel_doc_embeds_batch.clear()
            irrel_doc_embeds_batch.clear()


# Usage example:
process_and_save_embeddings(
    tokenized_triples_val, word_to_ix, cbow_model,
    "query_embeds_val.pkl", "rel_doc_embeds_val.pkl", "irrel_doc_embeds_val.pkl",
    batch_size=batch_size, num_docs=num_docs
)

Tokenizing triples: 100%|██████████| 10047/10047 [00:06<00:00, 1594.66it/s]
CBOW embedding + streaming: 100%|██████████| 10047/10047 [00:59<00:00, 168.78it/s]


In [None]:
# -----------------------------
# 9b. Load validation data embeddings
# -----------------------------

# Load them all
query_embeds_val = load_all_batches("query_embeds_val.pkl")           # list of [query_len, embed_dim] tensors
rel_doc_embeds_val = load_all_batches("rel_doc_embeds_val.pkl")       # list of lists: each is [num_docs] of [doc_len, embed_dim] tensors
irrel_doc_embeds_val = load_all_batches("irrel_doc_embeds_val.pkl")

val_data = []
# Ensure lengths match (or handle indexing errors)
for i in range(len(query_embeds_val)):
    q_embed = query_embeds_val[i]                       # [q_len, embed_dim]
    rel_embed = rel_doc_embeds_val[i][0]                # [rel_len, embed_dim]; use the first relevant doc
    irrel_embed = irrel_doc_embeds_val[i][0]            # [irrel_len, embed_dim]; use the first irrelevant doc
    val_data.append((q_embed, rel_embed, [irrel_embed])) 


TypeError: Module.eval() missing 1 required positional argument: 'self'

In [None]:
qry_tower.eval()
doc_tower.eval()

recall = evaluate_model(qry_tower, doc_tower, val_data, cosine_distance, K=1)
print(f"Recall@1: {recall:.4f}")

  all_doc_tensors = [torch.tensor(doc, device=device) for doc in candidates]


Recall@1: 0.9021


In [None]:
# -----------------------------
# 10. Hyperparameter Search & Evaluation
# -----------------------------

margins = [0.2, 0.1, 0.05]
hidden_dims = [128]
distance_function = cosine_similarity

best_result = {"score": 0, "settings": None}

for margin in margins:
    for hidden_dim in hidden_dims:
        qry_tower = QueryTower(embed_dim, hidden_dim, rnn_type='lstm')  # or 'lstm'
        doc_tower = DocTower(embed_dim, hidden_dim, rnn_type='lstm')    # or 'lstm'

        optimizer = torch.optim.Adam(list(qry_tower.parameters()) + list(doc_tower.parameters()), lr=1e-3)
        num_epochs = 3  # Set as needed

        for epoch in range(num_epochs):
            epoch_loss = 0.0
            for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
                qry_embeds, rel_embeds, irrel_embeds, q_lens, r_lens, i_lens = batch

                mask = ~torch.all(rel_embeds == 0, dim=(1,2))

                if mask.sum() == 0:
                    continue

                qry_embeds = qry_embeds[mask]
                rel_embeds = rel_embeds[mask]
                irrel_embeds = irrel_embeds[mask]
                q_lens = [q_lens[i] for i in range(len(mask)) if mask[i]]
                r_lens = [r_lens[i] for i in range(len(mask)) if mask[i]]
                i_lens = [i_lens[i] for i in range(len(mask)) if mask[i]]

                qry_vecs = qry_tower(qry_embeds, q_lens)
                rel_vecs = doc_tower(rel_embeds, r_lens)
                irrel_vecs = doc_tower(irrel_embeds, i_lens)

                loss = triplet_loss_function(
                    qry_vecs, rel_vecs, irrel_vecs,
                    distance_function=distance_function,
                    margin=margin
                )

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()

            print(f"Margin: {margin}, Hidden: {hidden_dim} | Epoch {epoch+1}/{num_epochs} | Loss: {epoch_loss / len(dataloader):.4f}")

        recall = evaluate_model(qry_tower, doc_tower, val_data, distance_function, K=1, device=device)
        print(f"Margin: {margin}, Hidden: {hidden_dim}, Recall@1: {recall:.4f}")
        if recall > best_result["score"]:
            best_result = {"score": recall, "settings": (margin, hidden_dim)}

print("Best config:", best_result)

Epoch 1/3: 100%|██████████| 800/800 [14:41<00:00,  1.10s/it]


Margin: 0.2, Hidden: 128 | Epoch 1/3 | Loss: 0.1212


Epoch 2/3: 100%|██████████| 800/800 [14:49<00:00,  1.11s/it]


Margin: 0.2, Hidden: 128 | Epoch 2/3 | Loss: 0.0529


Epoch 3/3: 100%|██████████| 800/800 [18:02<00:00,  1.35s/it] 


Margin: 0.2, Hidden: 128 | Epoch 3/3 | Loss: 0.0310
Margin: 0.2, Hidden: 128, Recall@1: 0.8677


Epoch 1/3:   2%|▎         | 20/800 [00:49<32:18,  2.48s/it] 


KeyboardInterrupt: 

In [None]:
# -----------------------------
# 11. Inference: Encode documents and queries, find top-k relevant docs
# -----------------------------

def encode_documents(doc_tower, all_doc_embeds, all_doc_lens, device='cpu', batch_size=128):
    doc_tower.eval()
    all_vecs = []
    with torch.no_grad():
        for i in range(0, len(all_doc_embeds), batch_size):
            batch_embeds = all_doc_embeds[i:i+batch_size].to(device)
            batch_lens = all_doc_lens[i:i+batch_size]
            vecs = doc_tower(batch_embeds, batch_lens)  # Shape: (batch, hidden_dim)
            all_vecs.append(vecs.cpu())
    return torch.cat(all_vecs, dim=0)  # Shape: (num_docs, hidden_dim)

def encode_query(qry_tower, query_embed, query_len, device='cpu'):
    qry_tower.eval()
    with torch.no_grad():
        query_vec = qry_tower(query_embed.to(device), query_len)
    return query_vec.cpu()  # Shape: (1, hidden_dim)

def find_top_k(query_vec, doc_vecs, k=5, distance_fn=None):
    # query_vec: (1, hidden_dim), doc_vecs: (num_docs, hidden_dim)
    if distance_fn is None:
        # Default to cosine distance
        def distance_fn(q, d):
            return 1 - torch.nn.functional.cosine_similarity(q, d)
    distances = distance_fn(query_vec, doc_vecs)
    # If query_vec is (1,hidden_dim), expand to (num_docs,hidden_dim)
    if query_vec.shape[0] == 1:
        distances = distance_fn(query_vec.expand_as(doc_vecs), doc_vecs)
    # Get top k smallest distances
    topk = torch.topk(-distances, k)  # negative because smallest distance = highest relevance
    indices = topk.indices.cpu().numpy()
    scores = -topk.values.cpu().numpy()
    return indices, scores


In [None]:
# -----------------------------
# 12. Calculate all document embeddings and lengths
# -----------------------------

all_doc_texts = []  # To hold all relevant doc texts for padding
all_query_texts = []

for query, rels, irrels in triples:
    all_query_texts.append(query)
    # rels is the list of relevant doc texts for this query
    # We slice to num_docs to match the embedding logic, then pad if needed
    for doc in rels[:num_docs]:
        all_doc_texts.append(doc)
    # Pad if fewer than num_docs
    while len(rels) < num_docs:
        all_doc_texts.append("")
        rels.append("")  # So embedding and text padding always match

# Step 1: Gather your document embedding sequences and their lengths
all_doc_embeds_list = []
all_doc_lens_list = []

for batch in dataloader:
    rel_embeds, r_lens = batch[1], batch[4]  # rel_embeds: (batch, seq_len, embed_dim)
    for i in range(rel_embeds.shape[0]):
        all_doc_embeds_list.append(rel_embeds[i])  # (seq_len, embed_dim)
        all_doc_lens_list.append(r_lens[i])

# Step 2: Pad all embeddings to max seq_len
# pad_sequence wants a list of (seq_len, embed_dim), returns (max_seq_len, num_docs, embed_dim)
padded = pad_sequence(all_doc_embeds_list, batch_first=True)  # (num_docs, max_seq_len, embed_dim)

all_doc_embeds = padded  # (num_docs, max_seq_len, embed_dim)
all_doc_lens = torch.tensor(all_doc_lens_list)

print("all_doc_embeds shape:", all_doc_embeds.shape)
print("all_doc_lens shape:", all_doc_lens.shape)
print("all_doc_texts length:", len(all_doc_texts))
print("all_query_texts length:", len(all_query_texts))


In [None]:
# -----------------------------
# 13. Encode documents and queries, find top-k relevant docs
# -----------------------------

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 1. Encode all documents
doc_vecs = encode_documents(doc_tower, all_doc_embeds, all_doc_lens, device=device)
#print("Encoded document vectors shape:", doc_vecs.shape)

# 2. Encode a sample query (here, just using first doc for demo)
query_embed = all_doc_embeds[0].unsqueeze(0)
query_len = all_doc_lens[0].unsqueeze(0)
query_vec = encode_query(qry_tower, query_embed, query_len, device=device)
#print("Encoded query vector shape:", query_vec.shape)

# 3. Retrieve top k relevant documents
k = 5
indices, scores = find_top_k(query_vec, doc_vecs, k=k)

print("Query: ", all_query_texts[0])
print("Top document matches:")
for rank, i in enumerate(indices):
    print(f"{rank+1}: {all_doc_texts[i]}")
    print(f"   (score: {scores[rank]:.4f})")


NameError: name 'encode_documents' is not defined

In [None]:
# -----------------------------
# 14. Custom Query Inference
# -----------------------------

# Hardcode your query here
custom_query = "what is machine learning?"

# Tokenize and embed the custom query
tokenized_query = tokenizer(custom_query)  # returns List[int] or torch.Tensor

# Map tokens to ids, filtering out tokens not in vocab
q_ids = [word_to_ix[t] for t in tokenized_query if t in word_to_ix]

if q_ids:
    with torch.no_grad():
        q_vecs = cbow_model.embeddings(torch.tensor(q_ids))  # (seq_len, embed_dim)
else:
    q_vecs = torch.zeros(1, embed_dim)  # Fallback for empty/unknown queries)

# Pad to (1, seq_len, embed_dim) for model
query_embed = q_vecs.unsqueeze(0)  # (1, seq_len, embed_dim)
query_len = torch.tensor([q_vecs.shape[0]])

# Move to device if needed
query_embed = query_embed.to(device)
query_len = query_len.to(device)

# Encode the hardcoded query
query_vec = encode_query(qry_tower, query_embed, query_len, device=device)

# Retrieve top k relevant documents
k = 5
indices, scores = find_top_k(query_vec, doc_vecs, k=k)

print("Custom Query:", custom_query)
print("Top document matches:")
for rank, i in enumerate(indices):
    print(f"{rank+1}: {all_doc_texts[i]}")
    print(f"   (score: {scores[rank]:.4f})")


In [None]:
# -----------------------------
# 15. Redis test
# -----------------------------

r = redis.Redis(host='localhost', port=6379, decode_responses=True)

r.set('foo', 'bar')
# True
r.get('foo')
# bar

