### **Step 1**: Import configs and hyperparameters

In [12]:
import sys
import pandas as pd
import fastparquet
from torch.utils.data import DataLoader, Dataset
import torch
import random
sys.path.append('../')

# --- Manual Data Loading and Preprocessing ---

def load_and_process_parquet(path):
    """Load a parquet file and create triplets with correct positive/negative logic."""
    print(f"\nProcessing {path}...")
    df = pd.read_parquet(path, engine='fastparquet')

    # Filter for valid rows: keep queries that have at least one passage
    valid_mask = (
        df['query'].notna() & 
        df['passages.passage_text'].notna() &
        df['passages.passage_text'].apply(lambda x: len(x) > 0 if isinstance(x, list) else False)
    )
    df = df[valid_mask].reset_index(drop=True)
    print(f"  Found {len(df)} valid queries.")

    # Create a flat list of (query_id, passage) for negative sampling
    all_passages = []
    for idx, row in df.iterrows():
        for p in row['passages.passage_text']:
            all_passages.append((idx, p))  # tag with query index for filtering later

    triplets = []
    for idx, row in df.iterrows():
        query = row['query']
        query_passages = row['passages.passage_text']

        for _ in range(10):  # 10 triplets per query
            positive = random.choice(query_passages)

            # Negative must be from a *different* query
            while True:
                neg_query_id, negative = random.choice(all_passages)
                if neg_query_id != idx:
                    break

            triplets.append({'query': query, 'positive': positive, 'negative': negative})

    print(f"  Generated {len(triplets)} triplets.")
    return triplets


class TripletListDataset(Dataset):
    def __init__(self, triplets):
        self.triplets = triplets
    def __len__(self):
        return len(self.triplets)
    def __getitem__(self, idx):
        return self.triplets[idx]

def collate_fn_notebook(batch):
    queries = [item['query'] for item in batch]
    positives = [item['positive'] for item in batch]
    negatives = [item['negative'] for item in batch]
    return {'queries': queries, 'positives': positives, 'negatives': negatives}

# Dataset paths
datasets = {
    'train': '../data/ms_marco_train.parquet',
    'validation': '../data/ms_marco_validation.parquet',
    'test': '../data/ms_marco_test.parquet'
}

# Create dataloaders
dataloaders = {}

print("Creating DataLoaders with robust direct-loading...")
for split, path in datasets.items():
    triplets = load_and_process_parquet(path)
    dataset = TripletListDataset(triplets)
    dataloaders[split] = DataLoader(
        dataset,
        batch_size=20,
        shuffle=(split == 'train'),
        num_workers=0,
        collate_fn=collate_fn_notebook
    )
    print(f"✅ {split} dataloader ready!")

Creating DataLoaders with robust direct-loading...

Processing ../data/ms_marco_train.parquet...
  Found 808731 valid queries.
  Generated 8087310 triplets.
✅ train dataloader ready!

Processing ../data/ms_marco_validation.parquet...
  Found 101093 valid queries.
  Generated 1010930 triplets.
✅ validation dataloader ready!

Processing ../data/ms_marco_test.parquet...
  Found 101092 valid queries.
  Generated 1010920 triplets.
✅ test dataloader ready!


### **Step 2**: Process data to get triplets for train, validation and test datasets

In [13]:
# Test the training dataloader
train_loader = dataloaders['train']

print(f"📊 Dataset Info:")
print(f"  Total batches in train: {len(train_loader)}")
print(f"  Batch size: 20")
print(f"  Triplets per query: 10")

print("\n🔍 Testing first few batches...")
for batch_idx, batch in enumerate(train_loader):
    print(f"\nBatch {batch_idx + 1}:")
    print(f"  Queries: {len(batch['queries'])}")
    print(f"  Positives: {len(batch['positives'])}")
    print(f"  Negatives: {len(batch['negatives'])}")
    
    # Show first triplet in batch
    print(f"\n  📝 Sample Triplet:")
    print(f"    Query: {batch['queries'][0][:100]}...")
    print(f"    Positive: {batch['positives'][0][:100]}...")
    print(f"    Negative: {batch['negatives'][0][:100]}...")
    
    if batch_idx >= 3:  # Just show first 3 batches
        break

print("\n✅ DataLoader test completed!")


📊 Dataset Info:
  Total batches in train: 404366
  Batch size: 20
  Triplets per query: 10

🔍 Testing first few batches...

