### Active Mini-Batch Stochastic Variational Inference

Authors: Sushil Bohara, Dequan Yang, Bishnu Dev

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, Subset, Dataset
from torchvision import datasets, transforms
import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from tqdm import tqdm
import time
import random
from collections import Counter
from sklearn.cluster import KMeans

##### Setup

In [None]:
# Set random seed for reproducibility
SEED = 42
pyro.set_rng_seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Data loading parameters
BATCH_SIZE = 128
TEST_BATCH_SIZE = 1000
DATA_DIR = './data'

# Model parameters
INPUT_SIZE = 28 * 28  # Fashion-MNIST images are 28x28
OUTPUT_SIZE = 10      # 10 clothing item classes

# Training parameters
NUM_EPOCHS = 5
LEARNING_RATE = 0.01

# Active learning parameters
UNCERTAINTY_UPDATE_INTERVAL = 2  # Update uncertainty more frequently
ACTIVE_BATCH_SIZE = 128          # Size of active minibatch
WARM_UP_ITERATIONS = 200         # Longer warm-up with random sampling
EXPLORATION_RATIO_START = 0.7    # Start with more exploration
EXPLORATION_RATIO_END = 0.3      # End with more exploitation
MAX_ITERATIONS_TO_REUSE = 100    # Prevent reusing same samples too frequently

# Class imbalance parameters
RARE_CLASSES = [7, 9]            # Sneaker and Ankle boot will be rare classes
IMBALANCE_RATIO = 20             # Keep only 1/20 of the rare class samples

