### **Step 1**: Import configs and hyperparameters

In [1]:
import pandas as pd
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import random

# Load config
with open('../backend/config.json', 'r') as f:
    config = json.load(f)

# Dataset paths
datasets = {
    'train': '../data/ms_marco_train.parquet',
    'validation': '../data/ms_marco_validation.parquet', 
    'test': '../data/ms_marco_test.parquet'
}

### **Step 2**: Process data to get triplets for train, validation and test datasets

In [2]:
import pandas as pd
import json
from backend.data_processing import flatten_data, add_negative_samples_fast, to_triplets, filter_valid_data, convert_to_training_format


# Load config
with open('../backend/config.json', 'r') as f:
    config = json.load(f)

# Dataset paths
datasets = {
    'train': '../data/ms_marco_train.parquet',
    'validation': '../data/ms_marco_validation.parquet', 
    'test': '../data/ms_marco_test.parquet'
}

# Sampling configuration - modify these as needed
print(f"📊 SAMPLING CONFIGURATION:")
print(f"  Total samples to process: {config['TOTAL_SAMPLES']:,}")
print(f"  Train split: {config['TRAIN_SPLIT']*100:.0f}% ({int(config['TOTAL_SAMPLES']*config['TRAIN_SPLIT']):,} samples)")
print(f"  Test split: {config['TEST_SPLIT']*100:.0f}% ({int(config['TOTAL_SAMPLES']*config['TEST_SPLIT']):,} samples)")
print(f"  Validation split: {(1-config['TRAIN_SPLIT']-config['TEST_SPLIT'])*100:.0f}% ({int(config['TOTAL_SAMPLES']*(1-config['TRAIN_SPLIT']-config['TEST_SPLIT'])):,} samples)")

results = {}

print("\nProcessing datasets to triplet format...")
print("="*50)

# Calculate samples per dataset
samples_per_dataset = {
    'train': int(config['TOTAL_SAMPLES'] * config['TRAIN_SPLIT']),
    'test': int(config['TOTAL_SAMPLES'] * config['TEST_SPLIT']),
    'validation': int(config['TOTAL_SAMPLES'] * (1 - config['TRAIN_SPLIT'] - config['TEST_SPLIT']))
}

for name, input_path in datasets.items():
    target_samples = samples_per_dataset[name]
    print(f"\n📁 Processing {name.upper()} dataset (target: {target_samples:,} samples)...")
    print(f"Loading: {input_path}")
    
    # Step 1: Load data
    df = pd.read_parquet(input_path, engine='fastparquet')
    print(f"  Loaded: {len(df):,} samples")
    
    # Step 2: Early sampling - cut here to save processing time
    if len(df) > target_samples:
        df = df.sample(n=target_samples, random_state=42).reset_index(drop=True)
        print(f"  ✂️ Sampled down to: {len(df):,} samples")
    
    # Step 3: Filter valid data
    df_filtered = filter_valid_data(df)
    print(f"  After filtering: {len(df_filtered):,} samples")
    
    # Step 4: Flatten data (nested passages to flat rows)
    print("  🔄 Flattening data...")
    flattened = flatten_data(df_filtered)
    print(f"  Flattened: {len(flattened):,} rows")
    
    # Step 5: Add negative samples
    print("  ➕ Adding negative samples...")
    with_negatives = add_negative_samples_fast(flattened)
    print(f"  With negatives: {len(with_negatives):,} rows")
    print(f"    - Positive: {sum(with_negatives['passage_sign_de'] == 1):,}")
    print(f"    - Negative: {sum(with_negatives['passage_sign_de'] == 0):,}")
    
    # Step 6: Convert to triplets
    print("  🔄 Converting to triplets...")
    triplets = to_triplets(with_negatives, triplets_per_query=10)
    print(f"  Final triplets: {len(triplets):,}")
    print(f"  Unique queries: {triplets['query'].nunique()}")
    
    # Store result
    results[name] = triplets
    print(f"  ✅ {name.upper()} datasetcompleted!")