Batch 1:
  Queries: 20
  Positives: 20
  Negatives: 20

  📝 Sample Triplet:
    Query: symbolic meaning of heart...
    Positive: A heart is a symbol of love. A heart shape is a symbol of love and compassion. The heart has long be...
    Negative: For example if we have x=953, I want to select specifically the first digit (or the last two) and sa...

Batch 2:
  Queries: 20
  Positives: 20
  Negatives: 20

  📝 Sample Triplet:
    Query: is shotacon illegal...
    Positive: Is lolicon and shotacon illegal in the United States? Even though its drawn to make the characters l...
    Negative: The SRB casings were made of half-inch (12.7 mm) thick steel and were much stronger than the orbiter...

Batch 3:
  Queries: 20
  Positives: 20
  Negatives: 20

  📝 Sample Triplet:
    Query: who is adam levine married to...
    Positive: Before Getting Married, Prinsloo and Levin

### **Step 3**: Select sub-sample triplets for training, validation and testing

In [14]:
# Extract triplets from DataLoaders to create train/val/test lists
train_triplets = dataloaders['train'].dataset.triplets
val_triplets = dataloaders['validation'].dataset.triplets
test_triplets = dataloaders['test'].dataset.triplets

# The model expects data as a list of (query, positive_doc, negative_doc) tuples
train_data = [(t['query'], t['positive'], t['negative']) for t in train_triplets]
val_data = [(t['query'], t['positive'], t['negative']) for t in val_triplets]
test_data = [(t['query'], t['positive'], t['negative']) for t in test_triplets]

# Print sample to verify format
print("Sample training triplet:")
print(f"Query: {train_data[0][0][:100]}...")
print(f"Positive: {train_data[0][1][:100]}...")  
print(f"Negative: {train_data[0][2][:100]}...")
print(f"\nDataset sizes:")
print(f"  Training: {len(train_data):,} triplets")
print(f"  Validation: {len(val_data):,} triplets")
print(f"  Test: {len(test_data):,} triplets")

# Use training data for the model
data = train_data

Sample training triplet:
Query: )what was the immediate impact of the success of the manhattan project?...
Positive: The Manhattan Project. This once classified photograph features the first atomic bomb — a weapon tha...
Negative: Muscle twitching, also known as muscle fasciculation, is marked by small muscle contractions in the ...

Dataset sizes:
  Training: 8,087,310 triplets
  Validation: 1,010,930 triplets
  Test: 1,010,920 triplets


### **Step 4**: Select sub-sample triplets for training, validation and testing

In [15]:
# --- Tokenizer and Vocab ---
import pickle
import numpy as np
from collections import defaultdict
from itertools import chain

class PretrainedTokenizer:
    def __init__(self, word_to_idx_path):
        # Load pretrained word_to_idx mapping
        with open(word_to_idx_path, 'rb') as f:
            self.word2idx = pickle.load(f)
        
        print(f"Loaded vocabulary with {len(self.word2idx):,} tokens")

    def encode(self, sentence):
        # Only include words that exist in vocabulary, skip unknown words
        return [self.word2idx[word.lower()] for word in sentence.split() if word.lower() in self.word2idx]

    def vocab_size(self):
        return len(self.word2idx)


In [16]:
# Load pretrained tokenizer
tokenizer = PretrainedTokenizer('../data/word_to_idx.pkl')

Loaded vocabulary with 400,000 tokens


In [17]:
# --- Dataset Class ---
class TripletDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        query, pos_doc, neg_doc = self.data[idx]
        return (torch.tensor(self.tokenizer.encode(query), dtype=torch.long),
                torch.tensor(self.tokenizer.encode(pos_doc), dtype=torch.long),
                torch.tensor(self.tokenizer.encode(neg_doc), dtype=torch.long))

In [18]:
# --- Collate Function ---
def collate_fn(batch):
    queries, pos_docs, neg_docs = zip(*batch)
    return (
        pad_sequence(queries, batch_first=True),
        pad_sequence(pos_docs, batch_first=True),
        pad_sequence(neg_docs, batch_first=True)
    )

In [20]:
import torch
import torch.nn as nn

# --- Dual RNN Encoder Model ---
class RNNEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, pretrained_embeddings=None):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        
        # Load pretrained embeddings if provided
        if pretrained_embeddings is not None:
            self.embedding.weight.data.copy_(torch.from_numpy(pretrained_embeddings))
            # Keep embeddings trainable (they are by default)
            
        self.rnn = nn.GRU(embed_dim, hidden_dim, batch_first=True)

    def forward(self, x):
        x = self.embedding(x)
        _, h_n = self.rnn(x)
        return h_n.squeeze(0)  # shape: (batch, hidden_dim)