In [None]:
class ImbalancedDataset(Dataset):
    """
    Creates an imbalanced dataset from a base dataset
    by undersampling specified rare classes.
    """
    def __init__(self, base_dataset, rare_classes, imbalance_ratio):
        """
        Args:
            base_dataset: Original dataset
            rare_classes: List of class indices to make rare
            imbalance_ratio: Ratio to undersample rare classes (e.g., 10 means keep 1/10 of samples)
        """
        self.base_dataset = base_dataset
        self.rare_classes = rare_classes
        self.imbalance_ratio = imbalance_ratio
        
        # Create an index mapping from our dataset to the base dataset
        self.indices = self._create_imbalanced_indices()
        
        print(f"Original dataset size: {len(base_dataset)}")
        print(f"Imbalanced dataset size: {len(self.indices)}")
        
        # Count class distribution
        self.class_counts = Counter([base_dataset[i][1] for i in self.indices])
        for class_idx in range(10):
            count = self.class_counts[class_idx]
            is_rare = "RARE" if class_idx in rare_classes else "common"
            print(f"Class {class_idx} ({is_rare}): {count} samples")
    
    def _create_imbalanced_indices(self):
        """Create indices list with rare classes undersampled"""
        indices = []
        
        # Group indices by class
        class_indices = {i: [] for i in range(10)}
        for i in range(len(self.base_dataset)):
            _, label = self.base_dataset[i]
            class_indices[label].append(i)
        
        # Add undersampled rare classes and all common classes
        for class_idx, class_specific_indices in class_indices.items():
            if class_idx in self.rare_classes:
                # Undersample rare classes
                random.shuffle(class_specific_indices)
                num_to_keep = max(1, len(class_specific_indices) // self.imbalance_ratio)
                indices.extend(class_specific_indices[:num_to_keep])
            else:
                # Keep all samples of common classes
                indices.extend(class_specific_indices)
                
        return indices
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        base_idx = self.indices[idx]
        return self.base_dataset[base_idx]


class OptimizedActiveDataset(Dataset):
    """
    Optimized dataset wrapper for active learning
    """
    def __init__(self, dataset):
        self.dataset = dataset
        self.uncertainties = torch.zeros(len(dataset))
        self.sample_usage_counter = Counter()  # Track usage of samples
        self.update_required = True
        self.previously_selected = set()  # Track recently selected indices
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        return self.dataset[idx]
    
    def update_uncertainties(self, uncertainties):
        """Update the uncertainty values for all samples"""
        assert len(uncertainties) == len(self.dataset), "Uncertainties must match dataset size"
        self.uncertainties = uncertainties
        self.update_required = False
    
    def get_stratified_batch_indices(self, batch_size):
        """
        Return indices with stratified sampling based on uncertainty:
        - Select from different uncertainty levels proportionally
        - Ensures diversity across uncertainty spectrum
        """
        if self.update_required:
            # If update is required but not performed, return random indices
            return torch.randperm(len(self.dataset))[:batch_size]
        
        # Divide samples into 5 uncertainty buckets (quintiles)
        _, sorted_indices = torch.sort(self.uncertainties)
        bucket_size = len(sorted_indices) // 5
        buckets = [
            sorted_indices[0:bucket_size],                    # Very certain
            sorted_indices[bucket_size:2*bucket_size],        # Somewhat certain
            sorted_indices[2*bucket_size:3*bucket_size],      # Medium certainty
            sorted_indices[3*bucket_size:4*bucket_size],      # Somewhat uncertain
            sorted_indices[4*bucket_size:],                   # Very uncertain
        ]
        
        # Sample from each bucket proportionally
        # More from higher uncertainty buckets, but some from all
        proportions = [0.1, 0.15, 0.2, 0.25, 0.3]  # Distribution across buckets
        
        selected_indices = []
        for i, bucket in enumerate(buckets):
            # Calculate how many samples to take from this bucket
            n_samples = int(batch_size * proportions[i])
            
            # Filter out recently used samples to reduce repetition
            available_indices = set(bucket.tolist()) - self.previously_selected
            if len(available_indices) < n_samples:  # Fall back if not enough
                available_indices = set(bucket.tolist())
                
            # Convert back to list and get a random sample
            available_indices = list(available_indices)
            if len(available_indices) > 0:
                bucket_indices = random.sample(
                    available_indices, 
                    min(n_samples, len(available_indices))
                )
                selected_indices.extend(bucket_indices)
                
                # Add to recently used set
                self.previously_selected.update(bucket_indices)
        
        # If we didn't get enough samples, fill with random ones
        if len(selected_indices) < batch_size:
            remaining = batch_size - len(selected_indices)
            available_indices = set(range(len(self.dataset))) - set(selected_indices) - self.previously_selected
            if len(available_indices) < remaining:
                available_indices = set(range(len(self.dataset))) - set(selected_indices)
            
            available_indices = list(available_indices)
            additional_indices = random.sample(
                available_indices, 
                min(remaining, len(available_indices))
            )
            selected_indices.extend(additional_indices)
            self.previously_selected.update(additional_indices)
        
        # Limit the size of previously_selected to prevent memory issues
        if len(self.previously_selected) > MAX_ITERATIONS_TO_REUSE:
            self.previously_selected = set(list(self.previously_selected)[-MAX_ITERATIONS_TO_REUSE:])
            
        # Ensure we have batch_size samples
        if len(selected_indices) > batch_size:
            selected_indices = selected_indices[:batch_size]
        elif len(selected_indices) < batch_size:
            # This should rarely happen, but just in case
            remaining = batch_size - len(selected_indices)
            selected_indices.extend(torch.randperm(len(self.dataset))[:remaining].tolist())
            
        # Track usage
        for idx in selected_indices:
            self.sample_usage_counter[idx] += 1
            
        return torch.tensor(selected_indices)
    
    def get_random_batch_indices(self, batch_size):
        """Return random indices for comparison"""
        indices = torch.randperm(len(self.dataset))[:batch_size].tolist()
        
        # Track usage
        for idx in indices:
            self.sample_usage_counter[idx] += 1
            
        return torch.tensor(indices)


class BayesianLogisticRegression(nn.Module):
    """
    Bayesian Logistic Regression model for Fashion-MNIST classification.
    """
    def __init__(self, input_size, output_size):
        super(BayesianLogisticRegression, self).__init__()
        self.linear = nn.Linear(input_size, output_size)
        
    def forward(self, x):
        # Flatten the input image
        x = x.view(-1, INPUT_SIZE)
        # Apply linear transformation
        return self.linear(x)


def model(x_data, y_data=None):
    """
    Bayesian logistic regression model.
    """
    # Define priors
    weight_prior = dist.Normal(0., 1.).expand([OUTPUT_SIZE, INPUT_SIZE]).to_event(2)
    bias_prior = dist.Normal(0., 1.).expand([OUTPUT_SIZE]).to_event(1)
    
    # Sample from priors
    priors = {}
    priors['linear.weight'] = pyro.sample('linear.weight', weight_prior)
    priors['linear.bias'] = pyro.sample('linear.bias', bias_prior)
    
    # Forward pass
    x = x_data.view(-1, INPUT_SIZE)
    logits = torch.matmul(x, priors['linear.weight'].t()) + priors['linear.bias']
    
    # Sample y from the logistic model
    with pyro.plate('data', x.shape[0]):
        obs = pyro.sample('obs', dist.Categorical(logits=logits), obs=y_data)
        
    return logits, obs


def guide(x_data, y_data=None):
    """
    Variational guide (posterior approximation) for the Bayesian logistic regression model.
    """
    # Define variational parameters
    # For the weight matrix
    w_loc = pyro.param('w_loc', torch.zeros(OUTPUT_SIZE, INPUT_SIZE).to(device))
    w_scale = pyro.param('w_scale', torch.ones(OUTPUT_SIZE, INPUT_SIZE).to(device),
                          constraint=dist.constraints.positive)
    
    # For the bias vector
    b_loc = pyro.param('b_loc', torch.zeros(OUTPUT_SIZE).to(device))
    b_scale = pyro.param('b_scale', torch.ones(OUTPUT_SIZE).to(device),
                         constraint=dist.constraints.positive)
    
    # Sample from variational distributions
    w = pyro.sample('linear.weight', dist.Normal(w_loc, w_scale).to_event(2))
    b = pyro.sample('linear.bias', dist.Normal(b_loc, b_scale).to_event(1))


def compute_bayesian_uncertainty(model, guide, x_data, num_samples=10, uncertainty_type='entropy'):
    """
    Compute uncertainty measures for each sample.
    
    Args:
        model: Bayesian model
        guide: Variational guide
        x_data: Input data tensor
        num_samples: Number of posterior samples to use
        uncertainty_type: Type of uncertainty measure to use
                         'entropy': predictive entropy
                         'bald': Bayesian Active Learning by Disagreement
                         'variation_ratio': 1 - max probability
    """
    all_probs = []
    
    # Obtain samples from the posterior
    for _ in range(num_samples):
        # Sample parameters from the guide
        guide_trace = poutine.trace(guide).get_trace(x_data)
        sampled_weights = guide_trace.nodes['linear.weight']['value']
        sampled_bias = guide_trace.nodes['linear.bias']['value']
        
        # Forward pass with sampled parameters
        with torch.no_grad():
            x = x_data.view(-1, INPUT_SIZE)
            logits = torch.matmul(x, sampled_weights.t()) + sampled_bias
            probs = F.softmax(logits, dim=1)
            all_probs.append(probs)
    
    # Stack probabilities from all samples [num_samples, batch_size, num_classes]
    stacked_probs = torch.stack(all_probs)
    
    # Average probabilities across samples [batch_size, num_classes]
    mean_probs = stacked_probs.mean(0)
    
    if uncertainty_type == 'entropy':
        # Predictive entropy: -∑p*log(p)
        entropy = -torch.sum(mean_probs * torch.log(mean_probs + 1e-10), dim=1)
        return entropy
        
    elif uncertainty_type == 'bald':
        # Bayesian Active Learning by Disagreement
        # BALD = H(y|x) - E_θ[H(y|x,θ)]
        
        # First term: entropy of the mean prediction (same as predictive entropy)
        H_mean = -torch.sum(mean_probs * torch.log(mean_probs + 1e-10), dim=1)
        
        # Second term: mean entropy of individual predictions
        sample_entropies = -torch.sum(stacked_probs * torch.log(stacked_probs + 1e-10), dim=2)
        mean_entropy = sample_entropies.mean(0)
        
        # BALD score
        bald = H_mean - mean_entropy
        return bald
        
    elif uncertainty_type == 'variation_ratio':
        # Variation ratio = 1 - max probability
        max_probs, _ = torch.max(mean_probs, dim=1)
        variation_ratio = 1.0 - max_probs
        return variation_ratio
    
    else:
        raise ValueError(f"Unknown uncertainty type: {uncertainty_type}")


def predict(model, guide, x, num_samples=10):
    """
    Make predictions with the trained model.
    """
    # Obtain samples from the posterior
    all_probs = []
    for _ in range(num_samples):
        # Sample parameters from the guide
        guide_trace = poutine.trace(guide).get_trace(x)
        sampled_weights = guide_trace.nodes['linear.weight']['value']
        sampled_bias = guide_trace.nodes['linear.bias']['value']
        
        # Forward pass with sampled parameters
        with torch.no_grad():
            x_flat = x.view(-1, INPUT_SIZE)
            logits = torch.matmul(x_flat, sampled_weights.t()) + sampled_bias
            probs = F.softmax(logits, dim=1)
            all_probs.append(probs)
    
    # Average probabilities across samples
    mean_probs = torch.stack(all_probs).mean(0)
    _, predicted_class = torch.max(mean_probs, 1)
    
    return predicted_class, mean_probs


def evaluate_per_class(model, guide, test_loader, class_names):
    """
    Evaluate the model on test data and report per-class accuracy.
    """
    class_correct = {i: 0 for i in range(OUTPUT_SIZE)}
    class_total = {i: 0 for i in range(OUTPUT_SIZE)}
    overall_correct = 0
    overall_total = 0
    uncertainties = []
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            predicted_class, probs = predict(model, guide, data)
            
            # Calculate prediction entropy as uncertainty measure
            entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=1)
            uncertainties.extend(entropy.cpu().numpy())
            
            # Track per-class accuracy
            for i in range(len(target)):
                label = target[i].item()
                pred = predicted_class[i].item()
                if label == pred:
                    class_correct[label] += 1
                    overall_correct += 1
                class_total[label] += 1
                overall_total += 1
    
    # Print overall accuracy
    overall_accuracy = 100.0 * overall_correct / overall_total
    avg_uncertainty = np.mean(uncertainties)
    
    print(f'Overall Test Accuracy: {overall_accuracy:.2f}%')
    print(f'Average Uncertainty: {avg_uncertainty:.4f}')
    
    # Print per-class accuracy
    print("\nPer-class accuracy:")
    for i in range(OUTPUT_SIZE):
        if class_total[i] > 0:
            accuracy = 100.0 * class_correct[i] / class_total[i]
            print(f'  Class {i} ({class_names[i]}): {accuracy:.2f}% ({class_correct[i]}/{class_total[i]})')
        else:
            print(f'  Class {i} ({class_names[i]}): No test samples')
    
    # Rare class performance is especially important
    rare_class_correct = sum(class_correct[c] for c in RARE_CLASSES)
    rare_class_total = sum(class_total[c] for c in RARE_CLASSES)
    if rare_class_total > 0:
        rare_class_accuracy = 100.0 * rare_class_correct / rare_class_total
        print(f'\nRare class accuracy: {rare_class_accuracy:.2f}% ({rare_class_correct}/{rare_class_total})')
    
    return overall_accuracy, rare_class_accuracy, uncertainties