print("\n" + "="*50)
print("FINAL SUMMARY")
print("="*50)
total_triplets = 0
for name, triplets_df in results.items():
    triplets_count = len(triplets_df)
    total_triplets += triplets_count
    print(f"{name.upper()}: {triplets_count:,} triplets, {triplets_df['query'].nunique():,} unique queries")

print(f"\n🎯 TOTAL TRIPLETS: {total_triplets:,}")

print("\n🎯 Sample triplet from train dataset:")
if 'train' in results and len(results['train']) > 0:
    sample = results['train'].iloc[0]
    print(f"Query: {sample['query'][:80]}...")
    print(f"Positive: {sample['positive_example'][:80]}...")
    print(f"Negative: {sample['negative_example'][:80]}...")

print("\n✅ All datasets processed! Results stored in 'results' dictionary.")
print("Access with: results['train'], results['validation'], results['test']")


from backend.data_processing import convert_to_training_format

# Convert processed results to training format (no subsampling needed - already done!)
train_data = convert_to_training_format(results['train'])
val_data = convert_to_training_format(results['validation']) 
test_data = convert_to_training_format(results['test'])

# Print sample to verify format
print("Sample training triplet:")
print(f"Query: {train_data[0][0][:100]}...")
print(f"Positive: {train_data[0][1][:100]}...")  
print(f"Negative: {train_data[0][2][:100]}...")
print(f"\nDataset sizes:")
print(f"  Training: {len(train_data):,} triplets")
print(f"  Validation: {len(val_data):,} triplets")
print(f"  Test: {len(test_data):,} triplets")

# Use training data for the model
data = train_data


📊 SAMPLING CONFIGURATION:
  Total samples to process: 1,000
  Train split: 70% (700 samples)
  Test split: 20% (200 samples)
  Validation split: 10% (100 samples)

Processing datasets to triplet format...

📁 Processing TRAIN dataset (target: 700 samples)...
Loading: ../data/ms_marco_train.parquet
  Loaded: 808,731 samples
  ✂️ Sampled down to: 700 samples
  After filtering: 441 samples
  🔄 Flattening data...
  Flattened: 4,394 rows
  ➕ Adding negative samples...
  With negatives: 8,804 rows
    - Positive: 4,394
    - Negative: 4,410
  🔄 Converting to triplets...
  Final triplets: 4,410
  Unique queries: 441
  ✅ TRAIN datasetcompleted!

📁 Processing VALIDATION dataset (target: 100 samples)...
Loading: ../data/ms_marco_validation.parquet
  Loaded: 101,093 samples
  ✂️ Sampled down to: 100 samples
  After filtering: 60 samples
  🔄 Flattening data...
  Flattened: 598 rows
  ➕ Adding negative samples...
  With negatives: 1,198 rows
    - Positive: 598
    - Negative: 600
  🔄 Converting to 

### **Step 3**: Select sub-sample triplets for training, validation and testing

In [3]:
from backend.data_processing import convert_to_training_format

# Convert processed results to training format (no subsampling needed - already done!)
train_data = convert_to_training_format(results['train'])
val_data = convert_to_training_format(results['validation']) 
test_data = convert_to_training_format(results['test'])

# Print sample to verify format
print("Sample training triplet:")
print(f"Query: {train_data[0][0][:100]}...")
print(f"Positive: {train_data[0][1][:100]}...")  
print(f"Negative: {train_data[0][2][:100]}...")
print(f"\nDataset sizes:")
print(f"  Training: {len(train_data):,} triplets")
print(f"  Validation: {len(val_data):,} triplets")
print(f"  Test: {len(test_data):,} triplets")

# Use training data for the model
data = train_data

