# Deep Learning Methods for COVID-QA Medical Retrieval

## DPR and SPLADE for Medical Domain Question Answering

This notebook implements advanced deep learning retrieval methods for medical domain question answering using the COVID-QA dataset:

1. **Load COVID-QA Dataset** - Medical literature Q&A pairs
2. **Implement DPR** - Dense Passage Retrieval with dual encoders
3. **Implement SPLADE** - Sparse Lexical And Expansion model
4. **Define Evaluation Framework** - Retrieval metrics for medical domain
5. **Run and Compare Methods** - Performance analysis

### **Why DPR and SPLADE for Medical Domain?**

- **DPR**: Dense neural retrieval captures semantic similarity between medical questions and literature, handling synonyms and paraphrases common in medical terminology
- **SPLADE**: Neural sparse retrieval maintains interpretability while expanding medical terms, crucial for expert validation in medical applications

### **Focus**: Retrieval effectiveness for medical domain questions, not extractive QA

In [4]:
#use wandb for managing experiments
!pip install wandb
import wandb
wandb.login()




True

In [5]:
import os
if not os.path.exists('data'):
  os.mkdir('data')

In [6]:
# Step 1: Import Libraries and Load COVID-QA Dataset with Proper Splits

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple, Dict, Optional
import warnings
warnings.filterwarnings('ignore')

# Core deep learning libraries
try:
    from transformers import (
        AutoTokenizer, AutoModel,
        BertTokenizer, BertModel,
        TrainingArguments, Trainer
    )
    from sentence_transformers import SentenceTransformer
    print("✅ Transformers and Sentence Transformers loaded successfully!")
    print(f"   🔥 PyTorch version: {torch.__version__}")
    print(f"   🤗 Using device: {'GPU' if torch.cuda.is_available() else 'CPU'}")
    DL_LIBRARIES_AVAILABLE = True
except ImportError:
    print("⚠️ Deep learning libraries not available. Please install:")
    print("   pip install torch transformers sentence-transformers")
    DL_LIBRARIES_AVAILABLE = False

# Traditional libraries for comparison
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.model_selection import train_test_split
from collections import Counter
import re
import json

# Load COVID-QA Dataset with Proper Splits
print("📁 Loading COVID-QA dataset with proper train/validation/test splits...")

try:
    # Load dataset files
    community_df = pd.read_csv('data/community.csv')
    news_df = pd.read_csv('data/news.csv')
    multilingual_df = pd.read_csv('data/multilingual.csv')

    # Clean and process each dataset
    def clean_dataframe(df, source_name):
        df_clean = df.copy()
        df_clean['dataset_source'] = source_name
        df_clean['question'] = df_clean['question'].fillna('').astype(str)
        df_clean['answer'] = df_clean['answer'].fillna('').astype(str)
        return df_clean[['question', 'answer', 'dataset_source']]

    # Process all datasets
    community_clean = clean_dataframe(community_df, 'community')
    news_clean = clean_dataframe(news_df, 'news')
    multilingual_clean = clean_dataframe(multilingual_df, 'multilingual')

    # Combine datasets
    df_combined = pd.concat([community_clean, news_clean, multilingual_clean], ignore_index=True)

    # Remove empty entries
    df_combined = df_combined[(df_combined['question'].str.len() > 0) & (df_combined['answer'].str.len() > 0)]

    print(f"✅ Successfully loaded COVID-QA dataset!")
    print(f"   • Community: {len(community_clean):,} QA pairs")
    print(f"   • News: {len(news_clean):,} QA pairs")
    print(f"   • Multilingual: {len(multilingual_clean):,} QA pairs")
    print(f"   • Total combined: {len(df_combined):,} QA pairs")

    # Create proper train/validation/test splits
    print(f"\n🔄 Creating train/validation/test splits...")

    # First split: 80% train, 20% temp
    train_df, temp_df = train_test_split(
        df_combined,
        test_size=0.2,
        random_state=42,
        stratify=df_combined['dataset_source']
    )

    # Second split: 10% validation, 10% test (from the 20% temp)
    val_df, test_df = train_test_split(
        temp_df,
        test_size=0.5,
        random_state=42,
        stratify=temp_df['dataset_source']
    )

    print(f"   • Training set: {len(train_df):,} QA pairs (80%)")
    print(f"   • Validation set: {len(val_df):,} QA pairs (10%)")
    print(f"   • Test set: {len(test_df):,} QA pairs (10%)")

    # Show dataset source distribution
    print(f"\n📊 Dataset source distribution:")
    for split_name, split_df in [("Train", train_df), ("Validation", val_df), ("Test", test_df)]:
        source_counts = split_df['dataset_source'].value_counts()
        print(f"   {split_name}:")
        for source, count in source_counts.items():
            print(f"      • {source}: {count:,} ({count/len(split_df)*100:.1f}%)")

    # Extract data for each split
    train_questions = train_df['question'].tolist()
    train_answers = train_df['answer'].tolist()

    val_questions = val_df['question'].tolist()
    val_answers = val_df['answer'].tolist()

    test_questions = test_df['question'].tolist()
    test_answers = test_df['answer'].tolist()

    # Dataset statistics
    print(f"\n📈 Dataset Statistics:")
    print(f"   Training set:")
    print(f"      • Average question length: {train_df['question'].str.len().mean():.1f} chars")
    print(f"      • Average answer length: {train_df['answer'].str.len().mean():.1f} chars")
    print(f"   Test set:")
    print(f"      • Average question length: {test_df['question'].str.len().mean():.1f} chars")
    print(f"      • Average answer length: {test_df['answer'].str.len().mean():.1f} chars")

    # Show sample from each split
    print(f"\n📝 Sample QA pairs:")
    print(f"   Training sample:")
    sample_train = train_df.iloc[0]
    print(f"      Q: {sample_train['question'][:100]}...")
    print(f"      A: {sample_train['answer'][:100]}...")

    print(f"   Test sample:")
    sample_test = test_df.iloc[0]
    print(f"      Q: {sample_test['question'][:100]}...")
    print(f"      A: {sample_test['answer'][:100]}...")

    # Create corpus for negative sampling (using training answers only)
    corpus_for_negatives = train_answers.copy()
    print(f"\n🎯 Corpus for negative sampling: {len(corpus_for_negatives):,} passages")

except FileNotFoundError:
    print("❌ COVID-QA data files not found!")
    print("Please ensure these files exist in the 'data/' directory:")
    print("   • community.csv")
    print("   • news.csv")
    print("   • multilingual.csv")
    raise FileNotFoundError


print(f"\n🎯 Dataset Loading Complete!")
print(f"📊 Ready for proper train/validation/test evaluation of deep learning methods!")
print(f"   • Training data: {len(train_questions):,} QA pairs")
print(f"   • Validation data: {len(val_questions):,} QA pairs")
print(f"   • Test data: {len(test_questions):,} QA pairs")
print(f"   • No data leakage: Test set completely separate from training")

✅ Transformers and Sentence Transformers loaded successfully!
   🔥 PyTorch version: 2.6.0+cu124
   🤗 Using device: CPU
📁 Loading COVID-QA dataset with proper train/validation/test splits...
✅ Successfully loaded COVID-QA dataset!
   • Community: 642 QA pairs
   • News: 481 QA pairs
   • Multilingual: 888 QA pairs
   • Total combined: 2,008 QA pairs

🔄 Creating train/validation/test splits...
   • Training set: 1,606 QA pairs (80%)
   • Validation set: 201 QA pairs (10%)
   • Test set: 201 QA pairs (10%)

📊 Dataset source distribution:
   Train:
      • multilingual: 709 (44.1%)
      • community: 514 (32.0%)
      • news: 383 (23.8%)
   Validation:
      • multilingual: 89 (44.3%)
      • community: 64 (31.8%)
      • news: 48 (23.9%)
   Test:
      • multilingual: 89 (44.3%)
      • community: 64 (31.8%)
      • news: 48 (23.9%)

📈 Dataset Statistics:
   Training set:
      • Average question length: 325.9 chars
      • Average answer length: 1057.0 chars
   Test set:
      • Average 

In [7]:
# Step 2: Implement Dense Passage Retrieval (DPR) with Training

import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
import random
from tqdm import tqdm
# Remove unused imports for BM25
# from sklearn.feature_extraction.text import TfidfVectorizer
# from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import os # Import os for checkpointing