In [21]:
# --- Triplet Loss Function ---
def triplet_loss_function(triplet, distance_function, margin):
    query, pos_doc, neg_doc = triplet
    d_pos = distance_function(query, pos_doc)
    d_neg = distance_function(query, neg_doc)
    return torch.clamp(d_pos - d_neg + margin, min=0.0).mean()


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import random
import json
import numpy as np

with open('../backend/config.json', 'r') as f:
    config = json.load(f)

# --- Training Setup ---
VOCAB_SIZE = tokenizer.vocab_size()

# Load pretrained embeddings
pretrained_embeddings = np.load('../data/embeddings.npy')
EMBED_DIM = pretrained_embeddings.shape[1]  # Get embedding dimension from loaded embeddings

print(f"Loaded pretrained embeddings: {pretrained_embeddings.shape}")
print(f"Vocabulary size: {VOCAB_SIZE}")
print(f"Embedding dimension: {EMBED_DIM}")

# Check if MPS is available and set device
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize encoders with pretrained embeddings and move to GPU
query_encoder = RNNEncoder(VOCAB_SIZE, EMBED_DIM, config['HIDDEN_DIM'], pretrained_embeddings).to(device)
doc_encoder = RNNEncoder(VOCAB_SIZE, EMBED_DIM, config['HIDDEN_DIM'], pretrained_embeddings).to(device)

optimizer = torch.optim.Adam(list(query_encoder.parameters()) + list(doc_encoder.parameters()), lr=config['LR'])

# CRITICAL FIX: Increase batch size dramatically for much faster training
dataset = TripletDataset(data, tokenizer)
dataloader = DataLoader(dataset, batch_size=config['BATCH_SIZE'], shuffle=True, collate_fn=collate_fn, 
                       num_workers=2, pin_memory=True if device.type == 'mps' else False)

Loaded pretrained embeddings: (400000, 200)
Vocabulary size: 400000
Embedding dimension: 200
Using device: cpu


NameError: name 'config' is not defined

In [None]:
# --- Training Loop ---
import time

print("🚀 Starting training...")
start_time = time.time()

for epoch in range(config['EPOCHS']):
    epoch_start = time.time()
    total_loss = 0
    num_batches = 0
    
    for query_batch, pos_batch, neg_batch in dataloader:
        # Move tensors to GPU
        query_batch = query_batch.to(device)
        pos_batch = pos_batch.to(device)
        neg_batch = neg_batch.to(device)
        
        q_vec = query_encoder(query_batch)
        pos_vec = doc_encoder(pos_batch)
        neg_vec = doc_encoder(neg_batch)

        loss = triplet_loss_function((q_vec, pos_vec, neg_vec), F.pairwise_distance, config['MARGIN'])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1
        
        # Progress indicator every 50 batches
        if num_batches % 50 == 0:
            print(f"  Batch {num_batches}/{len(dataloader)}, Loss: {loss.item():.4f}")
    
    epoch_time = time.time() - epoch_start
    avg_loss = total_loss / num_batches
    print(f"Epoch {epoch+1}/{config['EPOCHS']}, Avg Loss: {avg_loss:.4f}, Time: {epoch_time:.1f}s")

total_time = time.time() - start_time
print(f"\n✅ Training completed! Total time: {total_time/60:.1f} minutes")



🚀 Starting training...


In [None]:

# --- Automatic Model Saving ---
import os
import json
from datetime import datetime

# Create artifacts directory
artifacts_dir = "../artifacts"
os.makedirs(artifacts_dir, exist_ok=True)

# Create timestamped run directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = os.path.join(artifacts_dir, f"two_tower_run_{timestamp}")
os.makedirs(run_dir, exist_ok=True)

print(f"💾 Saving artifacts to: {run_dir}")

# Save model state dictionaries
torch.save({
    'query_encoder_state_dict': query_encoder.state_dict(),
    'doc_encoder_state_dict': doc_encoder.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': config['EPOCHS'],
    'final_loss': avg_loss
}, os.path.join(run_dir, 'model_checkpoint.pth'))

# Save model architectures (for easy loading later)
torch.save(query_encoder, os.path.join(run_dir, 'query_encoder_full.pth'))
torch.save(doc_encoder, os.path.join(run_dir, 'doc_encoder_full.pth'))