Sample training triplet:
Query: Causes of Retained Placenta...
Positive: In humans, retained placenta is generally defined as a placenta that has not undergone placental exp...
Negative: Function of neurons. The central nervous system [CNS] is composed entirely of two kinds of specializ...

Dataset sizes:
  Training: 4,410 triplets
  Validation: 600 triplets
  Test: 2,000 triplets


### **Step 4**: Select sub-sample triplets for training, validation and testing

In [4]:
# --- Tokenizer and Vocab ---
import pickle
import numpy as np
from collections import defaultdict
from itertools import chain

class PretrainedTokenizer:
    def __init__(self, word_to_idx_path):
        # Load pretrained word_to_idx mapping
        with open(word_to_idx_path, 'rb') as f:
            self.word2idx = pickle.load(f)
        
        print(f"Loaded vocabulary with {len(self.word2idx):,} tokens")

    def encode(self, sentence):
        # Only include words that exist in vocabulary, skip unknown words
        return [self.word2idx[word.lower()] for word in sentence.split() if word.lower() in self.word2idx]

    def vocab_size(self):
        return len(self.word2idx)


In [5]:
# Load pretrained tokenizer
tokenizer = PretrainedTokenizer('../data/word_to_idx.pkl')

Loaded vocabulary with 400,000 tokens


In [6]:
# --- Dataset Class ---
class TripletDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        query, pos_doc, neg_doc = self.data[idx]
        return (torch.tensor(self.tokenizer.encode(query), dtype=torch.long),
                torch.tensor(self.tokenizer.encode(pos_doc), dtype=torch.long),
                torch.tensor(self.tokenizer.encode(neg_doc), dtype=torch.long))

In [7]:
# --- Collate Function ---
def collate_fn(batch):
    queries, pos_docs, neg_docs = zip(*batch)
    return (
        pad_sequence(queries, batch_first=True),
        pad_sequence(pos_docs, batch_first=True),
        pad_sequence(neg_docs, batch_first=True)
    )

In [8]:
# --- Dual RNN Encoder Model ---
class RNNEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, pretrained_embeddings=None):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        
        # Load pretrained embeddings if provided
        if pretrained_embeddings is not None:
            self.embedding.weight.data.copy_(torch.from_numpy(pretrained_embeddings))
            # Keep embeddings trainable (they are by default)
            
        self.rnn = nn.GRU(embed_dim, hidden_dim, batch_first=True)

    def forward(self, x):
        x = self.embedding(x)
        _, h_n = self.rnn(x)
        return h_n.squeeze(0)  # shape: (batch, hidden_dim)

In [9]:
# --- Triplet Loss Function ---
def triplet_loss_function(triplet, distance_function, margin):
    query, pos_doc, neg_doc = triplet
    d_pos = distance_function(query, pos_doc)
    d_neg = distance_function(query, neg_doc)
    return torch.clamp(d_pos - d_neg + margin, min=0.0).mean()


In [11]:
# --- Training Setup ---
VOCAB_SIZE = tokenizer.vocab_size()

# Load pretrained embeddings
pretrained_embeddings = np.load('../data/embeddings.npy')
EMBED_DIM = pretrained_embeddings.shape[1]  # Get embedding dimension from loaded embeddings

print(f"Loaded pretrained embeddings: {pretrained_embeddings.shape}")
print(f"Vocabulary size: {VOCAB_SIZE}")
print(f"Embedding dimension: {EMBED_DIM}")

# Check if MPS is available and set device
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize encoders with pretrained embeddings and move to GPU
query_encoder = RNNEncoder(VOCAB_SIZE, EMBED_DIM, config['HIDDEN_DIM'], pretrained_embeddings).to(device)
doc_encoder = RNNEncoder(VOCAB_SIZE, EMBED_DIM, config['HIDDEN_DIM'], pretrained_embeddings).to(device)

optimizer = torch.optim.Adam(list(query_encoder.parameters()) + list(doc_encoder.parameters()), lr=config['LR'])