class DPREncoder(nn.Module):
    """
    DPR Encoder (can be used for both questions and passages)
    Based on BERT with projection head
    """
    def __init__(self, model_name='bert-base-uncased', projection_dim=768):
        super(DPREncoder, self).__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.projection = nn.Linear(self.bert.config.hidden_size, projection_dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        # Use [CLS] token representation
        pooled_output = outputs.last_hidden_state[:, 0, :]  # [batch_size, hidden_size]
        projected = self.projection(self.dropout(pooled_output))
        return projected

class DPRTrainer:
    """
    Dense Passage Retrieval trainer with sophisticated negative sampling
    Implements the full training strategy from Karpukhin et al. 2020
    """

    def __init__(self, model_name='bert-base-uncased', projection_dim=768, learning_rate=2e-5):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Dual encoders for questions and passages
        self.question_encoder = DPREncoder(model_name, projection_dim).to(self.device)
        self.passage_encoder = DPREncoder(model_name, projection_dim).to(self.device)

        # Tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # Optimizer
        self.optimizer = optim.AdamW(
            list(self.question_encoder.parameters()) + list(self.passage_encoder.parameters()),
            lr=learning_rate
        )

        self.projection_dim = projection_dim

    def encode_batch(self, texts, encoder, max_length=512):
        """
        Encode a batch of texts using the specified encoder
        """
        # Tokenize
        encoded = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors='pt'
        )

        input_ids = encoded['input_ids'].to(self.device)
        attention_mask = encoded['attention_mask'].to(self.device)

        # Encode
        # Ensure gradients are tracked during training
        if encoder.training:
             embeddings = encoder(input_ids, attention_mask)
        else:
            with torch.no_grad():
                embeddings = encoder(input_ids, attention_mask)

        return embeddings

    def compute_similarity_matrix(self, question_embeddings, passage_embeddings):
        """
        Compute similarity matrix between questions and passages
        Returns: (batch_size, num_passages) similarity matrix
        """
        # Normalize embeddings
        question_embeddings = F.normalize(question_embeddings, p=2, dim=1)
        passage_embeddings = F.normalize(passage_embeddings, p=2, dim=1)

        # Compute similarity matrix S = Q * P^T
        similarity_matrix = torch.matmul(question_embeddings, passage_embeddings.transpose(0, 1))

        return similarity_matrix

    def train_step(self, questions, passages):
        """
        Training step with in-batch negatives only
        """
        batch_size = len(questions)

        # Encode questions and passages
        question_embeddings = self.encode_batch(questions, self.question_encoder)
        passage_embeddings = self.encode_batch(passages, self.passage_encoder)

        # Compute similarity matrix (B x B)
        # In-batch negatives mean the positive pairs are on the diagonal
        similarity_matrix = self.compute_similarity_matrix(question_embeddings, passage_embeddings)

        # Labels: positive pairs are on diagonal (i=j)
        labels = torch.arange(batch_size).to(self.device)

        # Apply temperature scaling
        temperature = 0.1
        similarity_matrix = similarity_matrix / temperature

        # Cross-entropy loss
        # This loss maximizes similarity of (qi, pi) pairs and minimizes (qi, pj) for i != j
        loss = F.cross_entropy(similarity_matrix, labels)

        # Backward pass
        self.optimizer.zero_grad()
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(
            list(self.question_encoder.parameters()) + list(self.passage_encoder.parameters()),
            max_norm=1.0
        )

        self.optimizer.step()

        return loss.item()

    def train_epoch(self, train_data, batch_size, epoch, val_questions, val_answers):
        """
        Train for one epoch with comprehensive validation, and save checkpoint with metrics

        Args:
            train_data: List of (question, passage) pairs
            batch_size: Training batch size
            epoch: Current epoch number (for logging and checkpointing)
            val_questions: List of validation questions
            val_answers: List of validation answers

        Returns:
            Average training loss and comprehensive validation metrics for the epoch
        """
        self.question_encoder.train()
        self.passage_encoder.train()

        total_loss = 0
        num_batches = 0

        # Shuffle training data
        random.shuffle(train_data)

        for i in tqdm(range(0, len(train_data), batch_size), desc="Training DPR"):
            batch_data = train_data[i:i+batch_size]

            # Extract questions and passages
            questions = [item[0] for item in batch_data]
            passages = [item[1] for item in batch_data]

            # Only train if batch size > 1 to have in-batch negatives
            if len(questions) > 1:
                 loss = self.train_step(questions, passages)
                 total_loss += loss
                 num_batches += 1
                 # Log batch loss to wandb
                 if wandb.run:
                     wandb.log({"dpr/batch_loss": loss}, step=epoch * (len(train_data) // batch_size) + num_batches)

            else:
                print(f"Skipping batch {i//batch_size} with size {len(questions)} as in-batch negatives require batch size > 1.")

        avg_train_loss = total_loss / num_batches if num_batches > 0 else 0

        # Comprehensive Validation
        print(f"   • Performing comprehensive validation...")
        val_subset_size = min(50, len(val_questions))  # Increased validation subset size for better metrics
        val_passage_embeddings = self.encode_for_retrieval(val_answers[:val_subset_size], 'passage')

        # Initialize metrics tracking
        top_k_accuracies = {1: 0, 5: 0, 10: 0, 20: 0}
        val_predictions = []

        # Evaluate validation set
        for i, question in enumerate(val_questions[:val_subset_size]):
            results_list = self.retrieve(question, val_passage_embeddings, val_answers[:val_subset_size], top_k=20)

            if results_list:
                # Get top-1 prediction for QA metrics
                best_answer = results_list[0][2]
                val_predictions.append(best_answer)

                # Calculate top-k accuracies
                correct_answer = val_answers[i]
                retrieved_answers = [r[2] for r in results_list]

                for k in [1, 5, 10, 20]:
                    if correct_answer in retrieved_answers[:k]:
                        top_k_accuracies[k] += 1
            else:
                val_predictions.append("")

        # Calculate final metrics
        for k in top_k_accuracies:
            top_k_accuracies[k] = top_k_accuracies[k] / val_subset_size if val_subset_size > 0 else 0.0

        # QA performance metrics (F1 and Exact Match)
        from collections import Counter
        import re

        def normalize_answer(text):
            text = text.lower()
            text = re.sub(r'[^\w\s]', '', text)
            text = ' '.join(text.split())
            return text

        def exact_match(prediction, ground_truth):
            return float(normalize_answer(prediction) == normalize_answer(ground_truth))

        def f1_score(prediction, ground_truth):
            pred_tokens = normalize_answer(prediction).split()
            gt_tokens = normalize_answer(ground_truth).split()

            if len(pred_tokens) == 0 and len(gt_tokens) == 0:
                return 1.0
            if len(pred_tokens) == 0 or len(gt_tokens) == 0:
                return 0.0

            common_tokens = Counter(pred_tokens) & Counter(gt_tokens)
            num_same = sum(common_tokens.values())

            if num_same == 0:
                return 0.0

            precision = num_same / len(pred_tokens)
            recall = num_same / len(gt_tokens)

            return 2 * precision * recall / (precision + recall)

        # Calculate QA metrics
        exact_matches = [exact_match(pred, gt) for pred, gt in zip(val_predictions, val_answers[:val_subset_size])]
        f1_scores = [f1_score(pred, gt) for pred, gt in zip(val_predictions, val_answers[:val_subset_size])]

        val_exact_match = np.mean(exact_matches) if exact_matches else 0.0
        val_f1 = np.mean(f1_scores) if f1_scores else 0.0

        print(f"   • Validation Results (subset size: {val_subset_size}):")
        print(f"     - Top-1 accuracy: {top_k_accuracies[1]:.3f}")
        print(f"     - Top-5 accuracy: {top_k_accuracies[5]:.3f}")
        print(f"     - Top-10 accuracy: {top_k_accuracies[10]:.3f}")
        print(f"     - Top-20 accuracy: {top_k_accuracies[20]:.3f}")
        print(f"     - Exact Match: {val_exact_match:.3f}")
        print(f"     - F1 Score: {val_f1:.3f}")

        # Log comprehensive epoch metrics to wandb
        if wandb.run:
            wandb.log({
                "dpr/epoch_loss": avg_train_loss,
                "dpr/val_top_1_accuracy": top_k_accuracies[1],
                "dpr/val_top_5_accuracy": top_k_accuracies[5],
                "dpr/val_top_10_accuracy": top_k_accuracies[10],
                "dpr/val_top_20_accuracy": top_k_accuracies[20],
                "dpr/val_exact_match": val_exact_match,
                "dpr/val_f1": val_f1
            }, step=epoch)

        # Save checkpoint after each epoch with comprehensive metrics
        checkpoint_dir = "dpr_checkpoints"
        os.makedirs(checkpoint_dir, exist_ok=True)
        checkpoint_path = os.path.join(checkpoint_dir, f"dpr_epoch_{epoch+1}.pth")
        torch.save({
            'epoch': epoch,
            'question_encoder_state_dict': self.question_encoder.state_dict(),
            'passage_encoder_state_dict': self.passage_encoder.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'loss': avg_train_loss,
            'val_top_1_accuracy': top_k_accuracies[1],
            'val_top_5_accuracy': top_k_accuracies[5],
            'val_top_10_accuracy': top_k_accuracies[10],
            'val_top_20_accuracy': top_k_accuracies[20],
            'val_exact_match': val_exact_match,
            'val_f1': val_f1,
            'val_subset_size': val_subset_size
        }, checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

        # Log checkpoint artifact to wandb
        if wandb.run:
            artifact = wandb.Artifact(f'dpr-model-epoch-{epoch+1}', type='model')
            artifact.add_file(checkpoint_path)
            wandb.log_artifact(artifact)

        # Return comprehensive validation metrics
        epoch_metrics = {
            'loss': avg_train_loss,
            'top_1_accuracy': top_k_accuracies[1],
            'top_5_accuracy': top_k_accuracies[5],
            'top_10_accuracy': top_k_accuracies[10],
            'top_20_accuracy': top_k_accuracies[20],
            'exact_match': val_exact_match,
            'f1': val_f1,
            'val_subset_size': val_subset_size
        }

        return avg_train_loss, epoch_metrics

    def encode_for_retrieval(self, texts, encoder_type='passage', batch_size=32):
        """
        Encode texts for retrieval (inference mode)
        """
        encoder = self.question_encoder if encoder_type == 'question' else self.passage_encoder
        encoder.eval()

        all_embeddings = []

        with torch.no_grad():
            for i in range(0, len(texts), batch_size):
                batch_texts = texts[i:i+batch_size]
                embeddings = self.encode_batch(batch_texts, encoder)
                all_embeddings.append(embeddings)

        return torch.cat(all_embeddings, dim=0)

    def retrieve(self, query, passage_embeddings, passages, top_k=5):
        """
        Retrieve top-k passages for a query
        """
        # Encode query
        query_embedding = self.encode_for_retrieval([query], 'question')

        # Compute similarities
        similarities = self.compute_similarity_matrix(query_embedding, passage_embeddings)
        similarities = similarities.squeeze(0)  # Remove batch dimension

        # Get top-k
        top_k_scores, top_k_indices = torch.topk(similarities, min(top_k, len(similarities)))

        results = []
        for i, (score, idx) in enumerate(zip(top_k_scores, top_k_indices)):
            results.append((int(idx), float(score), passages[int(idx)]))

        return results

# Initialize DPR trainer
print("🧠 Initializing Dense Passage Retrieval (DPR) with In-Batch Negatives...")

if DL_LIBRARIES_AVAILABLE:
    try:
        dpr_trainer = DPRTrainer(
            model_name='bert-base-uncased',
            projection_dim=768,
            learning_rate=2e-5
        )
        print("✅ DPR trainer initialized successfully!")
        print(f"   🔥 Device: {dpr_trainer.device}")
        print(f"   📐 Projection dim: {dpr_trainer.projection_dim}")
        print(f"   🔀 Dual encoders: Question + Passage")
        print(f"   🎯 Negative sampling strategy: In-batch negatives only")

    except Exception as e:
        print(f"❌ Failed to initialize DPR trainer: {e}")
        dpr_trainer = None
else:
    print("⚠️ Deep learning libraries not available")
    dpr_trainer = None

🧠 Initializing Dense Passage Retrieval (DPR) with In-Batch Negatives...


config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

✅ DPR trainer initialized successfully!
   🔥 Device: cpu
   📐 Projection dim: 768
   🔀 Dual encoders: Question + Passage
   🎯 Negative sampling strategy: In-batch negatives only


In [8]:
# Step 3: Implement SPLADE Neural Sparse Retrieval (Correct Implementation)

class SPLADEModel(nn.Module):
    """
    SPLADE (Sparse Lexical And Expansion) Model
    Implements the correct SPLADE architecture from Formal et al. 2021
    """

    def __init__(self, model_name='bert-base-uncased'):
        super(SPLADEModel, self).__init__()

        # Load BERT model with MLM head
        self.bert = AutoModel.from_pretrained(model_name)
        self.vocab_size = self.bert.config.vocab_size  # 30,522 for BERT

        # MLM head to project back to vocabulary space
        self.mlm_head = nn.Linear(self.bert.config.hidden_size, self.vocab_size)

        # Initialize MLM head with BERT's original MLM weights if available
        try:
            bert_mlm = AutoModel.from_pretrained(model_name, output_hidden_states=True)
            if hasattr(bert_mlm, 'cls') and hasattr(bert_mlm.cls, 'predictions'):
                self.mlm_head.weight.data = bert_mlm.cls.predictions.transform.dense.weight.data.clone()
                self.mlm_head.bias.data = bert_mlm.cls.predictions.bias.data.clone()
        except:
            pass  # Use default initialization if MLM weights not available

        self.dropout = nn.Dropout(0.1)

    def forward(self, input_ids, attention_mask):
        """
        Forward pass of SPLADE model

        Args:
            input_ids: Token IDs [batch_size, seq_len]
            attention_mask: Attention mask [batch_size, seq_len]

        Returns:
            Sparse representation [batch_size, vocab_size]
        """
        # Get BERT outputs
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state  # [batch_size, seq_len, hidden_size]

        # Apply MLM head to each token
        # Project each token to vocabulary space
        token_logits = self.mlm_head(self.dropout(last_hidden_state))  # [batch_size, seq_len, vocab_size]

        # Apply ReLU activation (ensure non-negative weights)
        token_weights = F.relu(token_logits)  # [batch_size, seq_len, vocab_size]

        # Apply log saturation: log(1 + ReLU(x))
        token_weights = torch.log(1 + token_weights)  # [batch_size, seq_len, vocab_size]

        # Apply attention mask to ignore padded tokens
        attention_mask_expanded = attention_mask.unsqueeze(-1).expand(token_weights.size())
        token_weights = token_weights * attention_mask_expanded

        # Sum over sequence length to get document/query representation
        # This aggregates all token contributions for each vocabulary term
        sparse_representation = torch.sum(token_weights, dim=1)  # [batch_size, vocab_size]

        return sparse_representation

class SPLADETrainer:
    """
    SPLADE trainer with proper ranking loss and FLOPS regularization
    """

    def __init__(self, model_name='bert-base-uncased', learning_rate=5e-6):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Initialize SPLADE model
        self.model = SPLADEModel(model_name).to(self.device)

        # Tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # Optimizer (lower learning rate for SPLADE)
        self.optimizer = optim.AdamW(self.model.parameters(), lr=learning_rate)

        # Regularization parameters
        self.lambda_q = 0.06  # Query regularization strength
        self.lambda_d = 0.02  # Document regularization strength

        self.vocab_size = self.model.vocab_size

    def encode_batch(self, texts, max_length=512):
        """
        Encode batch of texts to sparse representations
        """
        # Tokenize
        encoded = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors='pt'
        )

        input_ids = encoded['input_ids'].to(self.device)
        attention_mask = encoded['attention_mask'].to(self.device)

        # Get sparse representations
        sparse_repr = self.model(input_ids, attention_mask)

        return sparse_repr

    def compute_flops_regularization(self, sparse_repr):
        """
        Compute FLOPS regularization to encourage sparsity

        Args:
            sparse_repr: Sparse representation [batch_size, vocab_size]

        Returns:
            FLOPS regularization loss
        """
        # Sum of all non-zero elements (approximates FLOPS)
        # This penalizes having many non-zero terms
        flops_loss = torch.sum(torch.sum(sparse_repr, dim=0) ** 2)
        return flops_loss

    def compute_ranking_loss(self, query_repr, pos_doc_repr, neg_doc_repr):
        """
        Compute ranking loss with in-batch negatives

        Args:
            query_repr: Query sparse representations [batch_size, vocab_size]
            pos_doc_repr: Positive document representations [batch_size, vocab_size]
            neg_doc_repr: Negative document representations [batch_size, vocab_size]

        Returns:
            Ranking loss
        """
        # Compute similarities using dot product (sparse * sparse)
        pos_scores = torch.sum(query_repr * pos_doc_repr, dim=1)  # [batch_size]
        neg_scores = torch.sum(query_repr * neg_doc_repr, dim=1)  # [batch_size]

        # Ranking loss: maximize positive scores, minimize negative scores
        ranking_loss = torch.mean(torch.maximum(torch.zeros_like(pos_scores),
                                               1.0 - pos_scores + neg_scores))

        return ranking_loss

    def train_step(self, queries, positive_docs, negative_docs):
        """
        Single training step with ranking loss and FLOPS regularization

        Args:
            queries: List of query strings
            positive_docs: List of positive document strings
            negative_docs: List of negative document strings

        Returns:
            Total loss, ranking loss, regularization loss
        """
        # Encode queries and documents
        query_repr = self.encode_batch(queries)
        pos_doc_repr = self.encode_batch(positive_docs)
        neg_doc_repr = self.encode_batch(negative_docs)

        # Compute ranking loss
        ranking_loss = self.compute_ranking_loss(query_repr, pos_doc_repr, neg_doc_repr)

        # Compute FLOPS regularization
        query_flops = self.compute_flops_regularization(query_repr)
        doc_flops = self.compute_flops_regularization(pos_doc_repr)

        # Total regularization loss
        reg_loss = self.lambda_q * query_flops + self.lambda_d * doc_flops

        # Total loss
        total_loss = ranking_loss + reg_loss

        # Backward pass
        self.optimizer.zero_grad()
        total_loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

        self.optimizer.step()

        return total_loss.item(), ranking_loss.item(), reg_loss.item()

    # Modify train_epoch to perform comprehensive validation and save metrics
    def train_epoch(self, train_data: List[Tuple[str, str, str]], batch_size: int, epoch: int, val_questions: List[str], val_answers: List[str]):
        """
        Train for one epoch with comprehensive validation, and save checkpoint with metrics

        Args:
            train_data: List of (query, positive_doc, negative_doc) tuples
            batch_size: Training batch size
            epoch: Current epoch number (for logging and checkpointing)
            val_questions: List of validation questions
            val_answers: List of validation answers

        Returns:
            Dictionary with comprehensive training and validation metrics for the epoch
        """
        self.model.train()

        total_loss = 0
        total_ranking_loss = 0
        total_reg_loss = 0
        num_batches = 0

        # Shuffle training data
        random.shuffle(train_data)

        for i in tqdm(range(0, len(train_data), batch_size), desc=f"Epoch {epoch+1} Training SPLADE"):
            batch_data = train_data[i:i+batch_size]

            # Extract queries, positive docs, and negative docs
            queries = [item[0] for item in batch_data]
            positive_docs = [item[1] for item in batch_data]
            negative_docs = [item[2] for item in batch_data]

            # Training step
            loss, ranking_loss, reg_loss = self.train_step(queries, positive_docs, negative_docs)

            total_loss += loss
            total_ranking_loss += ranking_loss
            total_reg_loss += reg_loss
            num_batches += 1

            # Log batch loss to wandb
            if wandb.run:
                wandb.log({
                    "splade/batch_total_loss": loss,
                    "splade/batch_ranking_loss": ranking_loss,
                    "splade/batch_reg_loss": reg_loss
                }, step=epoch * (len(train_data) // batch_size) + num_batches)

        avg_total_loss = total_loss / num_batches if num_batches > 0 else 0
        avg_ranking_loss = total_ranking_loss / num_batches if num_batches > 0 else 0
        avg_reg_loss = total_reg_loss / num_batches if num_batches > 0 else 0

        # Comprehensive Validation
        print(f"   • Performing comprehensive validation...")
        val_subset_size = min(50, len(val_questions))  # Increased validation subset size for better metrics
        val_doc_representations = self.encode_for_retrieval(val_answers[:val_subset_size])

        # Initialize metrics tracking
        top_k_accuracies = {1: 0, 5: 0, 10: 0, 20: 0}
        val_predictions = []

        # Evaluate validation set
        for i, question in enumerate(val_questions[:val_subset_size]):
            results_list = self.retrieve(question, val_doc_representations, val_answers[:val_subset_size], top_k=20)

            if results_list:
                # Get top-1 prediction for QA metrics
                best_answer = results_list[0][2]
                val_predictions.append(best_answer)

                # Calculate top-k accuracies
                correct_answer = val_answers[i]
                retrieved_answers = [r[2] for r in results_list]

                for k in [1, 5, 10, 20]:
                    if correct_answer in retrieved_answers[:k]:
                        top_k_accuracies[k] += 1
            else:
                val_predictions.append("")

        # Calculate final metrics
        for k in top_k_accuracies:
            top_k_accuracies[k] = top_k_accuracies[k] / val_subset_size if val_subset_size > 0 else 0.0

        # QA performance metrics (F1 and Exact Match)
        from collections import Counter
        import re

        def normalize_answer(text):
            text = text.lower()
            text = re.sub(r'[^\w\s]', '', text)
            text = ' '.join(text.split())
            return text

        def exact_match(prediction, ground_truth):
            return float(normalize_answer(prediction) == normalize_answer(ground_truth))

        def f1_score(prediction, ground_truth):
            pred_tokens = normalize_answer(prediction).split()
            gt_tokens = normalize_answer(ground_truth).split()

            if len(pred_tokens) == 0 and len(gt_tokens) == 0:
                return 1.0
            if len(pred_tokens) == 0 or len(gt_tokens) == 0:
                return 0.0

            common_tokens = Counter(pred_tokens) & Counter(gt_tokens)
            num_same = sum(common_tokens.values())

            if num_same == 0:
                return 0.0

            precision = num_same / len(pred_tokens)
            recall = num_same / len(gt_tokens)

            return 2 * precision * recall / (precision + recall)

        # Calculate QA metrics
        exact_matches = [exact_match(pred, gt) for pred, gt in zip(val_predictions, val_answers[:val_subset_size])]
        f1_scores = [f1_score(pred, gt) for pred, gt in zip(val_predictions, val_answers[:val_subset_size])]

        val_exact_match = np.mean(exact_matches) if exact_matches else 0.0
        val_f1 = np.mean(f1_scores) if f1_scores else 0.0

        print(f"   • Validation Results (subset size: {val_subset_size}):")
        print(f"     - Top-1 accuracy: {top_k_accuracies[1]:.3f}")
        print(f"     - Top-5 accuracy: {top_k_accuracies[5]:.3f}")
        print(f"     - Top-10 accuracy: {top_k_accuracies[10]:.3f}")
        print(f"     - Top-20 accuracy: {top_k_accuracies[20]:.3f}")
        print(f"     - Exact Match: {val_exact_match:.3f}")
        print(f"     - F1 Score: {val_f1:.3f}")

        # Log comprehensive epoch metrics to wandb
        if wandb.run:
            wandb.log({
                "splade/epoch_total_loss": avg_total_loss,
                "splade/epoch_ranking_loss": avg_ranking_loss,
                "splade/epoch_reg_loss": avg_reg_loss,
                "splade/val_top_1_accuracy": top_k_accuracies[1],
                "splade/val_top_5_accuracy": top_k_accuracies[5],
                "splade/val_top_10_accuracy": top_k_accuracies[10],
                "splade/val_top_20_accuracy": top_k_accuracies[20],
                "splade/val_exact_match": val_exact_match,
                "splade/val_f1": val_f1
            }, step=epoch)

        # Save checkpoint after each epoch with comprehensive metrics
        checkpoint_dir = "splade_checkpoints"
        os.makedirs(checkpoint_dir, exist_ok=True)
        checkpoint_path = os.path.join(checkpoint_dir, f"splade_epoch_{epoch+1}.pth")
        torch.save({
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'total_loss': avg_total_loss,
            'ranking_loss': avg_ranking_loss,
            'reg_loss': avg_reg_loss,
            'val_top_1_accuracy': top_k_accuracies[1],
            'val_top_5_accuracy': top_k_accuracies[5],
            'val_top_10_accuracy': top_k_accuracies[10],
            'val_top_20_accuracy': top_k_accuracies[20],
            'val_exact_match': val_exact_match,
            'val_f1': val_f1,
            'val_subset_size': val_subset_size
        }, checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

        # Log checkpoint artifact to wandb
        if wandb.run:
            artifact = wandb.Artifact(f'splade-model-epoch-{epoch+1}', type='model')
            artifact.add_file(checkpoint_path)
            wandb.log_artifact(artifact)

        # Return comprehensive validation metrics
        epoch_metrics = {
            'total_loss': avg_total_loss,
            'ranking_loss': avg_ranking_loss,
            'reg_loss': avg_reg_loss,
            'top_1_accuracy': top_k_accuracies[1],
            'top_5_accuracy': top_k_accuracies[5],
            'top_10_accuracy': top_k_accuracies[10],
            'top_20_accuracy': top_k_accuracies[20],
            'exact_match': val_exact_match,
            'f1': val_f1,
            'val_subset_size': val_subset_size
        }

        return epoch_metrics


    def encode_for_retrieval(self, texts, batch_size=32):
        """
        Encode texts for retrieval (inference mode)
        """
        self.model.eval()

        all_representations = []

        with torch.no_grad():
            for i in range(0, len(texts), batch_size):
                batch_texts = texts[i:i+batch_size]
                sparse_repr = self.encode_batch(batch_texts)
                all_representations.append(sparse_repr)

        return torch.cat(all_representations, dim=0)

    def retrieve(self, query, document_representations, documents, top_k=5):
        """
        Retrieve top-k documents for a query using sparse representations

        Args:
            query: Query string
            document_representations: Pre-computed document sparse representations
            documents: List of document strings
            top_k: Number of documents to retrieve

        Returns:
            List of (doc_idx, score, doc_text) tuples
        """
        # Encode query
        query_repr = self.encode_for_retrieval([query])

        # Compute similarities (dot product for sparse vectors)
        similarities = torch.sum(query_repr * document_representations, dim=1)

        # Get top-k
        top_k_scores, top_k_indices = torch.topk(similarities, min(top_k, len(similarities)))

        results = []
        for score, idx in zip(top_k_scores, top_k_indices):
            results.append((int(idx), float(score), documents[int(idx)]))

        return results

    def get_sparse_statistics(self, sparse_repr):
        """
        Get statistics about sparsity of representations
        """
        # Count non-zero elements
        non_zero_count = torch.sum(sparse_repr > 0, dim=1).float()
        sparsity = 1.0 - (non_zero_count / self.vocab_size)

        return {
            'avg_non_zero_terms': torch.mean(non_zero_count).item(),
            'avg_sparsity': torch.mean(sparsity).item(),
            'max_weight': torch.max(sparse_repr).item(),
            'avg_weight': torch.mean(sparse_repr[sparse_repr > 0]).item()
        }

# Initialize SPLADE trainer
print("🔍 Initializing SPLADE Neural Sparse Retrieval (Correct Implementation)...")

if DL_LIBRARIES_AVAILABLE:
    try:
        splade_trainer = SPLADETrainer(
            model_name='bert-base-uncased',
            learning_rate=5e-6  # Lower learning rate for SPLADE
        )

        print("✅ SPLADE trainer initialized successfully!")
        print(f"   🔥 Device: {splade_trainer.device}")
        print(f"   📚 Vocabulary size: {splade_trainer.vocab_size}")
        print(f"   🎯 Key features:")
        print(f"      • MLM head projection to full vocabulary")
        print(f"      • Log saturation: log(1 + ReLU(x))")
        print(f"      • FLOPS regularization (λ_q={splade_trainer.lambda_q}, λ_d={splade_trainer.lambda_d})")
        print(f"      • Ranking loss with in-batch negatives")
        print(f"      • Learns sparse 'bag-of-expanded-words' representations")

    except Exception as e:
        print(f"❌ Failed to initialize SPLADE trainer: {e}")
        splade_trainer = None
else:
    print("⚠️ Deep learning libraries not available")
    splade_trainer = None

🔍 Initializing SPLADE Neural Sparse Retrieval (Correct Implementation)...
✅ SPLADE trainer initialized successfully!
   🔥 Device: cpu
   📚 Vocabulary size: 30522
   🎯 Key features:
      • MLM head projection to full vocabulary
      • Log saturation: log(1 + ReLU(x))
      • FLOPS regularization (λ_q=0.06, λ_d=0.02)
      • Ranking loss with in-batch negatives
      • Learns sparse 'bag-of-expanded-words' representations


In [9]:
# Step 4: Define Evaluation Framework for Deep Learning Methods

class DeepLearningEvaluator:
    """Evaluation framework for deep learning retrieval methods"""

    def __init__(self):
        self.metrics = {}

    def normalize_answer(self, text: str) -> str:
        """Normalize answer text for comparison"""
        text = text.lower()
        text = re.sub(r'[^\w\s]', '', text)
        text = ' '.join(text.split())
        return text

    def exact_match(self, prediction: str, ground_truth: str) -> float:
        """Calculate exact match score"""
        return float(self.normalize_answer(prediction) == self.normalize_answer(ground_truth))

    def f1_score(self, prediction: str, ground_truth: str) -> float:
        """Calculate F1 score at token level"""
        pred_tokens = self.normalize_answer(prediction).split()
        gt_tokens = self.normalize_answer(ground_truth).split()

        if len(pred_tokens) == 0 and len(gt_tokens) == 0:
            return 1.0
        if len(pred_tokens) == 0 or len(gt_tokens) == 0:
            return 0.0

        common_tokens = Counter(pred_tokens) & Counter(gt_tokens)
        num_same = sum(common_tokens.values())

        if num_same == 0:
            return 0.0

        precision = num_same / len(pred_tokens)
        recall = num_same / len(gt_tokens)

        return 2 * precision * recall / (precision + recall)

    def evaluate_qa_performance(self, predictions: List[str], ground_truths: List[str]) -> Dict[str, float]:
        """
        Evaluate QA performance using traditional metrics

        Args:
            predictions: List of predicted answers
            ground_truths: List of ground truth answers

        Returns:
            Dictionary of QA evaluation metrics
        """
        if len(predictions) != len(ground_truths):
            raise ValueError("Predictions and ground truths must have same length")

        exact_matches = [self.exact_match(pred, gt) for pred, gt in zip(predictions, ground_truths)]
        f1_scores = [self.f1_score(pred, gt) for pred, gt in zip(predictions, ground_truths)]

        return {
            'exact_match': np.mean(exact_matches),
            'f1': np.mean(f1_scores),
            'count': len(predictions)
        }

# Initialize evaluator
evaluator = DeepLearningEvaluator()
print("✅ Deep learning evaluation framework ready!")
print("   📊 Metrics: Exact Match, F1 Score")
print("   🎯 Focus: Retrieval performance for medical domain QA")

✅ Deep learning evaluation framework ready!
   📊 Metrics: Exact Match, F1 Score
   🎯 Focus: Retrieval performance for medical domain QA


In [10]:
# Step 5: Train and Evaluate Deep Learning Methods with Proper Splits with DPR
print("🚀 Training and Evaluating Deep Learning Methods with Proper Dataset Splits")
print("=" * 80)

# Initialize results storage
results = {}

# ===== TRAINING PHASE =====
print("🎯 TRAINING PHASE")
print("-" * 40)

# Prepare training data (subset for efficiency)
# Use full training set if using dummy data to facilitate negative sampling
train_subset_size = len(train_questions) # Reduced training subset size

train_subset_questions = train_questions[:train_subset_size]
train_subset_answers = train_answers[:train_subset_size]

print(f"📊 Training data preparation:")
print(f"   • Full training set: {len(train_questions):,} QA pairs")
print(f"   • Training subset: {train_subset_size:,} QA pairs (for efficiency)")
print(f"   • Validation set: {len(val_questions):,} QA pairs")
print(f"   • Test set: {len(test_questions):,} QA pairs")

# Create corpus for negative sampling (using training answers only)
corpus_for_negatives = train_answers.copy()
print(f"🎯 Corpus for negative sampling: {len(corpus_for_negatives):,} passages")

# ===== DPR TRAINING AND EVALUATION =====
print(f"\n🧠 DENSE PASSAGE RETRIEVAL (DPR) TRAINING & EVALUATION")
print("-" * 60)

if DL_LIBRARIES_AVAILABLE and dpr_trainer is not None:
    try:
        # Prepare training data
        train_data = [(q, a) for q, a in zip(train_subset_questions, train_subset_answers)]

        # Training parameters
        num_epochs =  6 # Reduced for resource efficiency
        batch_size = 8 # Reduced for resource efficiency

        # --- Start: Added code for loading from wandb checkpoint ---
        load_from_wandb = True  # Set to True to load from wandb
        wandb_project = "Domain QA" # Replace with your project name
        wandb_artifact_name = "dpr-model-epoch-1:v0" # Replace with your artifact name and version

        start_epoch = 0 # Default start epoch

        if load_from_wandb:
            print(f"Attempting to load model from wandb artifact: {wandb_artifact_name}")
            try:
                # Use wandb.init() as a context manager for loading as well
                with wandb.init(project=wandb_project, resume=True) as run:
                    artifact = run.use_artifact(wandb_artifact_name, type='model')
                    artifact_dir = artifact.download()
                    checkpoint_path = os.path.join(artifact_dir, f"dpr_epoch_{wandb_artifact_name.split(':')[0].split('-')[-1]}.pth") # Assuming filename format

                    checkpoint = torch.load(checkpoint_path)
                    dpr_trainer.question_encoder.load_state_dict(checkpoint['question_encoder_state_dict'])
                    dpr_trainer.passage_encoder.load_state_dict(checkpoint['passage_encoder_state_dict'])
                    dpr_trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                    start_epoch = checkpoint['epoch'] + 1 # Start from the epoch after the loaded one
                    print(f"✅ Successfully loaded checkpoint from epoch {checkpoint['epoch']}. Resuming training from epoch {start_epoch}.")

            except Exception as e:
                print(f"❌ Failed to load checkpoint from wandb: {e}")
                print("Starting training from scratch.")
                start_epoch = 0
        # --- End: Added code for loading from wandb checkpoint ---

        print(f"🎯 DPR Training configuration:")
        print(f"   • Epochs: {num_epochs}")
        print(f"   • Batch size: {batch_size}")
        print(f"   • Training pairs: {len(train_data):,}")
        print(f"   • Negative sampling: In-batch negatives only") # Updated description
        print(f"   • Starting epoch: {start_epoch}")

        # Training loop with comprehensive epoch-by-epoch metrics tracking
        print(f"\n🏃 Training DPR...")
        training_losses = []
        epoch_metrics_history = []  # Store all epoch metrics

        # Use wandb.init() as a context manager for training
        with wandb.init(project="Domain QA", name="DPR Training", config={
            "dataset_train_size": len(train_data),
            "dataset_val_size": len(val_questions),
            "dataset_test_size": len(test_questions),
            "model_name": 'bert-base-uncased', # Assuming this is the model name
            "projection_dim": dpr_trainer.projection_dim,
            "learning_rate": dpr_trainer.optimizer.param_groups[0]['lr'], # Get current LR
            "num_epochs": num_epochs,
            "batch_size": batch_size,
            "negative_sampling": "in-batch only"
        }) as run:
            # ===== DATASET SPLIT HANDLING WITHIN TRAINING RUN =====
            print("📊 Checking for existing dataset splits in wandb...")
            dataset_loaded_from_wandb = False
            dataset_artifact_name = "covid-qa-splits"  # Standard artifact name

            # Check if artifact already exists and load or create
            try:
                existing_artifact = run.use_artifact(f"{dataset_artifact_name}:latest", type="dataset_split")
                print(f"✅ Found existing dataset splits, loading from wandb:")
                print(f"   • Using artifact: {existing_artifact.name}")

                # Load the dataset splits from wandb tables
                train_table = existing_artifact.get("train_data_frame")
                val_table = existing_artifact.get("val_data_frame")
                test_table = existing_artifact.get("test_data_frame")

                # Convert back to lists
                train_df_loaded = train_table.get_dataframe()
                val_df_loaded = val_table.get_dataframe()
                test_df_loaded = test_table.get_dataframe()

                # Update variables to use loaded splits
                train_questions_loaded = train_df_loaded['question'].tolist()
                train_answers_loaded = train_df_loaded['answer'].tolist()
                val_questions_loaded = val_df_loaded['question'].tolist()
                val_answers_loaded = val_df_loaded['answer'].tolist()
                test_questions_loaded = test_df_loaded['question'].tolist()
                test_answers_loaded = test_df_loaded['answer'].tolist()

                # Use loaded splits for consistency
                train_questions = train_questions_loaded
                train_answers = train_answers_loaded
                val_questions = val_questions_loaded
                val_answers = val_answers_loaded
                test_questions = test_questions_loaded
                test_answers = test_answers_loaded

                # Update training data and corpus
                train_subset_questions = train_questions[:train_subset_size]
                train_subset_answers = train_answers[:train_subset_size]
                train_data = [(q, a) for q, a in zip(train_subset_questions, train_subset_answers)]
                corpus_for_negatives = train_answers.copy()

                dataset_loaded_from_wandb = True
                print(f"   • Training set: {len(train_questions):,} QA pairs")
                print(f"   • Validation set: {len(val_questions):,} QA pairs")
                print(f"   • Test set: {len(test_questions):,} QA pairs")

            except Exception as e:
                print(f"ℹ️ No existing dataset splits found: {e}")
                print(f"📊 Creating new dataset splits artifact...")

                # Create and log dataset split artifacts
                split_artifact = wandb.Artifact(dataset_artifact_name, type="dataset_split")

                # Convert to dataframes and add to artifact
                train_split_df = pd.DataFrame({
                    'question': train_questions,
                    'answer': train_answers,
                    'dataset_source': ['mixed'] * len(train_questions)  # Add source info
                })
                val_split_df = pd.DataFrame({
                    'question': val_questions,
                    'answer': val_answers,
                    'dataset_source': ['mixed'] * len(val_questions)
                })
                test_split_df = pd.DataFrame({
                    'question': test_questions,
                    'answer': test_answers,
                    'dataset_source': ['mixed'] * len(test_questions)
                })

                # Log as wandb Tables
                split_artifact.add(wandb.Table(dataframe=train_split_df), 'train_data_frame')
                split_artifact.add(wandb.Table(dataframe=val_split_df), 'val_data_frame')
                split_artifact.add(wandb.Table(dataframe=test_split_df), 'test_data_frame')

                run.log_artifact(split_artifact)
                print(f"✅ Dataset splits saved as wandb artifact: {split_artifact.name}")
                dataset_loaded_from_wandb = False

            # Log dataset info
            wandb.log({
                "dataset/train_size": len(train_questions),
                "dataset/val_size": len(val_questions),
                "dataset/test_size": len(test_questions),
                "dataset/total_size": len(train_questions) + len(val_questions) + len(test_questions),
                "dataset/loaded_from_wandb": dataset_loaded_from_wandb
            })

            # Log initial dataset and hyperparameters
            if wandb.run:
                 wandb.config.update({
                    "dataset_train_size": len(train_data),
                    "dataset_val_size": len(val_questions),
                    "dataset_test_size": len(test_questions),
                    "model_name": 'bert-base-uncased', # Assuming this is the model name
                    "projection_dim": dpr_trainer.projection_dim,
                    "learning_rate": dpr_trainer.optimizer.param_groups[0]['lr'], # Get current LR
                    "num_epochs": num_epochs,
                    "batch_size": batch_size,
                    "negative_sampling": "in-batch only",
                    "dataset_loaded_from_wandb": dataset_loaded_from_wandb
                })
                 print("Logged dataset size and hyperparameters to wandb.")

            for epoch in range(start_epoch, start_epoch + num_epochs): # Adjust loop to start from loaded epoch
                print(f"\n📅 Epoch {epoch + 1}/{start_epoch + num_epochs}")

                # Train one epoch and get comprehensive validation metrics
                avg_loss, epoch_metrics = dpr_trainer.train_epoch(train_data, batch_size=batch_size, epoch=epoch, val_questions=val_questions, val_answers=val_answers)

                # Store metrics
                training_losses.append(avg_loss)
                epoch_metrics_history.append(epoch_metrics)

                print(f"   • Training loss: {avg_loss:.4f}")
                print(f"   • Validation Results:")
                print(f"     - Top-1 accuracy: {epoch_metrics['top_1_accuracy']:.3f}")
                print(f"     - Top-5 accuracy: {epoch_metrics['top_5_accuracy']:.3f}")
                print(f"     - Top-10 accuracy: {epoch_metrics['top_10_accuracy']:.3f}")
                print(f"     - Top-20 accuracy: {epoch_metrics['top_20_accuracy']:.3f}")
                print(f"     - Exact Match: {epoch_metrics['exact_match']:.3f}")
                print(f"     - F1 Score: {epoch_metrics['f1']:.3f}")

                # Add memory monitoring (example for CPU, adjust for GPU if needed)
                import psutil
                process = psutil.Process(os.getpid())
                mem_info = process.memory_info()
                print(f"   • Memory usage after epoch {epoch+1}: {mem_info.rss / (1024 * 1024):.2f} MB")

        print(f"✅ DPR training completed!")

        # ===== DPR EVALUATION ON TEST SET =====
        print(f"\n🔍 Evaluating DPR on test set...")

        # Use full test set for evaluation
        test_size = len(test_questions)
        print(f"   • Test set size: {test_size:,} QA pairs")

        # Encode test passages
        test_passage_embeddings = dpr_trainer.encode_for_retrieval(test_answers, 'passage')

        # Evaluate on test set
        dpr_predictions = []
        top_k_accuracies = {1: 0, 5: 0, 10: 0, 20: 0} # Added top 20

        for i, question in enumerate(test_questions):
            results_list = dpr_trainer.retrieve(question, test_passage_embeddings, test_answers, top_k=20) # Retrieve top 20

            if results_list:
                # Get top-1 prediction
                best_answer = results_list[0][2]
                dpr_predictions.append(best_answer)

                # Calculate top-k accuracies
                correct_answer = test_answers[i]
                retrieved_answers = [r[2] for r in results_list]

                for k in [1, 5, 10, 20]: # Added top 20
                    if correct_answer in retrieved_answers[:k]:
                        top_k_accuracies[k] += 1
            else:
                dpr_predictions.append("")

        # Calculate final metrics
        for k in top_k_accuracies:
            top_k_accuracies[k] = top_k_accuracies[k] / test_size if test_size > 0 else 0.0

        # QA performance metrics
        dpr_qa_results = evaluator.evaluate_qa_performance(dpr_predictions, test_answers)

        # Store results with epoch-by-epoch metrics
        results['DPR_Trained'] = {
            **dpr_qa_results,
            'top_1_accuracy': top_k_accuracies[1],
            'top_5_accuracy': top_k_accuracies[5],
            'top_10_accuracy': top_k_accuracies[10],
            'top_20_accuracy': top_k_accuracies[20], # Added top 20
            'training_losses': training_losses,
            'epoch_metrics_history': epoch_metrics_history,  # Store comprehensive epoch history
            'test_size': test_size,
            'dataset_loaded_from_wandb': dataset_loaded_from_wandb
        }

        print(f"✅ DPR Test Results:")
        print(f"   • Test set size: {test_size:,}")
        print(f"   • Top-1 accuracy: {top_k_accuracies[1]:.3f}")
        print(f"   • Top-5 accuracy: {top_k_accuracies[5]:.3f}")
        print(f"   • Top-10 accuracy: {top_k_accuracies[10]:.3f}")
        print(f"   • Top-20 accuracy: {top_k_accuracies[20]:.3f}") # Added top 20
        print(f"   • QA Exact Match: {dpr_qa_results['exact_match']:.3f}")
        print(f"   • QA F1 Score: {dpr_qa_results['f1']:.3f}")

        # Log test results to wandb
        with wandb.init(project="Domain QA", name="DPR Training", resume=True) as run:
            wandb.log({
                "dpr/test_top_1_accuracy": top_k_accuracies[1],
                "dpr/test_top_5_accuracy": top_k_accuracies[5],
                "dpr/test_top_10_accuracy": top_k_accuracies[10],
                "dpr/test_top_20_accuracy": top_k_accuracies[20],
                "dpr/test_exact_match": dpr_qa_results['exact_match'],
                "dpr/test_f1": dpr_qa_results['f1']
            })

        # Print epoch-by-epoch summary
        print(f"\n📊 DPR Epoch-by-Epoch Summary:")
        print(f"{'Epoch':<6} {'Loss':<8} {'Top-1':<8} {'Top-5':<8} {'Top-10':<8} {'Top-20':<8} {'EM':<8} {'F1':<8}")
        print("-" * 65)
        for i, (loss, metrics) in enumerate(zip(training_losses, epoch_metrics_history)):
            epoch_num = start_epoch + i + 1
            print(f"{epoch_num:<6} {loss:<8.4f} {metrics['top_1_accuracy']:<8.3f} {metrics['top_5_accuracy']:<8.3f} {metrics['top_10_accuracy']:<8.3f} {metrics['top_20_accuracy']:<8.3f} {metrics['exact_match']:<8.3f} {metrics['f1']:<8.3f}")

    except Exception as e:
        print(f"❌ DPR training/evaluation failed: {e}")
        import traceback
        traceback.print_exc()
        results['DPR_Trained'] = {'exact_match': 0.0, 'f1': 0.0, 'count': 0, 'test_size': 0, 'top_1_accuracy': 0.0, 'top_5_accuracy': 0.0, 'top_10_accuracy': 0.0, 'top_20_accuracy': 0.0, 'training_losses': [], 'epoch_metrics_history': []} # Added epoch_metrics_history

else:
    print("⚠️ DPR trainer not available - using placeholder results")
    results['DPR_Trained'] = {'exact_match': 0.0, 'f1': 0.0, 'count': 0, 'test_size': 0, 'top_1_accuracy': 0.0, 'top_5_accuracy': 0.0, 'top_10_accuracy': 0.0, 'top_20_accuracy': 0.0, 'training_losses': [], 'epoch_metrics_history': []} # Added epoch_metrics_history

🚀 Training and Evaluating Deep Learning Methods with Proper Dataset Splits
🎯 TRAINING PHASE
----------------------------------------
📊 Training data preparation:
   • Full training set: 1,606 QA pairs
   • Training subset: 1,606 QA pairs (for efficiency)
   • Validation set: 201 QA pairs
   • Test set: 201 QA pairs
🎯 Corpus for negative sampling: 1,606 passages

🧠 DENSE PASSAGE RETRIEVAL (DPR) TRAINING & EVALUATION
------------------------------------------------------------
Attempting to load model from wandb artifact: dpr-model-epoch-1:v0


[34m[1mwandb[0m: Downloading large artifact dpr-model-epoch-1:v0, 2510.82MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:15.1 (166.3MB/s)


✅ Successfully loaded checkpoint from epoch 0. Resuming training from epoch 1.


🎯 DPR Training configuration:
   • Epochs: 6
   • Batch size: 8
   • Training pairs: 1,606
   • Negative sampling: In-batch negatives only
   • Starting epoch: 1

🏃 Training DPR...


📊 Checking for existing dataset splits in wandb...
✅ Found existing dataset splits, loading from wandb:
   • Using artifact: covid-qa-splits:latest


[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   3 of 3 files downloaded.  


   • Training set: 1,606 QA pairs
   • Validation set: 201 QA pairs
   • Test set: 201 QA pairs
Logged dataset size and hyperparameters to wandb.

📅 Epoch 2/7


Training DPR: 100%|██████████| 201/201 [1:01:07<00:00, 18.25s/it]


   • Performing comprehensive validation...
   • Validation Results (subset size: 50):
     - Top-1 accuracy: 0.420
     - Top-5 accuracy: 0.920
     - Top-10 accuracy: 0.960
     - Top-20 accuracy: 0.980
     - Exact Match: 0.420
     - F1 Score: 0.535




Checkpoint saved to dpr_checkpoints/dpr_epoch_2.pth
   • Training loss: 0.5116
   • Validation Results:
     - Top-1 accuracy: 0.420
     - Top-5 accuracy: 0.920
     - Top-10 accuracy: 0.960
     - Top-20 accuracy: 0.980
     - Exact Match: 0.420
     - F1 Score: 0.535
   • Memory usage after epoch 2: 21358.45 MB

📅 Epoch 3/7


Training DPR: 100%|██████████| 201/201 [1:00:47<00:00, 18.15s/it]


   • Performing comprehensive validation...
   • Validation Results (subset size: 50):
     - Top-1 accuracy: 0.520
     - Top-5 accuracy: 0.980
     - Top-10 accuracy: 0.980
     - Top-20 accuracy: 0.980
     - Exact Match: 0.520
     - F1 Score: 0.611




Checkpoint saved to dpr_checkpoints/dpr_epoch_3.pth
   • Training loss: 0.3490
   • Validation Results:
     - Top-1 accuracy: 0.520
     - Top-5 accuracy: 0.980
     - Top-10 accuracy: 0.980
     - Top-20 accuracy: 0.980
     - Exact Match: 0.520
     - F1 Score: 0.611
   • Memory usage after epoch 3: 22185.36 MB

📅 Epoch 4/7


Training DPR: 100%|██████████| 201/201 [1:03:50<00:00, 19.06s/it]


   • Performing comprehensive validation...
   • Validation Results (subset size: 50):
     - Top-1 accuracy: 0.500
     - Top-5 accuracy: 0.980
     - Top-10 accuracy: 0.980
     - Top-20 accuracy: 0.980
     - Exact Match: 0.500
     - F1 Score: 0.591




Checkpoint saved to dpr_checkpoints/dpr_epoch_4.pth
   • Training loss: 0.2620
   • Validation Results:
     - Top-1 accuracy: 0.500
     - Top-5 accuracy: 0.980
     - Top-10 accuracy: 0.980
     - Top-20 accuracy: 0.980
     - Exact Match: 0.500
     - F1 Score: 0.591
   • Memory usage after epoch 4: 22547.08 MB

📅 Epoch 5/7


Training DPR: 100%|██████████| 201/201 [1:03:16<00:00, 18.89s/it]


   • Performing comprehensive validation...
   • Validation Results (subset size: 50):
     - Top-1 accuracy: 0.600
     - Top-5 accuracy: 0.980
     - Top-10 accuracy: 0.980
     - Top-20 accuracy: 0.980
     - Exact Match: 0.600
     - F1 Score: 0.668




Checkpoint saved to dpr_checkpoints/dpr_epoch_5.pth
   • Training loss: 0.2126
   • Validation Results:
     - Top-1 accuracy: 0.600
     - Top-5 accuracy: 0.980
     - Top-10 accuracy: 0.980
     - Top-20 accuracy: 0.980
     - Exact Match: 0.600
     - F1 Score: 0.668
   • Memory usage after epoch 5: 22416.53 MB

📅 Epoch 6/7


Training DPR: 100%|██████████| 201/201 [1:01:23<00:00, 18.33s/it]


   • Performing comprehensive validation...
   • Validation Results (subset size: 50):
     - Top-1 accuracy: 0.580
     - Top-5 accuracy: 0.940
     - Top-10 accuracy: 0.980
     - Top-20 accuracy: 1.000
     - Exact Match: 0.580
     - F1 Score: 0.657




Checkpoint saved to dpr_checkpoints/dpr_epoch_6.pth
   • Training loss: 0.2022
   • Validation Results:
     - Top-1 accuracy: 0.580
     - Top-5 accuracy: 0.940
     - Top-10 accuracy: 0.980
     - Top-20 accuracy: 1.000
     - Exact Match: 0.580
     - F1 Score: 0.657
   • Memory usage after epoch 6: 21917.97 MB

📅 Epoch 7/7


Training DPR: 100%|██████████| 201/201 [1:00:30<00:00, 18.06s/it]


   • Performing comprehensive validation...
   • Validation Results (subset size: 50):
     - Top-1 accuracy: 0.580
     - Top-5 accuracy: 0.960
     - Top-10 accuracy: 0.980
     - Top-20 accuracy: 1.000
     - Exact Match: 0.580
     - F1 Score: 0.663




Checkpoint saved to dpr_checkpoints/dpr_epoch_7.pth
   • Training loss: 0.1712
   • Validation Results:
     - Top-1 accuracy: 0.580
     - Top-5 accuracy: 0.960
     - Top-10 accuracy: 0.980
     - Top-20 accuracy: 1.000
     - Exact Match: 0.580
     - F1 Score: 0.663
   • Memory usage after epoch 7: 22172.33 MB


0,1
dataset/test_size,▁
dataset/total_size,▁
dataset/train_size,▁
dataset/val_size,▁
dpr/batch_loss,▆▅▄▄▅█▇▆▄▇▃▄▄▁▄▂▃▂▄▁▃▄▂▂▂▅▂▃▃▃▂▂▂▄▁▃▂▂▂▂

0,1
dataset/loaded_from_wandb,True
dataset/test_size,201
dataset/total_size,2008
dataset/train_size,1606
dataset/val_size,201
dpr/batch_loss,0.25357


✅ DPR training completed!

🔍 Evaluating DPR on test set...
   • Test set size: 201 QA pairs
✅ DPR Test Results:
   • Test set size: 201
   • Top-1 accuracy: 0.353
   • Top-5 accuracy: 0.692
   • Top-10 accuracy: 0.896
   • Top-20 accuracy: 0.950
   • QA Exact Match: 0.353
   • QA F1 Score: 0.454


0,1
dpr/test_exact_match,▁
dpr/test_f1,▁
dpr/test_top_10_accuracy,▁
dpr/test_top_1_accuracy,▁
dpr/test_top_20_accuracy,▁
dpr/test_top_5_accuracy,▁

0,1
dpr/test_exact_match,0.35323
dpr/test_f1,0.4543
dpr/test_top_10_accuracy,0.89552
dpr/test_top_1_accuracy,0.35323
dpr/test_top_20_accuracy,0.95025
dpr/test_top_5_accuracy,0.69154



📊 DPR Epoch-by-Epoch Summary:
Epoch  Loss     Top-1    Top-5    Top-10   Top-20   EM       F1      
-----------------------------------------------------------------
2      0.5116   0.420    0.920    0.960    0.980    0.420    0.535   
3      0.3490   0.520    0.980    0.980    0.980    0.520    0.611   
4      0.2620   0.500    0.980    0.980    0.980    0.500    0.591   
5      0.2126   0.600    0.980    0.980    0.980    0.600    0.668   
6      0.2022   0.580    0.940    0.980    1.000    0.580    0.657   
7      0.1712   0.580    0.960    0.980    1.000    0.580    0.663   


In [None]:
# ===== SPLADE TRAINING AND EVALUATION =====
print(f"\n🔍 SPLADE NEURAL SPARSE RETRIEVAL TRAINING & EVALUATION")
print("-" * 60)

if DL_LIBRARIES_AVAILABLE and splade_trainer is not None:
    try:
        # Training parameters
        num_epochs = 7 # Reduced for resource efficiency
        batch_size = 4 # Reduced for resource efficiency

        print(f"🎯 SPLADE Training configuration:")
        print(f"   • Epochs: {num_epochs}")
        print(f"   • Batch size: {batch_size}")
        print(f"   • Regularization: FLOPS (λ_q={splade_trainer.lambda_q}, λ_d={splade_trainer.lambda_d})")

        # Training loop with comprehensive epoch-by-epoch metrics tracking
        print(f"\n🏃 Training SPLADE...")
        training_losses = []
        epoch_metrics_history = []  # Store all epoch metrics

        # Use wandb.init() as a context manager for training
        with wandb.init(project="Domain QA", name="SPLADE Training", config={
            "model_name": 'bert-base-uncased', # Assuming this is the model name
            "learning_rate": splade_trainer.optimizer.param_groups[0]['lr'], # Get current LR
            "num_epochs": num_epochs,
            "batch_size": batch_size,
            "lambda_q": splade_trainer.lambda_q,
            "lambda_d": splade_trainer.lambda_d,
            "negative_sampling_strategy": "random_from_corpus"
        }) as run:
            # ===== DATASET SPLIT HANDLING WITHIN TRAINING RUN =====
            print("📊 Checking for existing dataset splits in wandb...")
            dataset_loaded_from_wandb = False
            dataset_artifact_name = "covid-qa-splits"  # Same artifact name as DPR for consistency

            # Check if artifact already exists and load or create
            try:
                existing_artifact = run.use_artifact(f"{dataset_artifact_name}:latest", type="dataset_split")
                print(f"✅ Found existing dataset splits, loading from wandb (same as DPR):")
                print(f"   • Using artifact: {existing_artifact.name}")

                # Load the dataset splits from wandb tables
                train_table = existing_artifact.get("train_data_frame")
                val_table = existing_artifact.get("val_data_frame")
                test_table = existing_artifact.get("test_data_frame")

                # Convert back to lists
                train_df_loaded = train_table.get_dataframe()
                val_df_loaded = val_table.get_dataframe()
                test_df_loaded = test_table.get_dataframe()

                # Update variables to use loaded splits (same as DPR)
                train_questions_loaded = train_df_loaded['question'].tolist()
                train_answers_loaded = train_df_loaded['answer'].tolist()
                val_questions_loaded = val_df_loaded['question'].tolist()
                val_answers_loaded = val_df_loaded['answer'].tolist()
                test_questions_loaded = test_df_loaded['question'].tolist()
                test_answers_loaded = test_df_loaded['answer'].tolist()

                # Use loaded splits for consistency with DPR
                train_questions = train_questions_loaded
                train_answers = train_answers_loaded
                val_questions = val_questions_loaded
                val_answers = val_answers_loaded
                test_questions = test_questions_loaded
                test_answers = test_answers_loaded

                # Update training data and corpus
                train_subset_questions = train_questions[:len(train_questions)]  # Use full training set
                train_subset_answers = train_answers[:len(train_answers)]
                corpus_for_negatives = train_answers.copy()

                dataset_loaded_from_wandb = True
                print(f"   • Training set: {len(train_questions):,} QA pairs")
                print(f"   • Validation set: {len(val_questions):,} QA pairs")
                print(f"   • Test set: {len(test_questions):,} QA pairs")

            except Exception as e:
                print(f"ℹ️ No existing dataset splits found: {e}")
                print(f"📊 SPLADE will use current dataset splits (not creating new artifact)")
                dataset_loaded_from_wandb = False
                # Note: Don't create new artifact here since DPR should have created it

            # Prepare training data with negatives using the correct dataset
            print("🔧 Preparing SPLADE training data with negative sampling...")

            # Create negative samples for training
            splade_train_data = []
            possible_negatives_corpus = [a for a in corpus_for_negatives if a not in train_subset_answers] # Use answers not in the training subset as potential negatives
            if not possible_negatives_corpus:
                 possible_negatives_corpus = corpus_for_negatives # Fallback to full corpus if no non-training answers available

            for i, (question, pos_answer) in enumerate(zip(train_subset_questions, train_subset_answers)):
                # Random negative from corpus
                # Ensure the negative sample is different from the positive sample
                possible_negatives = [a for a in possible_negatives_corpus if a != pos_answer]
                if possible_negatives: # Only add if a negative sample can be found
                    neg_answer = np.random.choice(possible_negatives)
                    splade_train_data.append((question, pos_answer, neg_answer))
                else:
                     print(f"Skipping training example {i} due to lack of suitable negative sample.")

            if not splade_train_data:
                print("⚠️ No valid SPLADE training data triplets could be generated. Skipping SPLADE training.")
                results['SPLADE_Trained'] = {'exact_match': 0.0, 'f1': 0.0, 'count': 0, 'test_size': 0, 'epoch_metrics_history': []}
                return

            print(f"   • Training triplets: {len(splade_train_data):,}")
            print(f"   • Corpus for negatives: {len(corpus_for_negatives):,} passages")

            # Log dataset info
            wandb.log({
                "dataset/train_size": len(train_questions),
                "dataset/val_size": len(val_questions),
                "dataset/test_size": len(test_questions),
                "dataset/total_size": len(train_questions) + len(val_questions) + len(test_questions),
                "dataset/loaded_from_wandb": dataset_loaded_from_wandb,
                "dataset/splade_triplets": len(splade_train_data),
                "dataset/unified_with_dpr": True
            })

            # Log initial dataset and hyperparameters
            if wandb.run:
                wandb.config.update({
                    "dataset_train_size": len(splade_train_data),
                    "dataset_val_size": len(val_questions),
                    "dataset_test_size": len(test_questions),
                    "model_name": 'bert-base-uncased', # Assuming this is the model name
                    "learning_rate": splade_trainer.optimizer.param_groups[0]['lr'], # Get current LR
                    "num_epochs": num_epochs,
                    "batch_size": batch_size,
                    "lambda_q": splade_trainer.lambda_q,
                    "lambda_d": splade_trainer.lambda_d,
                    "unified_dataset": True,  # Flag to indicate same dataset as DPR
                    "negative_sampling_strategy": "random_from_corpus",
                    "dataset_loaded_from_wandb": dataset_loaded_from_wandb
                })
                print("Logged dataset size and hyperparameters to wandb.")

            for epoch in range(num_epochs):
                print(f"\n📅 Epoch {epoch + 1}/{num_epochs}")

                # Train one epoch and get comprehensive validation metrics
                epoch_metrics = splade_trainer.train_epoch(splade_train_data, batch_size=batch_size, epoch=epoch, val_questions=val_questions, val_answers=val_answers)

                # Store metrics
                training_losses.append(epoch_metrics['total_loss'])
                epoch_metrics_history.append(epoch_metrics)

                print(f"   • Training Losses:")
                print(f"     - Total loss: {epoch_metrics['total_loss']:.4f}")
                print(f"     - Ranking loss: {epoch_metrics['ranking_loss']:.4f}")
                print(f"     - Regularization loss: {epoch_metrics['reg_loss']:.4f}")
                print(f"   • Validation Results:")
                print(f"     - Top-1 accuracy: {epoch_metrics['top_1_accuracy']:.3f}")
                print(f"     - Top-5 accuracy: {epoch_metrics['top_5_accuracy']:.3f}")
                print(f"     - Top-10 accuracy: {epoch_metrics['top_10_accuracy']:.3f}")
                print(f"     - Top-20 accuracy: {epoch_metrics['top_20_accuracy']:.3f}")
                print(f"     - Exact Match: {epoch_metrics['exact_match']:.3f}")
                print(f"     - F1 Score: {epoch_metrics['f1']:.3f}")

        print(f"✅ SPLADE training completed!")

        # ===== SPLADE EVALUATION ON TEST SET =====
        print(f"\n🔍 Evaluating SPLADE on test set...")
        print(f"   • Using same test set as DPR for fair comparison")

        # Use full test set for evaluation (same as DPR)
        test_size = len(test_questions)
        print(f"   • Test set size: {test_size:,} QA pairs")

        # Encode test documents
        test_doc_representations = splade_trainer.encode_for_retrieval(test_answers)

        # Get sparsity statistics
        sparsity_stats = splade_trainer.get_sparse_statistics(test_doc_representations)
        print(f"   • Sparsity statistics:")
        print(f"      - Avg non-zero terms: {sparsity_stats['avg_non_zero_terms']:.1f}")
        print(f"      - Avg sparsity: {sparsity_stats['avg_sparsity']:.3f}")

        # Evaluate on test set
        splade_predictions = []
        top_k_accuracies = {1: 0, 5: 0, 10: 0, 20: 0}  # Added top-20

        for i, question in enumerate(test_questions):
            results_list = splade_trainer.retrieve(question, test_doc_representations, test_answers, top_k=20)  # Changed to top-20

            if results_list:
                # Get top-1 prediction
                best_answer = results_list[0][2]
                splade_predictions.append(best_answer)

                # Calculate top-k accuracies
                correct_answer = test_answers[i]
                retrieved_answers = [r[2] for r in results_list]

                for k in [1, 5, 10, 20]:  # Added top-20
                    if correct_answer in retrieved_answers[:k]:
                        top_k_accuracies[k] += 1
            else:
                splade_predictions.append("")

        # Calculate final metrics
        for k in top_k_accuracies:
            top_k_accuracies[k] = top_k_accuracies[k] / test_size if test_size > 0 else 0.0

        # QA performance metrics
        splade_qa_results = evaluator.evaluate_qa_performance(splade_predictions, test_answers)

        # Store results with epoch-by-epoch metrics
        results['SPLADE_Trained'] = {
            **splade_qa_results,
            'top_1_accuracy': top_k_accuracies[1],
            'top_5_accuracy': top_k_accuracies[5],
            'top_10_accuracy': top_k_accuracies[10],
            'top_20_accuracy': top_k_accuracies[20],  # Added top-20
            'training_losses': training_losses,
            'epoch_metrics_history': epoch_metrics_history,  # Store comprehensive epoch history
            'sparsity_stats': sparsity_stats,
            'test_size': test_size,
            'unified_dataset': True,  # Flag to indicate same dataset as DPR
            'dataset_loaded_from_wandb': dataset_loaded_from_wandb
        }

        print(f"✅ SPLADE Test Results:")
        print(f"   • Test set size: {test_size:,}")
        print(f"   • Top-1 accuracy: {top_k_accuracies[1]:.3f}")
        print(f"   • Top-5 accuracy: {top_k_accuracies[5]:.3f}")
        print(f"   • Top-10 accuracy: {top_k_accuracies[10]:.3f}")
        print(f"   • Top-20 accuracy: {top_k_accuracies[20]:.3f}")  # Added top-20
        print(f"   • QA Exact Match: {splade_qa_results['exact_match']:.3f}")
        print(f"   • QA F1 Score: {splade_qa_results['f1']:.3f}")

        # Log test results to wandb
        with wandb.init(project="Domain QA", name="SPLADE Training", resume=True) as run:
            wandb.log({
                "splade/test_top_1_accuracy": top_k_accuracies[1],
                "splade/test_top_5_accuracy": top_k_accuracies[5],
                "splade/test_top_10_accuracy": top_k_accuracies[10],
                "splade/test_top_20_accuracy": top_k_accuracies[20],
                "splade/test_exact_match": splade_qa_results['exact_match'],
                "splade/test_f1": splade_qa_results['f1'],
                "splade/avg_non_zero_terms": sparsity_stats['avg_non_zero_terms'],
                "splade/avg_sparsity": sparsity_stats['avg_sparsity']
            })

        # Print epoch-by-epoch summary
        print(f"\n📊 SPLADE Epoch-by-Epoch Summary:")
        print(f"{'Epoch':<6} {'Total Loss':<12} {'Ranking Loss':<14} {'Reg Loss':<10} {'Top-1':<8} {'Top-5':<8} {'Top-10':<8} {'Top-20':<8} {'EM':<8} {'F1':<8}")
        print("-" * 95)
        for i, metrics in enumerate(epoch_metrics_history):
            epoch_num = i + 1
            print(f"{epoch_num:<6} {metrics['total_loss']:<12.4f} {metrics['ranking_loss']:<14.4f} {metrics['reg_loss']:<10.4f} {metrics['top_1_accuracy']:<8.3f} {metrics['top_5_accuracy']:<8.3f} {metrics['top_10_accuracy']:<8.3f} {metrics['top_20_accuracy']:<8.3f} {metrics['exact_match']:<8.3f} {metrics['f1']:<8.3f}")

    except Exception as e:
        print(f"❌ SPLADE training/evaluation failed: {e}")
        import traceback
        traceback.print_exc()
        results['SPLADE_Trained'] = {'exact_match': 0.0, 'f1': 0.0, 'count': 0, 'test_size': 0, 'epoch_metrics_history': []}

else:
    print("⚠️ SPLADE trainer not available - using placeholder results")
    results['SPLADE_Trained'] = {'exact_match': 0.0, 'f1': 0.0, 'count': 0, 'test_size': 0, 'epoch_metrics_history': []}

In [None]:
# Step 6: Analysis and Comparison with Traditional Methods

print("🔬 COMPREHENSIVE ANALYSIS: DEEP LEARNING vs TRADITIONAL METHODS")
print("=" * 70)

# Load traditional methods results (from Chapter 4)
traditional_results = {
    'TF-IDF': {'exact_match': 0.275, 'f1': 0.403, 'count': 200},
    'BM25': {'exact_match': 0.385, 'f1': 0.505, 'count': 200}
}

# Combine all results (handle case where results might be empty)
if results:
    all_results = {**traditional_results, **results}
else:
    all_results = traditional_results

print("📊 COMPLETE PERFORMANCE COMPARISON:")
print(f"{'Method':<15} {'Type':<12} {'Exact Match':<12} {'F1 Score':<10} {'Evaluated':<10}")
print("-" * 70)

# Traditional methods
print(f"{'TF-IDF':<15} {'Traditional':<12} {traditional_results['TF-IDF']['exact_match']:<12.3f} {traditional_results['TF-IDF']['f1']:<10.3f} {traditional_results['TF-IDF']['count']:<10}")
print(f"{'BM25':<15} {'Traditional':<12} {traditional_results['BM25']['exact_match']:<12.3f} {traditional_results['BM25']['f1']:<10.3f} {traditional_results['BM25']['count']:<10}")

# Deep learning methods
for method, result in results.items():
    method_type = 'Dense' if method == 'DPR' else 'Neural Sparse'
    print(f"{method:<15} {method_type:<12} {result['exact_match']:<12.3f} {result['f1']:<10.3f} {result['count']:<10}")

# Find best overall method
best_method = max(all_results.keys(), key=lambda x: all_results[x]['f1'])
print(f"\n🏆 BEST OVERALL METHOD: {best_method}")
print(f"   • F1 Score: {all_results[best_method]['f1']:.3f}")
print(f"   • Exact Match: {all_results[best_method]['exact_match']:.3f}")

# Method comparison analysis
print(f"\n💡 DETAILED ANALYSIS:")
print("=" * 50)

print(f"1. TRADITIONAL METHODS PERFORMANCE:")
avg_traditional_f1 = np.mean([traditional_results[m]['f1'] for m in traditional_results])
print(f"   • Average F1: {avg_traditional_f1:.3f}")
print(f"   • BM25 outperforms TF-IDF by {traditional_results['BM25']['f1'] - traditional_results['TF-IDF']['f1']:.3f} F1 points")
print(f"   • Strong baseline for medical domain terminology")

if results:
    print(f"\n2. DEEP LEARNING METHODS PERFORMANCE:")
    avg_dl_f1 = np.mean([results[m]['f1'] for m in results])
    print(f"   • Average F1: {avg_dl_f1:.3f}")

    if 'DPR' in results and 'SPLADE' in results:
        dpr_vs_splade = results['DPR']['f1'] - results['SPLADE']['f1']
        better_dl = 'DPR' if dpr_vs_splade > 0 else 'SPLADE'
        print(f"   • {better_dl} outperforms by {abs(dpr_vs_splade):.3f} F1 points")

    print(f"   • Semantic understanding vs interpretability trade-off")

    print(f"\n3. TRADITIONAL vs DEEP LEARNING:")
    best_traditional = max(traditional_results.keys(), key=lambda x: traditional_results[x]['f1'])
    best_dl = max(results.keys(), key=lambda x: results[x]['f1'])

    improvement = results[best_dl]['f1'] - traditional_results[best_traditional]['f1']
    if improvement > 0:
        print(f"   • Deep learning improves by {improvement:.3f} F1 points")
        print(f"   • {best_dl} beats {best_traditional} by {improvement:.3f}")
    else:
        print(f"   • Traditional methods competitive: {abs(improvement):.3f} gap")
        print(f"   • {best_traditional} beats {best_dl} by {abs(improvement):.3f}")

print(f"\n4. MEDICAL DOMAIN INSIGHTS:")
print(f"   • COVID-19 terminology challenges all methods")
print(f"   • Semantic methods help with medical synonyms")
print(f"   • Sparse methods maintain interpretability")
print(f"   • Long medical answers reduce exact match scores")

print(f"\n5. METHOD CHARACTERISTICS:")
print(f"   • TF-IDF: Fast, interpretable, keyword-based")
print(f"   • BM25: Probabilistic ranking, length normalization")
print(f"   • DPR: Dense semantic similarity, handles synonyms")
print(f"   • SPLADE: Sparse neural expansion, interpretable")

print(f"\n🎯 RECOMMENDATIONS:")
print("=" * 50)
print(f"• For PRODUCTION: BM25 (reliable, fast, interpretable)")
print(f"• For RESEARCH: DPR/SPLADE (semantic understanding)")
print(f"• For MEDICAL DOMAIN: Hybrid approach combining methods")
print(f"• For INTERPRETABILITY: SPLADE (sparse + neural)")

print(f"\n✅ COMPREHENSIVE ANALYSIS COMPLETE!")
print(f"📊 All methods evaluated on COVID-QA medical domain dataset")
print(f"🔬 Results demonstrate trade-offs between speed, accuracy, and interpretability")