def train_with_optimized_active_minibatches(model, guide, active_dataset, optimizer, num_epochs, 
                                 batch_size=BATCH_SIZE, update_interval=UNCERTAINTY_UPDATE_INTERVAL,
                                 uncertainty_type='bald', class_names=None):
    """
    Train the model using SVI with optimized active minibatch selection.
    """
    # Define SVI
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
    
    # Training loop
    train_losses = []
    test_overall_accuracies = []
    test_rare_accuracies = []
    
    # Get all data to compute uncertainties when needed
    all_data = torch.stack([active_dataset[i][0] for i in range(len(active_dataset))]).to(device)
    all_labels = torch.tensor([active_dataset[i][1] for i in range(len(active_dataset))]).to(device)
    
    # Create test loader for evaluation
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.2860,), (0.3530,))  # Fashion-MNIST mean and std
    ])
    test_dataset = datasets.FashionMNIST(DATA_DIR, train=False, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE, shuffle=False)
    
    # Evaluate model before training begins to establish a baseline
    print("\nEvaluating baseline model performance...")
    initial_accuracy, initial_rare_accuracy, _ = evaluate_per_class(model, guide, test_loader, class_names)
    test_overall_accuracies.append((0, initial_accuracy))
    test_rare_accuracies.append((0, initial_rare_accuracy))
    
    total_iterations = num_epochs * (len(active_dataset) // batch_size)
    
    # Training iterations
    iteration = 0
    for epoch in range(num_epochs):
        # Initialize loss for this epoch
        epoch_loss = 0.0
        processed_batches = 0
        
        # Number of batches per epoch
        num_batches = len(active_dataset) // batch_size
        
        for batch_idx in tqdm(range(num_batches), desc=f"Epoch {epoch+1}/{num_epochs}"):
            # Initial warm-up phase with random batches
            if iteration < WARM_UP_ITERATIONS:
                batch_indices = active_dataset.get_random_batch_indices(batch_size)
            else:
                # Update uncertainties periodically
                if iteration % update_interval == 0 or active_dataset.update_required:
                    print("\nUpdating uncertainties...")
                    # Compute uncertainties for all samples using the specified method
                    uncertainties = compute_bayesian_uncertainty(
                        model, guide, all_data, 
                        num_samples=10, 
                        uncertainty_type=uncertainty_type
                    )
                    active_dataset.update_uncertainties(uncertainties.cpu())
                
                # Calculate adaptive exploration ratio based on training progress
                # Start with more exploration and gradually focus more on exploitation
                progress = min(1.0, (iteration - WARM_UP_ITERATIONS) / (total_iterations - WARM_UP_ITERATIONS))
                exploration_ratio = EXPLORATION_RATIO_START + progress * (EXPLORATION_RATIO_END - EXPLORATION_RATIO_START)
                
                # Get stratified batch indices
                batch_indices = active_dataset.get_stratified_batch_indices(batch_size)
            
            # Get batch data
            batch_data = all_data[batch_indices]
            batch_labels = all_labels[batch_indices]
            
            # Compute loss on this batch
            batch_loss = svi.step(batch_data, batch_labels)
            epoch_loss += batch_loss
            processed_batches += 1
            
            # Print intermediary results
            if batch_idx % 20 == 0:
                print(f'Batch {batch_idx}/{num_batches}, Loss: {batch_loss / batch_size:.6f}')
            
            # Track metrics consistently, including during warm-up phase
            if iteration % update_interval == 0:
                accuracy, rare_accuracy, _ = evaluate_per_class(model, guide, test_loader, class_names)
                test_overall_accuracies.append((iteration, accuracy))
                test_rare_accuracies.append((iteration, rare_accuracy))
            
            iteration += 1
        
        # Calculate average loss over the epoch
        avg_epoch_loss = epoch_loss / (processed_batches * batch_size)
        train_losses.append(avg_epoch_loss)
        
        print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss:.6f}')
        
        # Evaluate model after each epoch
        accuracy, rare_accuracy, _ = evaluate_per_class(model, guide, test_loader, class_names)
        test_overall_accuracies.append((iteration, accuracy))
        test_rare_accuracies.append((iteration, rare_accuracy))
    
    return train_losses, test_overall_accuracies, test_rare_accuracies