# CRITICAL FIX: Increase batch size dramatically for much faster training
dataset = TripletDataset(data, tokenizer)
dataloader = DataLoader(dataset, batch_size=config['BATCH_SIZE'], shuffle=True, collate_fn=collate_fn, 
                       num_workers=2, pin_memory=True if device.type == 'mps' else False)

Loaded pretrained embeddings: (400000, 200)
Vocabulary size: 400000
Embedding dimension: 200

🔍 DEVICE DETECTION:
PyTorch version: 2.1.2
MPS available: False
MPS built: False

⚠️  MPS not available, using CPU. If you're on a Mac with Apple Silicon, your dev container's PyTorch might not have MPS support.

🎯 Final device: cpu

🏗️  Initializing models...
✅ Models moved to device

📊 Setting up DataLoader...
✅ DataLoader configured for CPU
📈 Batch size: 64
📊 Total batches: 69


In [None]:
# --- Training Loop ---
import time

print("🚀 Starting training...")
start_time = time.time()

for epoch in range(config['EPOCHS']):
    epoch_start = time.time()
    total_loss = 0
    num_batches = 0
    
    for query_batch, pos_batch, neg_batch in dataloader:
        # Move tensors to GPU
        query_batch = query_batch.to(device)
        pos_batch = pos_batch.to(device)
        neg_batch = neg_batch.to(device)
        
        q_vec = query_encoder(query_batch)
        pos_vec = doc_encoder(pos_batch)
        neg_vec = doc_encoder(neg_batch)

        loss = triplet_loss_function((q_vec, pos_vec, neg_vec), F.pairwise_distance, config['MARGIN'])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1
        
        # Progress indicator every 50 batches
        if num_batches % 50 == 0:
            print(f"  Batch {num_batches}/{len(dataloader)}, Loss: {loss.item():.4f}")
    
    epoch_time = time.time() - epoch_start
    avg_loss = total_loss / num_batches
    print(f"Epoch {epoch+1}/{config['EPOCHS']}, Avg Loss: {avg_loss:.4f}, Time: {epoch_time:.1f}s")

total_time = time.time() - start_time
print(f"\n✅ Training completed! Total time: {total_time/60:.1f} minutes")



🚀 Starting training...


In [None]:

# --- Automatic Model Saving ---
import os
import json
from datetime import datetime

# Create artifacts directory
artifacts_dir = "../artifacts"
os.makedirs(artifacts_dir, exist_ok=True)

# Create timestamped run directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = os.path.join(artifacts_dir, f"two_tower_run_{timestamp}")
os.makedirs(run_dir, exist_ok=True)

print(f"💾 Saving artifacts to: {run_dir}")

# Save model state dictionaries
torch.save({
    'query_encoder_state_dict': query_encoder.state_dict(),
    'doc_encoder_state_dict': doc_encoder.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': config['EPOCHS'],
    'final_loss': avg_loss
}, os.path.join(run_dir, 'model_checkpoint.pth'))

# Save model architectures (for easy loading later)
torch.save(query_encoder, os.path.join(run_dir, 'query_encoder_full.pth'))
torch.save(doc_encoder, os.path.join(run_dir, 'doc_encoder_full.pth'))

# Save training configuration
training_config = {
    'model_config': {
        'vocab_size': VOCAB_SIZE,
        'embed_dim': EMBED_DIM,
        'hidden_dim': config['HIDDEN_DIM'],
        'margin': config['MARGIN']
    },
    'training_config': {
        'epochs': config['EPOCHS'],
        'batch_size': config['BATCH_SIZE'],
        'learning_rate': config['LR'],
        'device': str(device)
    },
    'data_config': {
        'train_samples': len(train_data),
        'val_samples': len(val_data),
        'test_samples': len(test_data),
        'total_triplets': len(train_data) + len(val_data) + len(test_data)
    },
    'training_results': {
        'final_avg_loss': avg_loss
    }
}