# Save training configuration
training_config = {
    'model_config': {
        'vocab_size': VOCAB_SIZE,
        'embed_dim': EMBED_DIM,
        'hidden_dim': config['HIDDEN_DIM'],
        'margin': config['MARGIN']
    },
    'training_config': {
        'epochs': config['EPOCHS'],
        'batch_size': config['BATCH_SIZE'],
        'learning_rate': config['LR'],
        'device': str(device)
    },
    'data_config': {
        'train_samples': len(train_data),
        'val_samples': len(val_data),
        'test_samples': len(test_data),
        'total_triplets': len(train_data) + len(val_data) + len(test_data)
    },
    'training_results': {
        'final_avg_loss': avg_loss
    }
}

with open(os.path.join(run_dir, 'training_config.json'), 'w') as f:
    json.dump(training_config, f, indent=2)

# Save tokenizer (copy the word2idx file or save the object)
import shutil
if os.path.exists(config['WORD_TO_IDX_PATH']):
    shutil.copy2(config['WORD_TO_IDX_PATH'], os.path.join(run_dir, 'word_to_idx.pkl'))

print(f"✅ Saved artifacts:")
print(f"  📁 Directory: {run_dir}")
print(f"  🧠 Models: model_checkpoint.pth, *_encoder_full.pth")
print(f"  ⚙️  Config: training_config.json")
print(f"  📝 Tokenizer: word_to_idx.pkl")

In [15]:
# --- MPS-Optimized Inference Function ---
def search(query_text, documents, tokenizer, query_encoder, doc_encoder):
    """
    Search function optimized for MPS (Apple Silicon GPU)
    """
    # Get device from encoder
    device = next(query_encoder.parameters()).device
    
    with torch.no_grad():
        # Encode query and move to device
        query_tensor = pad_sequence([torch.tensor(tokenizer.encode(query_text), dtype=torch.long)], batch_first=True).to(device)
        query_vec = query_encoder(query_tensor)

        # Encode documents and move to device
        doc_tensors = pad_sequence([torch.tensor(tokenizer.encode(doc), dtype=torch.long) for doc in documents], batch_first=True).to(device)
        doc_vecs = doc_encoder(doc_tensors)

        # Calculate similarity scores
        scores = F.cosine_similarity(query_vec, doc_vecs)
        top_indices = torch.argsort(scores, descending=True)
        
        # Convert results back to CPU for return (if needed)
        results = [(documents[i], scores[i].item()) for i in top_indices]
        
        # Clear MPS cache after inference if available
        if device.type == 'mps' and hasattr(torch.mps, 'empty_cache'):
            torch.mps.empty_cache()
            
        return results

In [None]:
# --- Comprehensive Testing with Real Data ---
import random
from collections import defaultdict

def evaluate_retrieval(test_data, query_encoder, doc_encoder, tokenizer, k=10):
    """
    Evaluate retrieval performance using real test data
    """
    print("🔍 COMPREHENSIVE RETRIEVAL EVALUATION")
    print("="*50)
    
    # Group test data by query to get all relevant docs per query
    query_to_docs = defaultdict(list)
    for query, pos_doc, neg_doc in test_data[:100]:  # Sample 100 for speed
        query_to_docs[query].extend([pos_doc, neg_doc])
    
    # Test multiple queries
    sample_queries = list(query_to_docs.keys())[:5]  # Test 5 queries
    
    for i, query in enumerate(sample_queries):
        print(f"\n🔎 TEST QUERY {i+1}: {query[:100]}...")
        print("-" * 60)
        
        # Get all documents for this query
        documents = query_to_docs[query]
        
        # Add some random documents from other queries for harder test
        other_docs = []
        for other_query in random.sample(list(query_to_docs.keys()), 3):
            if other_query != query:
                other_docs.extend(query_to_docs[other_query][:2])
        
        all_documents = documents + other_docs
        random.shuffle(all_documents)
        
        print(f"📚 Searching through {len(all_documents)} documents...")
        
        # Run search
        results = search(query, all_documents, tokenizer, query_encoder, doc_encoder)
        
        print(f"\n🏆 TOP {min(3, len(results))} RESULTS:")
        for j, (doc, score) in enumerate(results[:3]):
            relevance = "✅ RELEVANT" if doc in documents else "❌ NOT RELEVANT"
            print(f"{j+1}. Score: {score:.4f} {relevance}")
            print(f"   Doc: {doc[:80]}...")
            print()

# Run comprehensive evaluation
evaluate_retrieval(test_data, query_encoder, doc_encoder, tokenizer)