def train_with_random_minibatches(model, guide, active_dataset, optimizer, num_epochs, 
                                  batch_size=BATCH_SIZE, class_names=None):
    """
    Train the model using SVI with random minibatch selection (baseline).
    """
    # Define SVI
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
    
    # Training loop
    train_losses = []
    test_overall_accuracies = []
    test_rare_accuracies = []
    
    # Get all data 
    all_data = torch.stack([active_dataset[i][0] for i in range(len(active_dataset))]).to(device)
    all_labels = torch.tensor([active_dataset[i][1] for i in range(len(active_dataset))]).to(device)
    
    # Create test loader for evaluation
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.2860,), (0.3530,))  # Fashion-MNIST mean and std
    ])
    test_dataset = datasets.FashionMNIST(DATA_DIR, train=False, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE, shuffle=False)
    
    total_iterations = num_epochs * (len(active_dataset) // batch_size)
    
    # Training iterations
    iteration = 0
    for epoch in range(num_epochs):
        # Initialize loss for this epoch
        epoch_loss = 0.0
        processed_batches = 0
        
        # Number of batches per epoch
        num_batches = len(active_dataset) // batch_size
        
        for batch_idx in tqdm(range(num_batches), desc=f"Epoch {epoch+1}/{num_epochs}"):
            # Get random minibatch indices
            batch_indices = active_dataset.get_random_batch_indices(batch_size)
            
            # Get batch data
            batch_data = all_data[batch_indices]
            batch_labels = all_labels[batch_indices]
            
            # Compute loss on this batch
            batch_loss = svi.step(batch_data, batch_labels)
            epoch_loss += batch_loss
            processed_batches += 1
            
            # Print intermediary results
            if batch_idx % 20 == 0:
                print(f'Batch {batch_idx}/{num_batches}, Loss: {batch_loss / batch_size:.6f}')
            
            # For monitoring - compute test accuracy periodically
            if iteration % UNCERTAINTY_UPDATE_INTERVAL == 0 and iteration > 0:
                accuracy, rare_accuracy, _ = evaluate_per_class(model, guide, test_loader, class_names)
                test_overall_accuracies.append((iteration, accuracy))
                test_rare_accuracies.append((iteration, rare_accuracy))
            
            iteration += 1
        
        # Calculate average loss over the epoch
        avg_epoch_loss = epoch_loss / (processed_batches * batch_size)
        train_losses.append(avg_epoch_loss)
        
        print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss:.6f}')
        
        # Evaluate model after each epoch
        accuracy, rare_accuracy, _ = evaluate_per_class(model, guide, test_loader, class_names)
        test_overall_accuracies.append((iteration, accuracy))
        test_rare_accuracies.append((iteration, rare_accuracy))
        
    return train_losses, test_overall_accuracies, test_rare_accuracies