with open(os.path.join(run_dir, 'training_config.json'), 'w') as f:
    json.dump(training_config, f, indent=2)

# Save tokenizer (copy the word2idx file or save the object)
import shutil
if os.path.exists(config['WORD_TO_IDX_PATH']):
    shutil.copy2(config['WORD_TO_IDX_PATH'], os.path.join(run_dir, 'word_to_idx.pkl'))

print(f"✅ Saved artifacts:")
print(f"  📁 Directory: {run_dir}")
print(f"  🧠 Models: model_checkpoint.pth, *_encoder_full.pth")
print(f"  ⚙️  Config: training_config.json")
print(f"  📝 Tokenizer: word_to_idx.pkl")

In [15]:
# --- MPS-Optimized Inference Function ---
def search(query_text, documents, tokenizer, query_encoder, doc_encoder):
    """
    Search function optimized for MPS (Apple Silicon GPU)
    """
    # Get device from encoder
    device = next(query_encoder.parameters()).device
    
    with torch.no_grad():
        # Encode query and move to device
        query_tensor = pad_sequence([torch.tensor(tokenizer.encode(query_text), dtype=torch.long)], batch_first=True).to(device)
        query_vec = query_encoder(query_tensor)

        # Encode documents and move to device
        doc_tensors = pad_sequence([torch.tensor(tokenizer.encode(doc), dtype=torch.long) for doc in documents], batch_first=True).to(device)
        doc_vecs = doc_encoder(doc_tensors)

        # Calculate similarity scores
        scores = F.cosine_similarity(query_vec, doc_vecs)
        top_indices = torch.argsort(scores, descending=True)
        
        # Convert results back to CPU for return (if needed)
        results = [(documents[i], scores[i].item()) for i in top_indices]
        
        # Clear MPS cache after inference if available
        if device.type == 'mps' and hasattr(torch.mps, 'empty_cache'):
            torch.mps.empty_cache()
            
        return results

In [None]:
# --- Comprehensive Testing with Real Data ---
import random
from collections import defaultdict

def evaluate_retrieval(test_data, query_encoder, doc_encoder, tokenizer, k=10):
    """
    Evaluate retrieval performance using real test data
    """
    print("🔍 COMPREHENSIVE RETRIEVAL EVALUATION")
    print("="*50)
    
    # Group test data by query to get all relevant docs per query
    query_to_docs = defaultdict(list)
    for query, pos_doc, neg_doc in test_data[:100]:  # Sample 100 for speed
        query_to_docs[query].extend([pos_doc, neg_doc])
    
    # Test multiple queries
    sample_queries = list(query_to_docs.keys())[:5]  # Test 5 queries
    
    for i, query in enumerate(sample_queries):
        print(f"\n🔎 TEST QUERY {i+1}: {query[:100]}...")
        print("-" * 60)
        
        # Get all documents for this query
        documents = query_to_docs[query]
        
        # Add some random documents from other queries for harder test
        other_docs = []
        for other_query in random.sample(list(query_to_docs.keys()), 3):
            if other_query != query:
                other_docs.extend(query_to_docs[other_query][:2])
        
        all_documents = documents + other_docs
        random.shuffle(all_documents)
        
        print(f"📚 Searching through {len(all_documents)} documents...")
        
        # Run search
        results = search(query, all_documents, tokenizer, query_encoder, doc_encoder)
        
        print(f"\n🏆 TOP {min(3, len(results))} RESULTS:")
        for j, (doc, score) in enumerate(results[:3]):
            relevance = "✅ RELEVANT" if doc in documents else "❌ NOT RELEVANT"
            print(f"{j+1}. Score: {score:.4f} {relevance}")
            print(f"   Doc: {doc[:80]}...")
            print()

# Run comprehensive evaluation
evaluate_retrieval(test_data, query_encoder, doc_encoder, tokenizer)