In [None]:
def plot_comparison(active_losses, random_losses, 
                   active_overall_accuracies, random_overall_accuracies,
                   active_rare_accuracies, random_rare_accuracies):
    """
    Plot comparison between active and random training methods with cumulative averages.
    
    Args:
        active_losses: List of losses from active minibatch training
        random_losses: List of losses from random minibatch training
        active_overall_accuracies: List of (iteration, accuracy) tuples from active minibatch
        random_overall_accuracies: List of (iteration, accuracy) tuples from random minibatch
        active_rare_accuracies: List of (iteration, accuracy) tuples for rare classes from active
        random_rare_accuracies: List of (iteration, accuracy) tuples for rare classes from random
    """
    plt.figure(figsize=(15, 20))
    
    # Plot losses
    plt.subplot(4, 1, 1)
    plt.plot(active_losses, 'b-', label='Optimized Active Minibatch')
    plt.plot(random_losses, 'r-', label='Random Minibatch')
    plt.xlabel('Epoch', fontsize=14, fontweight='bold')
    plt.ylabel('ELBO Loss', fontsize=14, fontweight='bold')
    plt.title('Training Loss Comparison', fontsize=16, fontweight='bold')
    plt.legend(fontsize=12)
    plt.grid(True)
    plt.tick_params(axis='both', which='major', labelsize=12)
    
    # Extract iteration and accuracy values
    active_iters = np.array([acc[0] for acc in active_overall_accuracies])
    active_acc = np.array([acc[1] for acc in active_overall_accuracies])
    
    random_iters = np.array([acc[0] for acc in random_overall_accuracies])
    random_acc = np.array([acc[1] for acc in random_overall_accuracies])
    
    # Extract iteration and accuracy values for rare classes
    active_rare_iters = np.array([acc[0] for acc in active_rare_accuracies])
    active_rare_acc = np.array([acc[1] for acc in active_rare_accuracies])
    
    random_rare_iters = np.array([acc[0] for acc in random_rare_accuracies])
    random_rare_acc = np.array([acc[1] for acc in random_rare_accuracies])
    
    # Plot raw overall accuracies
    plt.subplot(4, 1, 2)
    plt.plot(active_iters, active_acc, 'b-', alpha=0.5, label='Active (Raw)')
    plt.plot(random_iters, random_acc, 'r-', alpha=0.5, label='Random (Raw)')
    plt.xlabel('Iteration', fontsize=14, fontweight='bold')
    plt.ylabel('Overall Test Accuracy (%)', fontsize=14, fontweight='bold')
    plt.title('Overall Test Accuracy Comparison', fontsize=16, fontweight='bold')
    plt.legend(fontsize=12)
    plt.grid(True)
    plt.tick_params(axis='both', which='major', labelsize=12)
    
    # Plot raw rare class accuracies
    plt.subplot(4, 1, 3)
    plt.plot(active_rare_iters, active_rare_acc, 'b-', alpha=0.5, label='Active (Raw)')
    plt.plot(random_rare_iters, random_rare_acc, 'r-', alpha=0.5, label='Random (Raw)')
    plt.xlabel('Iteration', fontsize=14, fontweight='bold')
    plt.ylabel('Rare Class Test Accuracy (%)', fontsize=14, fontweight='bold')
    plt.title('Rare Class Test Accuracy Comparison (Sneakers & Ankle Boots)', fontsize=16, fontweight='bold')
    plt.legend(fontsize=12)
    plt.grid(True)
    plt.tick_params(axis='both', which='major', labelsize=12)
    
    # Plot relative improvement: only cumulative average improvements
    plt.subplot(4, 1, 4)
    
    def interpolate_at_iterations(source_iters, source_values, target_iters):
        from scipy.interpolate import interp1d
        if len(source_iters) < 2:
            return np.zeros_like(target_iters)
        sort_idx = np.argsort(source_iters)
        sorted_iters = source_iters[sort_idx]
        sorted_values = source_values[sort_idx]
        interp_func = interp1d(sorted_iters, sorted_values, bounds_error=False, fill_value="extrapolate")
        return interp_func(target_iters)
    
    min_iter = max(min(active_iters), min(random_iters))
    max_iter = min(max(active_iters), max(random_iters))
    comparison_iters = np.linspace(min_iter, max_iter, 100)
    
    active_interp = interpolate_at_iterations(active_iters, active_acc, comparison_iters)
    random_interp = interpolate_at_iterations(random_iters, random_acc, comparison_iters)
    overall_improvement = active_interp - random_interp
    
    active_rare_interp = interpolate_at_iterations(active_rare_iters, active_rare_acc, comparison_iters)
    random_rare_interp = interpolate_at_iterations(random_rare_iters, random_rare_acc, comparison_iters)
    rare_improvement = active_rare_interp - random_rare_interp
    
    # Cumulative averages only
    cumulative_avg_overall = np.cumsum(overall_improvement) / (np.arange(len(overall_improvement)) + 1)
    cumulative_avg_rare = np.cumsum(rare_improvement) / (np.arange(len(rare_improvement)) + 1)
    
    # Print the final cumulative average improvements
    print(f"Final Cumulative Average Overall Improvement: {cumulative_avg_overall[-1]:.2f}%")
    print(f"Final Cumulative Average Rare Class Improvement: {cumulative_avg_rare[-1]:.2f}%")
    
    plt.plot(comparison_iters, cumulative_avg_overall, 'g-', linewidth=2, 
             label='Cumulative Avg Overall Improvement')
    plt.plot(comparison_iters, cumulative_avg_rare, 'm-', linewidth=2, 
             label='Cumulative Avg Rare Class Improvement')
    
    plt.axhline(y=0, color='k', linestyle='--', alpha=0.3)
    plt.xlabel('Iteration', fontsize=14, fontweight='bold')
    plt.ylabel('Average Improvement (Active - Random) (%)', fontsize=13, fontweight='bold')
    plt.title('Average Improvement of Active Learning over Random Sampling', fontsize=14, fontweight='bold')
    plt.legend(fontsize=12)
    plt.grid(True)
    plt.tick_params(axis='both', which='major', labelsize=12)
    
    plt.tight_layout()
    plt.savefig('imbalanced_comparison.png', dpi=300)
    plt.show()

In [None]:
def analyze_dataset_usage(active_dataset, class_names):
    """
    Analyze how the dataset was used during training.
    """
    # Get sample usage statistics
    usage_counts = list(active_dataset.sample_usage_counter.values())
    
    if not usage_counts:
        print("No usage data available")
        return
    
    # Calculate statistics
    min_usage = min(usage_counts)
    max_usage = max(usage_counts)
    mean_usage = sum(usage_counts) / len(usage_counts)
    
    # Calculate how many samples were never used
    never_used = len([count for count in usage_counts if count == 0])
    
    # Calculate histogram data
    plt.figure(figsize=(10, 6))
    plt.hist(usage_counts, bins=20, alpha=0.7)
    plt.axvline(mean_usage, color='r', linestyle='dashed', linewidth=1, label=f'Mean: {mean_usage:.2f}')
    plt.xlabel('Number of Times Selected')
    plt.ylabel('Number of Samples')
    plt.title('Sample Usage Distribution')
    plt.legend()
    plt.grid(True)
    
    print(f"Sample Usage Statistics:")
    print(f"  - Min usage: {min_usage}")
    print(f"  - Max usage: {max_usage}")
    print(f"  - Mean usage: {mean_usage:.2f}")
    print(f"  - Never used: {never_used} samples ({never_used/len(usage_counts)*100:.2f}%)")
    
    # Analyze usage by class
    class_labels = [active_dataset.dataset[i][1] for i in range(len(active_dataset.dataset))]
    usage_by_class = {i: [] for i in range(OUTPUT_SIZE)}
    
    for i, (count, label) in enumerate(zip(usage_counts, class_labels)):
        usage_by_class[label].append(count)
    
    # Print statistics by class
    print("\nUsage statistics by class:")
    for class_idx in range(OUTPUT_SIZE):
        if usage_by_class[class_idx]:
            class_mean = sum(usage_by_class[class_idx]) / len(usage_by_class[class_idx])
            class_max = max(usage_by_class[class_idx]) if usage_by_class[class_idx] else 0
            class_never = len([c for c in usage_by_class[class_idx] if c == 0])
            
            rare_label = "(RARE)" if class_idx in RARE_CLASSES else ""
            print(f"  - Class {class_idx} ({class_names[class_idx]}) {rare_label}:")
            print(f"      Mean usage: {class_mean:.2f}")
            print(f"      Max usage: {class_max}")
            print(f"      Never used: {class_never} samples ({class_never/len(usage_by_class[class_idx])*100:.2f}%)")
    
    plt.savefig('sample_usage_distribution.png')
    plt.show()

In [None]:
def visualize_uncertain_samples(active_dataset, k=10, class_names=None):
    """
    Visualize the most uncertain samples.
    """
    # Get indices of most uncertain samples
    _, indices = torch.sort(active_dataset.uncertainties, descending=True)
    top_k_indices = indices[:k].cpu().numpy()
    
    # Get data and labels
    samples = [active_dataset[idx][0].cpu().numpy().reshape(28, 28) for idx in top_k_indices]
    labels = [active_dataset[idx][1] for idx in top_k_indices]
    uncertainty_values = [active_dataset.uncertainties[idx].item() for idx in top_k_indices]
    
    # Plot
    plt.figure(figsize=(15, 8))
    for i in range(k):
        plt.subplot(2, 5, i+1)
        plt.imshow(samples[i], cmap='gray')
        class_label = f"{labels[i]} ({class_names[labels[i]]})" if class_names else str(labels[i])
        rare_label = "(RARE)" if labels[i] in RARE_CLASSES else ""
        plt.title(f"Label: {class_label}\n{rare_label}\nUncertainty: {uncertainty_values[i]:.4f}")
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('most_uncertain_samples.png')
    plt.show()

In [None]:
def plot_class_distribution(dataset, class_names):
    """
    Visualize the class distribution in the dataset.
    """
    # Count classes
    class_counts = Counter([dataset[i][1] for i in range(len(dataset))])
    
    # Sort by class index
    classes = sorted(class_counts.keys())
    counts = [class_counts[cls] for cls in classes]
    
    # Create class labels with names
    labels = [f"{i} ({class_names[i]})" for i in classes]
    
    # Highlight rare classes
    colors = ['lightblue' if i not in RARE_CLASSES else 'red' for i in classes]
    
    # Plot
    plt.figure(figsize=(12, 6))
    bars = plt.bar(labels, counts, color=colors)
    plt.xlabel('Classes')
    plt.ylabel('Number of Samples')
    plt.title('Class Distribution in Imbalanced Dataset')
    plt.xticks(rotation=45, ha='right')
    
    # Add a legend
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor='lightblue', label='Common Classes'),
        Patch(facecolor='red', label='Rare Classes')
    ]
    plt.legend(handles=legend_elements)
    
    # Add counts on top of the bars
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 5,
                 f'{height}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig('class_distribution.png')
    plt.show()


### Experiments

In [None]:
# Define data transformations 
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.2860,), (0.3530,))  # Fashion-MNIST mean and std
])

# Define Fashion-MNIST class names
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
                'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

# Load Fashion-MNIST dataset
base_dataset = datasets.FashionMNIST(DATA_DIR, train=True, download=True, transform=transform)

# Create imbalanced dataset
imbalanced_dataset = ImbalancedDataset(base_dataset, RARE_CLASSES, IMBALANCE_RATIO)

# Create active dataset wrapper
active_dataset = OptimizedActiveDataset(imbalanced_dataset)

# Visualize class distribution
plot_class_distribution(imbalanced_dataset, class_names)

##### Running Optimized Active Minibatch Selection

In [None]:
# For optimized active minibatch selection
pyro.clear_param_store()
active_optimizer = Adam({"lr": LEARNING_RATE})
active_losses, active_overall_accuracies, active_rare_accuracies = train_with_optimized_active_minibatches(
    model, guide, active_dataset, active_optimizer, NUM_EPOCHS, 
    batch_size=ACTIVE_BATCH_SIZE, update_interval=UNCERTAINTY_UPDATE_INTERVAL,
    uncertainty_type='bald',  # Use BALD measure for better performance
    class_names=class_names
)

# Save active model
active_model_params = {name: param.data.clone() for name, param in pyro.get_param_store().items()}

##### Running Random Minibatch Selection (Baseline)

In [None]:
pyro.clear_param_store()
random_optimizer = Adam({"lr": LEARNING_RATE})
random_losses, random_overall_accuracies, random_rare_accuracies = train_with_random_minibatches(
    model, guide, active_dataset, random_optimizer, NUM_EPOCHS, 
    batch_size=ACTIVE_BATCH_SIZE,
    class_names=class_names
)

# Save random model
random_model_params = {name: param.data.clone() for name, param in pyro.get_param_store().items()}

##### Comparison

In [None]:
plot_comparison(
    active_losses, random_losses,
    active_overall_accuracies, random_overall_accuracies,
    active_rare_accuracies, random_rare_accuracies
)

# Restore active model for visualization
pyro.clear_param_store()
for name, param in active_model_params.items():
    pyro.param(name, param)

# Compute final uncertainties
all_data = torch.stack([active_dataset[i][0] for i in range(len(active_dataset))]).to(device)
uncertainties = compute_bayesian_uncertainty(
    model, guide, all_data, num_samples=10, uncertainty_type='entropy'
)
active_dataset.update_uncertainties(uncertainties.cpu())

# Analyze dataset usage patterns
analyze_dataset_usage(active_dataset, class_names)

# Visualize most uncertain samples
visualize_uncertain_samples(active_dataset, k=10, class_names=class_names)