In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import math
from torch.utils.data import DataLoader
import torch.optim as optim
import time

# Your Custom Implementation (provided)
class RGMLinearLayer(Function):
    @staticmethod
    def forward(ctx, input, weight, bias, backward_dropout_rate, apply_reweighting=True):
        output = torch.matmul(input, weight.T) + bias
        ctx.save_for_backward(input, weight, bias)
        ctx.backward_dropout_rate = backward_dropout_rate
        ctx.apply_reweighting = apply_reweighting
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors
        backward_dropout_rate = ctx.backward_dropout_rate
        apply_reweighting = ctx.apply_reweighting
        dropout_rate = backward_dropout_rate

        # Create dropout mask, ensuring diagonal is always kept
        mask = torch.rand_like(weight) > dropout_rate
        mask.fill_diagonal_(True)

        if apply_reweighting:
            mask_weight = torch.ones_like(weight) / (1 - dropout_rate)
            mask_weight.fill_diagonal_(1.0)
            weighted_mask = mask_weight * mask
        else:
            mask_weight = torch.ones_like(weight)
            # mask_weight.fill_diagonal_(1.0)
            weighted_mask = mask_weight * mask

        grad_input = grad_output @ (weight * weighted_mask)
        grad_weight = grad_output.T @ input
        grad_bias = grad_output.sum(dim=0)

        return grad_input, grad_weight, grad_bias, None, None


class RGMLinear(nn.Module):
    def __init__(self, in_features, out_features, backward_dropout_rate=0.0, apply_reweighting=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        self.bias = nn.Parameter(torch.empty(out_features))
        self.backward_dropout_rate = backward_dropout_rate
        self.apply_reweighting = apply_reweighting
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            torch.nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        return RGMLinearLayer.apply(x, self.weight, self.bias,
                                     self.backward_dropout_rate, self.apply_reweighting)

# SBP Implementation from Paper
class SBPLinear2D(nn.Module):
    def __init__(self, in_features, out_features, bias=True, keep_ratio=0.5):
        super(SBPLinear2D, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.keep_ratio = keep_ratio

        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.randn(out_features))
        else:
            self.register_parameter('bias', None)
        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def _generate_keep_mask(self, batch_size):
        num_keep = int(batch_size * self.keep_ratio)
        keep_mask = torch.zeros(batch_size, dtype=torch.bool)
        indices = torch.randperm(batch_size)[:num_keep]
        keep_mask[indices] = True
        return keep_mask

    def forward(self, x):
        batch_size, in_features = x.shape
        keep_mask = self._generate_keep_mask(batch_size)
        keep_mask = keep_mask.to(x.device)

        output = torch.zeros(batch_size, self.out_features,
                           device=x.device, dtype=x.dtype)

        keep_indices = keep_mask.nonzero(as_tuple=True)[0]
        drop_indices = (~keep_mask).nonzero(as_tuple=True)[0]

        # Forward pass for kept indices WITH gradient computation
        if len(keep_indices) > 0:
            x_keep = x[keep_indices]
            with torch.enable_grad():
                out_keep = F.linear(x_keep, self.weight, self.bias)
            output[keep_indices] = out_keep

        # Forward pass for dropped indices WITHOUT gradient computation
        if len(drop_indices) > 0:
            x_drop = x[drop_indices]
            with torch.no_grad():
                out_drop = F.linear(x_drop, self.weight, self.bias)
            output[drop_indices] = out_drop

        return output


class SimpleNet(nn.Module):
    def __init__(self,
                 input_dim=196,
                 output_dim=10,
                 hidden_dims=[128, 64, 32],
                 layer_type='standard',
                 keep_ratio=0.5,
                 activation='relu'):
        super(SimpleNet, self).__init__()
        self.layer_type = layer_type
        self.input_dim = input_dim

        # Build the complete architecture dimensions
        self.layer_dims = [input_dim] + hidden_dims + [output_dim]

        # Create layers based on architecture
        self.layers = nn.ModuleList()

        for i in range(len(self.layer_dims) - 1):
            in_dim = self.layer_dims[i]
            out_dim = self.layer_dims[i + 1]

            if layer_type == 'standard':
                layer = nn.Linear(in_dim, out_dim)
            elif layer_type == 'rgm':
                layer = RGMLinear(in_dim, out_dim, backward_dropout_rate=1-keep_ratio)
            elif layer_type == 'rgm-no-reweighting':
                layer = RGMLinear(in_dim, out_dim, backward_dropout_rate=1-keep_ratio, apply_reweighting=False)
            elif layer_type == 'sbp':
                layer = SBPLinear2D(in_dim, out_dim, keep_ratio=keep_ratio)
            else:
                raise ValueError(f"Unknown layer_type: {layer_type}")

            self.layers.append(layer)

        # Set activation function
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'sigmoid':
            self.activation = nn.Sigmoid()
        elif activation == 'leaky_relu':
            self.activation = nn.LeakyReLU()
        elif activation == 'gelu':
            self.activation = nn.GELU()
        else:
            raise ValueError(f"Unknown activation: {activation}")

    def forward(self, x):
        # Flatten input to match input_dim
        x = x.view(-1, self.input_dim)

        # Forward through all layers except the last one with activation
        for layer in self.layers[:-1]:
            x = self.activation(layer(x))

        # Final layer without activation
        x = self.layers[-1](x)
        return x

    def get_architecture_info(self):
        """Return information about the network architecture"""
        return {
            'layer_dims': self.layer_dims,
            'layer_type': self.layer_type,
            'num_layers': len(self.layers),
            'total_params': sum(p.numel() for p in self.parameters())
        }

def get_gradients(model):
    """Extract gradients from model parameters"""
    gradients = []
    for param in model.parameters():
        if param.grad is not None:
            gradients.append(param.grad.clone())
    return gradients


def train_and_evaluate_models_time_and_space_analysis(config=None):
    """
    Train models with different gradient methods and compare final accuracies.

    Args:
        config (dict): Configuration dictionary with training parameters
    """
    # Default configuration
    default_config = {
        'num_epochs': 1,
        'learning_rate': 0.001,
        'keep_ratios': [0.3, 0.5, 0.7],
        'batch_size_train': 1024,
        'batch_size_test': 1024,
        'image_size': (14, 14),
        'dataset': 'MNIST',  # Currently only MNIST supported
        'normalize_mean': (0.1307,),
        'normalize_std': (0.3081,),
        'clear_memory_between_models': True,
        'verbose': True,
        'optimizer': 'adam',  # New: optimizer type ('adam', 'sgd', 'rmsprop', 'adamw')
        'optimizer_params': {}  # New: additional optimizer parameters
    }

    # Merge user config with defaults
    if config is None:
        config = default_config
    else:
        for key, value in default_config.items():
            if key not in config:
                config[key] = value

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if config['verbose']:
        print(f"🚀 Starting Training Comparison on {device}")
        print("=" * 60)

    # Load data based on configuration
    transform = transforms.Compose([
        transforms.Resize(config['image_size']),
        transforms.ToTensor(),
        transforms.Normalize(config['normalize_mean'], config['normalize_std'])
    ])

    # Training dataset
    if config['dataset'] == 'MNIST':
        train_dataset = torchvision.datasets.MNIST(
            root='./data', train=True, download=True, transform=transform
        )
        test_dataset = torchvision.datasets.MNIST(
            root='./data', train=False, download=True, transform=transform
        )
    else:
        raise ValueError(f"Dataset {config['dataset']} not supported yet")

    train_loader = DataLoader(train_dataset, batch_size=config['batch_size_train'], shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=config['batch_size_test'], shuffle=False)

    # Results storage - enhanced to include memory and time data
    results = {
        'config': config,
        'keep_ratios': config['keep_ratios'],
        'models': {
            'baseline': {},
            'custom': {},
            'sbp': {}
        }
    }

    # Helper function to create optimizer
    def create_optimizer(model_parameters):
        optimizer_name = config['optimizer'].lower()
        base_params = {'lr': config['learning_rate']}
        base_params.update(config['optimizer_params'])

        if optimizer_name == 'adam':
            return optim.Adam(model_parameters, **base_params)
        elif optimizer_name == 'sgd':
            return optim.SGD(model_parameters, **base_params)
        elif optimizer_name == 'rmsprop':
            return optim.RMSprop(model_parameters, **base_params)
        elif optimizer_name == 'adamw':
            return optim.AdamW(model_parameters, **base_params)
        else:
            raise ValueError(f"Optimizer '{optimizer_name}' not supported. Use 'adam', 'sgd', 'rmsprop', or 'adamw'")

    if config['verbose']:
        print(f"Training Configuration:")
        print(f"  • Dataset: {config['dataset']} ({config['image_size'][0]}×{config['image_size'][1]})")
        print(f"  • Architecture: 4-layer MLP (196→128→64→32→10)")
        print(f"  • Epochs: {config['num_epochs']}")
        print(f"  • Learning Rate: {config['learning_rate']}")
        print(f"  • Optimizer: {config['optimizer'].upper()}")
        print(f"  • Batch Size: {config['batch_size_train']} (train), {config['batch_size_test']} (test)")
        print(f"  • Keep Ratios: {config['keep_ratios']}")
        print("=" * 60)

    # Helper function to clear memory if configured
    def clear_memory_if_needed():
        if config['clear_memory_between_models'] and device.type == 'cuda':
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.synchronize()
            if config['verbose']:
                print("🧹 GPU memory cleared")

    # Train baseline model
    if config['verbose']:
        print("\n🔧 Training Baseline Model (Standard Backpropagation)...")

    clear_memory_if_needed()
    baseline_model = SimpleNet(layer_type='standard').to(device)
    baseline_optimizer = create_optimizer(baseline_model.parameters())

    baseline_results = train_model(
        model=baseline_model,
        optimizer=baseline_optimizer,
        train_loader=train_loader,
        test_loader=test_loader,
        num_epochs=config['num_epochs'],
        device=device,
        model_name="Baseline",
        clear_memory_first=False  # Already cleared above
    )

    results['models']['baseline'] = baseline_results

    # Train models with different keep ratios
    for keep_ratio in config['keep_ratios']:
        if config['verbose']:
            print(f"\n🔧 Training Models with Keep Ratio: {keep_ratio}")

        # Custom method
        if config['verbose']:
            print(f"  → Reweighted Gradient Message (Keep Ratio={keep_ratio:.1f})...")

        clear_memory_if_needed()
        custom_model = SimpleNet(layer_type='rgm', keep_ratio=keep_ratio).to(device)
        custom_optimizer = create_optimizer(custom_model.parameters())

        custom_results = train_model(
            model=custom_model,
            optimizer=custom_optimizer,
            train_loader=train_loader,
            test_loader=test_loader,
            num_epochs=config['num_epochs'],
            device=device,
            model_name=f"Reweighted Gradient Message-{keep_ratio}",
            clear_memory_first=False  # Already cleared above
        )

        # SBP method
        if config['verbose']:
            print(f"  → SBP Method (keep_ratio={keep_ratio:.1f})...")

        clear_memory_if_needed()
        sbp_model = SimpleNet(layer_type='sbp', keep_ratio=keep_ratio).to(device)
        sbp_optimizer = create_optimizer(sbp_model.parameters())

        sbp_results = train_model(
            model=sbp_model,
            optimizer=sbp_optimizer,
            train_loader=train_loader,
            test_loader=test_loader,
            num_epochs=config['num_epochs'],
            device=device,
            model_name=f"SBP-{keep_ratio}",
            clear_memory_first=False  # Already cleared above
        )

        # Store results with keep_ratio as key
        if 'custom' not in results['models']:
            results['models']['custom'] = {}
        if 'sbp' not in results['models']:
            results['models']['sbp'] = {}

        results['models']['custom'][keep_ratio] = custom_results
        results['models']['sbp'][keep_ratio] = sbp_results

    return results


def evaluate_model(model, test_loader, device):
    """Evaluate model on test set."""
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)

    return 100. * correct / total


def train_model(model, optimizer, train_loader, test_loader, num_epochs, device, model_name,
                clear_memory_first=True, track_time=False, track_memory=False):
    """Train a single model and return results dictionary with optional time, VRAM usage, and accuracy tracking."""

    # Clear GPU memory before starting if requested
    if clear_memory_first and device.type == 'cuda':
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats(device)
        torch.cuda.synchronize()

    model.train()
    criterion = nn.CrossEntropyLoss()

    # Initialize time tracking variables
    if track_time:
        start_time = time.time()

    # GPU memory tracking if available and requested
    if track_memory and device.type == 'cuda':
        torch.cuda.reset_peak_memory_stats(device)
        initial_gpu_memory = torch.cuda.memory_allocated(device) / 1024 / 1024  # MB
        max_gpu_memory = initial_gpu_memory
    else:
        initial_gpu_memory = 0
        max_gpu_memory = 0

    # Lists to store epoch-wise metrics
    epoch_losses = []
    epoch_test_accuracies = []

    for epoch in range(num_epochs):
        total_loss = 0
        correct = 0
        total = 0

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)

            # Track GPU memory usage only if requested
            if track_memory and device.type == 'cuda':
                current_gpu_memory = torch.cuda.memory_allocated(device) / 1024 / 1024  # MB
                max_gpu_memory = max(max_gpu_memory, current_gpu_memory)

        # Store epoch metrics
        avg_loss = total_loss / len(train_loader)
        epoch_losses.append(avg_loss)

        # Evaluate on test set for this epoch
        test_acc = evaluate_model(model, test_loader, device)
        epoch_test_accuracies.append(test_acc)

        # Print progress every 5 epochs
        if (epoch + 1) % 5 == 0:
            train_acc = 100. * correct / total
            print(f" Epoch {epoch+1:2d}/{num_epochs}: Loss={avg_loss:.4f}, "
                  f"Train Acc={train_acc:.2f}%, Test Acc={test_acc:.2f}%")

    # Calculate total training time only if tracking is enabled
    if track_time:
        end_time = time.time()
        training_time = end_time - start_time

    # Final evaluation
    final_test_acc = epoch_test_accuracies[-1]  # Use last epoch's test accuracy
    print(f" ✅ {model_name} Final Test Accuracy: {final_test_acc:.2f}%")

    # Create results dictionary
    results = {
        'model_name': model_name,
        'test_accuracy': final_test_acc,
        'epoch_losses': epoch_losses,
        'epoch_test_accuracies': epoch_test_accuracies,
    }

    # Add time tracking results if enabled
    if track_time:
        results.update({
            'training_time_seconds': training_time,
            'training_time_minutes': training_time / 60,
        })

    # Add GPU memory stats if available and tracking is enabled
    if track_memory and device.type == 'cuda':
        peak_gpu_memory = torch.cuda.max_memory_allocated(device) / 1024 / 1024  # MB
        results['gpu_memory_usage'] = {
            'initial_gpu_memory_mb': initial_gpu_memory,
            'max_gpu_memory_mb': max_gpu_memory,
            'peak_gpu_memory_mb': peak_gpu_memory,
            'gpu_memory_increase_mb': peak_gpu_memory - initial_gpu_memory
        }
        print(f" 🎮 GPU memory usage: {peak_gpu_memory:.2f}MB (increase: {peak_gpu_memory-initial_gpu_memory:.2f}MB)")
    elif track_memory and device.type != 'cuda':
        print(f" ⚠️ GPU not available - no VRAM tracking")

    # Print time summary only if tracking is enabled
    if track_time:
        print(f" ⏱️ Training time: {training_time:.2f}s ({training_time/60:.2f}m)")

    # Optional: Clear memory after training for next model
    if device.type == 'cuda':
        del model, optimizer
        torch.cuda.empty_cache()

    return results


def plot_training_results_time_and_space_analysis(results):
    """Create comprehensive visualizations of training results."""
    plt.style.use('seaborn-v0_8' if 'seaborn-v0_8' in plt.style.available else 'default')
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))

    keep_ratios = results['keep_ratios']
    x = np.arange(len(keep_ratios))
    width = 0.25

    # Extract accuracy data from new results structure
    baseline_acc = [results['models']['baseline']['test_accuracy']] * len(keep_ratios)
    custom_acc = [results['models']['custom'][kr]['test_accuracy'] for kr in keep_ratios]
    sbp_acc = [results['models']['sbp'][kr]['test_accuracy'] for kr in keep_ratios]

    # Extract VRAM data from new results structure (convert MB to GB)
    baseline_vram = []
    custom_vram = []
    sbp_vram = []

    for kr in keep_ratios:
        # Baseline VRAM (same for all keep ratios)
        if 'gpu_memory_usage' in results['models']['baseline']:
            baseline_vram.append(results['models']['baseline']['gpu_memory_usage']['peak_gpu_memory_mb'] / 1024)
        else:
            baseline_vram.append(0)

        # Custom VRAM (varies by keep ratio)
        if 'gpu_memory_usage' in results['models']['custom'][kr]:
            custom_vram.append(results['models']['custom'][kr]['gpu_memory_usage']['peak_gpu_memory_mb'] / 1024)
        else:
            custom_vram.append(0)

        # SBP VRAM (varies by keep ratio)
        if 'gpu_memory_usage' in results['models']['sbp'][kr]:
            sbp_vram.append(results['models']['sbp'][kr]['gpu_memory_usage']['peak_gpu_memory_mb'] / 1024)
        else:
            sbp_vram.append(0)

    # Plot 1: Final Accuracies Comparison
    ax1 = axes[0]

    bars1 = ax1.bar(x - width, baseline_acc, width,
                   label='Baseline (Standard BP)', alpha=0.8, color='green')
    bars2 = ax1.bar(x, custom_acc, width,
                   label='Custom Backward Dropout', alpha=0.8, color='blue')
    bars3 = ax1.bar(x + width, sbp_acc, width,
                   label='SBP Method', alpha=0.8, color='red')

    # Add value labels on bars
    for bars in [bars1, bars2, bars3]:
        for bar in bars:
            height = bar.get_height()
            ax1.text(bar.get_x() + bar.get_width()/2., height + 0.1,
                    f'{height:.1f}%', ha='center', va='bottom', fontsize=10, fontweight='bold')

    ax1.set_xlabel('Keep Ratio')
    ax1.set_ylabel('Test Accuracy (%)')
    ax1.set_title('Final Test Accuracy Comparison')
    ax1.set_xticks(x)
    ax1.set_xticklabels([f'{r:.1f}' for r in keep_ratios])
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim([0, 100])

    # Plot 2: VRAM Usage Comparison
    ax2 = axes[1]

    bars1_vram = ax2.bar(x - width, baseline_vram, width,
                        label='Baseline (Standard BP)', alpha=0.8, color='green')
    bars2_vram = ax2.bar(x, custom_vram, width,
                        label='Custom Backward Dropout', alpha=0.8, color='blue')
    bars3_vram = ax2.bar(x + width, sbp_vram, width,
                        label='SBP Method', alpha=0.8, color='red')

    # Add value labels on bars
    for bars in [bars1_vram, bars2_vram, bars3_vram]:
        for bar in bars:
            height = bar.get_height()
            if height > 0:  # Only show label if VRAM tracking is available
                ax2.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                        f'{height:.2f}GB', ha='center', va='bottom', fontsize=10, fontweight='bold')

    ax2.set_xlabel('Keep Ratio')
    ax2.set_ylabel('Max VRAM Usage (GB)')
    ax2.set_title('Maximum VRAM Usage Comparison')
    ax2.set_xticks(x)
    ax2.set_xticklabels([f'{r:.1f}' for r in keep_ratios])
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    return results

In [None]:
# config = {
#     'num_epochs': 1,
#     'learning_rate': 0.001,
#     'keep_ratios': [0.1, 0.3, 0.5, 0.7, 0.9],
#     'batch_size_train': 1024,
#     'batch_size_test': 1024,
#     'clear_memory_between_models': True,
#     'verbose': True,
#     'optimizer': 'sgd',  # New: optimizer type ('adam', 'sgd', 'rmsprop', 'adamw')
#     'optimizer_params': {}  # New: additional optimizer parameters
# }
#
# training_results = train_and_evaluate_models_time_and_space_analysis(config)
# final_plot = plot_training_results_time_and_space_analysis(training_results)

In [None]:
def train_and_evaluate_models(config=None):
    """
    Train models with different gradient methods and compare final accuracies with plots.

    Args:
        config (dict): Configuration dictionary with training parameters
    """
    # Default configuration
    default_config = {
        'num_epochs': 20,  # Changed from 1 to 20 for better plotting
        'learning_rate': 0.001,
        'keep_ratios': [0.3, 0.5, 0.7],
        'batch_size_train': 1024,
        'batch_size_test': 1024,
        'image_size': (14, 14),
        'dataset': 'MNIST',
        'normalize_mean': (0.1307,),
        'normalize_std': (0.3081,),
        'clear_memory_between_models': True,
        'verbose': True,
        'optimizer': 'adam',
        'optimizer_params': {}
    }

    # Merge user config with defaults
    if config is None:
        config = default_config
    else:
        for key, value in default_config.items():
            if key not in config:
                config[key] = value

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if config['verbose']:
        print(f"🚀 Starting Training Comparison on {device}")
        print("=" * 60)

    # Load data based on configuration
    transform = transforms.Compose([
        transforms.Resize(config['image_size']),
        transforms.ToTensor(),
        transforms.Normalize(config['normalize_mean'], config['normalize_std'])
    ])

    # Training dataset
    if config['dataset'] == 'MNIST':
        train_dataset = torchvision.datasets.MNIST(
            root='./data', train=True, download=True, transform=transform
        )
        test_dataset = torchvision.datasets.MNIST(
            root='./data', train=False, download=True, transform=transform
        )
    else:
        raise ValueError(f"Dataset {config['dataset']} not supported yet")

    train_loader = DataLoader(train_dataset, batch_size=config['batch_size_train'], shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=config['batch_size_test'], shuffle=False)

    # Results storage
    results = {
        'config': config,
        'keep_ratios': config['keep_ratios'],
        'models': {
            'baseline': {},
            'rgm': {},
            'rgm-no-reweighting': {},
            'sbp': {}
        }
    }

    # Helper function to create optimizer
    def create_optimizer(model_parameters):
        optimizer_name = config['optimizer'].lower()
        base_params = {'lr': config['learning_rate']}
        base_params.update(config['optimizer_params'])

        if optimizer_name == 'adam':
            return optim.Adam(model_parameters, **base_params)
        elif optimizer_name == 'sgd':
            return optim.SGD(model_parameters, **base_params)
        elif optimizer_name == 'rmsprop':
            return optim.RMSprop(model_parameters, **base_params)
        elif optimizer_name == 'adamw':
            return optim.AdamW(model_parameters, **base_params)
        else:
            raise ValueError(f"Optimizer '{optimizer_name}' not supported. Use 'adam', 'sgd', 'rmsprop', or 'adamw'")

    if config['verbose']:
        print(f"Training Configuration:")
        print(f"  • Dataset: {config['dataset']} ({config['image_size'][0]}×{config['image_size'][1]})")
        print(f"  • Architecture: 4-layer MLP (196→128→64→32→10)")
        print(f"  • Epochs: {config['num_epochs']}")
        print(f"  • Learning Rate: {config['learning_rate']}")
        print(f"  • Optimizer: {config['optimizer'].upper()}")
        print(f"  • Batch Size: {config['batch_size_train']} (train), {config['batch_size_test']} (test)")
        print(f"  • Keep Ratios: {config['keep_ratios']}")
        print("=" * 60)

    # Helper function to clear memory if configured
    def clear_memory_if_needed():
        if config['clear_memory_between_models'] and device.type == 'cuda':
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.synchronize()
            if config['verbose']:
                print("🧹 GPU memory cleared")

    # Train baseline model
    if config['verbose']:
        print("\n🔧 Training Baseline Model (Standard Backpropagation)...")

    clear_memory_if_needed()
    baseline_model = SimpleNet(layer_type='standard').to(device)
    baseline_optimizer = create_optimizer(baseline_model.parameters())

    baseline_results = train_model(
        model=baseline_model,
        optimizer=baseline_optimizer,
        train_loader=train_loader,
        test_loader=test_loader,
        num_epochs=config['num_epochs'],
        device=device,
        model_name="Baseline",
        clear_memory_first=False,
        track_time=False,  # Disable time tracking
        track_memory=False  # Disable memory tracking
    )

    results['models']['baseline'] = baseline_results

    # Train models with different keep ratios
    for keep_ratio in config['keep_ratios']:
        if config['verbose']:
            print(f"\n🔧 Training Models with Keep Ratio: {keep_ratio}")

        # RGM method
        if config['verbose']:
            print(f"  → Reweighted Gradient Message (Keep Ratio={keep_ratio:.1f})...")

        clear_memory_if_needed()
        rgm_model = SimpleNet(layer_type='rgm', keep_ratio=keep_ratio).to(device)
        rgm_optimizer = create_optimizer(rgm_model.parameters())

        rgm_results = train_model(
            model=rgm_model,
            optimizer=rgm_optimizer,
            train_loader=train_loader,
            test_loader=test_loader,
            num_epochs=config['num_epochs'],
            device=device,
            model_name=f"Reweighted Gradient Message-{keep_ratio}",
            clear_memory_first=False,
            track_time=False,  # Disable time tracking
            track_memory=False  # Disable memory tracking
        )

        # Custom method
        if config['verbose']:
            print(f"  → Reweighted Gradient Message w/o Reweighting (Keep Ratio={keep_ratio:.1f})...")

        clear_memory_if_needed()
        rgm_wo_model = SimpleNet(layer_type='rgm-no-reweighting', keep_ratio=keep_ratio).to(device)
        rgm_wo_optimizer = create_optimizer(rgm_wo_model.parameters())

        rgm_wo_results = train_model(
            model=rgm_wo_model,
            optimizer=rgm_wo_optimizer,
            train_loader=train_loader,
            test_loader=test_loader,
            num_epochs=config['num_epochs'],
            device=device,
            model_name=f"Reweighted Gradient Message w/o Reweighting-{keep_ratio}",
            clear_memory_first=False,
            track_time=False,  # Disable time tracking
            track_memory=False  # Disable memory tracking
        )

        # SBP method
        if config['verbose']:
            print(f"  → SBP Method (keep_ratio={keep_ratio:.1f})...")

        clear_memory_if_needed()
        sbp_model = SimpleNet(layer_type='sbp', keep_ratio=keep_ratio).to(device)
        sbp_optimizer = create_optimizer(sbp_model.parameters())

        sbp_results = train_model(
            model=sbp_model,
            optimizer=sbp_optimizer,
            train_loader=train_loader,
            test_loader=test_loader,
            num_epochs=config['num_epochs'],
            device=device,
            model_name=f"SBP-{keep_ratio}",
            clear_memory_first=False,
            track_time=False,  # Disable time tracking
            track_memory=False  # Disable memory tracking
        )

        # Store results with keep_ratio as key
        if 'rgm' not in results['models']:
            results['models']['rgm'] = {}
        if 'rgm-no-reweighting' not in results['models']:
            results['models']['rgm-no-reweighting'] = {}
        if 'sbp' not in results['models']:
            results['models']['sbp'] = {}

        results['models']['rgm'][keep_ratio] = rgm_results
        results['models']['rgm-no-reweighting'][keep_ratio] = rgm_wo_results
        results['models']['sbp'][keep_ratio] = sbp_results

    # Create plots
    _plot_training_curves(results)

    return results


def _plot_training_curves(results):
    """Create loss and test accuracy plots for different methods."""
    import matplotlib.pyplot as plt

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # Plot 1: Training Loss
    epochs = range(1, len(results['models']['baseline']['epoch_losses']) + 1)

    # Baseline
    ax1.plot(epochs, results['models']['baseline']['epoch_losses'],
             label='Baseline', linewidth=2, marker='o', markersize=4)

    # RGM methods
    for keep_ratio in results['keep_ratios']:
        if keep_ratio in results['models']['rgm']:
            ax1.plot(epochs, results['models']['rgm'][keep_ratio]['epoch_losses'],
                     label=f'RGM (keep={keep_ratio})', linewidth=2, marker='s', markersize=4)

    # RGM without reweighting methods
    for keep_ratio in results['keep_ratios']:
        if keep_ratio in results['models']['rgm-no-reweighting']:
            ax1.plot(epochs, results['models']['rgm-no-reweighting'][keep_ratio]['epoch_losses'],
                     label=f'RGM w/o Reweighting (keep={keep_ratio})', linewidth=2, marker='x', markersize=4)

    # SBP methods
    for keep_ratio in results['keep_ratios']:
        if keep_ratio in results['models']['sbp']:
            ax1.plot(epochs, results['models']['sbp'][keep_ratio]['epoch_losses'],
                     label=f'SBP (keep={keep_ratio})', linewidth=2, marker='^', markersize=4)

    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Training Loss')
    ax1.set_title('Training Loss Over Epochs')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Plot 2: Test Accuracy
    ax2.plot(epochs, results['models']['baseline']['epoch_test_accuracies'],
             label='Baseline', linewidth=2, marker='o', markersize=4)

    # RGM methods
    for keep_ratio in results['keep_ratios']:
        if keep_ratio in results['models']['rgm']:
            ax2.plot(epochs, results['models']['rgm'][keep_ratio]['epoch_test_accuracies'],
                     label=f'RGM (keep={keep_ratio})', linewidth=2, marker='s', markersize=4)

    # RGM without reweighting methods
    for keep_ratio in results['keep_ratios']:
        if keep_ratio in results['models']['rgm-no-reweighting']:
            ax2.plot(epochs, results['models']['rgm-no-reweighting'][keep_ratio]['epoch_test_accuracies'],
                     label=f'RGM w/o Reweighting (keep={keep_ratio})', linewidth=2, marker='x', markersize=4)

    # SBP methods
    for keep_ratio in results['keep_ratios']:
        if keep_ratio in results['models']['sbp']:
            ax2.plot(epochs, results['models']['sbp'][keep_ratio]['epoch_test_accuracies'],
                     label=f'SBP (keep={keep_ratio})', linewidth=2, marker='^', markersize=4)

    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Test Accuracy (%)')
    ax2.set_title('Test Accuracy Over Epochs')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    # Print final summary
    print("\n" + "="*60)
    print("📊 FINAL RESULTS SUMMARY")
    print("="*60)
    print(f"{'Method':<25} {'Final Test Acc':<15} {'Final Loss':<12}")
    print("-" * 60)

    baseline = results['models']['baseline']
    print(f"{'Baseline':<25} {baseline['test_accuracy']:<15.2f} {baseline['epoch_losses'][-1]:<12.4f}")

    for keep_ratio in results['keep_ratios']:
        if keep_ratio in results['models']['rgm']:
            custom = results['models']['rgm'][keep_ratio]
            print(f"{f'RGM (keep={keep_ratio})':<25} {custom['test_accuracy']:<15.2f} {custom['epoch_losses'][-1]:<12.4f}")

        if keep_ratio in results['models']['rgm-no-reweighting']:
            custom_wo = results['models']['rgm-no-reweighting'][keep_ratio]
            print(f"{f'RGM w/o Reweighting (keep={keep_ratio})':<25} {custom_wo['test_accuracy']:<15.2f} {custom_wo['epoch_losses'][-1]:<12.4f}")

        if keep_ratio in results['models']['sbp']:
            sbp = results['models']['sbp'][keep_ratio]
            print(f"{f'SBP (keep={keep_ratio})':<25} {sbp['test_accuracy']:<15.2f} {sbp['epoch_losses'][-1]:<12.4f}")


In [None]:
config = {
    'num_epochs': 5,
    'learning_rate': 0.1,
    # 'keep_ratios': [0.1, 0.3, 0.5, 0.7, 0.9],
    'keep_ratios': [0.5],
    'batch_size_train': 1024,
    'batch_size_test': 1024,
    'clear_memory_between_models': True,
    'verbose': True,
    'optimizer': 'sgd',  # New: optimizer type ('adam', 'sgd', 'rmsprop', 'adamw')
    'optimizer_params': {}  # New: additional optimizer parameters
}

training_results_epochs = train_and_evaluate_models(config)

In [None]:
config = {
    'num_epochs': 25,
    'learning_rate': 0.1,
    # 'keep_ratios': [0.1, 0.3, 0.5, 0.7, 0.9],
    'keep_ratios': [0.5],
    'batch_size_train': 1024,
    'batch_size_test': 1024,
    'clear_memory_between_models': True,
    'verbose': True,
    'optimizer': 'sgd',  # New: optimizer type ('adam', 'sgd', 'rmsprop', 'adamw')
    'optimizer_params': {}  # New: additional optimizer parameters
}

training_results_epochs = train_and_evaluate_models(config)

In [None]:
config = {
    'num_epochs': 25,
    'learning_rate': 0.1,
    # 'keep_ratios': [0.1, 0.3, 0.5, 0.7, 0.9],
    'keep_ratios': [0.1],
    'batch_size_train': 1024,
    'batch_size_test': 1024,
    'clear_memory_between_models': True,
    'verbose': True,
    'optimizer': 'sgd',  # New: optimizer type ('adam', 'sgd', 'rmsprop', 'adamw')
    'optimizer_params': {}  # New: additional optimizer parameters
}

training_results_epochs = train_and_evaluate_models(config)

In [None]:
config = {
    'num_epochs': 25,
    'learning_rate': 0.01,
    # 'keep_ratios': [0.1, 0.3, 0.5, 0.7, 0.9],
    'keep_ratios': [0.1],
    'batch_size_train': 1024,
    'batch_size_test': 1024,
    'clear_memory_between_models': True,
    'verbose': True,
    'optimizer': 'sgd',  # New: optimizer type ('adam', 'sgd', 'rmsprop', 'adamw')
    'optimizer_params': {}  # New: additional optimizer parameters
}

training_results_epochs = train_and_evaluate_models(config)

In [None]:
config = {
    'num_epochs': 25,
    'learning_rate': 0.1,
    # 'keep_ratios': [0.1, 0.3, 0.5, 0.7, 0.9],
    'keep_ratios': [0.3],
    'batch_size_train': 1024,
    'batch_size_test': 1024,
    'clear_memory_between_models': True,
    'verbose': True,
    'optimizer': 'sgd',  # New: optimizer type ('adam', 'sgd', 'rmsprop', 'adamw')
    'optimizer_params': {}  # New: additional optimizer parameters
}

training_results_epochs = train_and_evaluate_models(config)

In [None]:
config = {
    'num_epochs': 25,
    'learning_rate': 0.1,
    # 'keep_ratios': [0.1, 0.3, 0.5, 0.7, 0.9],
    'keep_ratios': [0.7],
    'batch_size_train': 1024,
    'batch_size_test': 1024,
    'clear_memory_between_models': True,
    'verbose': True,
    'optimizer': 'sgd',  # New: optimizer type ('adam', 'sgd', 'rmsprop', 'adamw')
    'optimizer_params': {}  # New: additional optimizer parameters
}

training_results_epochs = train_and_evaluate_models(config)

In [None]:
config = {
    'num_epochs': 25,
    'learning_rate': 0.1,
    # 'keep_ratios': [0.1, 0.3, 0.5, 0.7, 0.9],
    'keep_ratios': [0.9],
    'batch_size_train': 1024,
    'batch_size_test': 1024,
    'clear_memory_between_models': True,
    'verbose': True,
    'optimizer': 'sgd',  # New: optimizer type ('adam', 'sgd', 'rmsprop', 'adamw')
    'optimizer_params': {}  # New: additional optimizer parameters
}

training_results_epochs = train_and_evaluate_models(config)

In [None]:
config = {
    'num_epochs': 25,
    'learning_rate': 0.01,
    # 'keep_ratios': [0.1, 0.3, 0.5, 0.7, 0.9],
    'keep_ratios': [0.1],
    'batch_size_train': 1024,
    'batch_size_test': 1024,
    'clear_memory_between_models': True,
    'verbose': True,
    'optimizer': 'adam',  # New: optimizer type ('adam', 'sgd', 'rmsprop', 'adamw')
    'optimizer_params': {}  # New: additional optimizer parameters
}

training_results_epochs = train_and_evaluate_models(config)

In [None]:
config = {
    'num_epochs': 25,
    'learning_rate': 0.01,
    # 'keep_ratios': [0.1, 0.3, 0.5, 0.7, 0.9],
    'keep_ratios': [0.3],
    'batch_size_train': 1024,
    'batch_size_test': 1024,
    'clear_memory_between_models': True,
    'verbose': True,
    'optimizer': 'adam',  # New: optimizer type ('adam', 'sgd', 'rmsprop', 'adamw')
    'optimizer_params': {}  # New: additional optimizer parameters
}

training_results_epochs = train_and_evaluate_models(config)

In [None]:
config = {
    'num_epochs': 25,
    'learning_rate': 0.01,
    # 'keep_ratios': [0.1, 0.3, 0.5, 0.7, 0.9],
    'keep_ratios': [0.5],
    'batch_size_train': 1024,
    'batch_size_test': 1024,
    'clear_memory_between_models': True,
    'verbose': True,
    'optimizer': 'adam',  # New: optimizer type ('adam', 'sgd', 'rmsprop', 'adamw')
    'optimizer_params': {}  # New: additional optimizer parameters
}

training_results_epochs = train_and_evaluate_models(config)

In [None]:
config = {
    'num_epochs': 25,
    'learning_rate': 0.01,
    # 'keep_ratios': [0.1, 0.3, 0.5, 0.7, 0.9],
    'keep_ratios': [0.7],
    'batch_size_train': 1024,
    'batch_size_test': 1024,
    'clear_memory_between_models': True,
    'verbose': True,
    'optimizer': 'adam',  # New: optimizer type ('adam', 'sgd', 'rmsprop', 'adamw')
    'optimizer_params': {}  # New: additional optimizer parameters
}

training_results_epochs = train_and_evaluate_models(config)

In [None]:
config = {
    'num_epochs': 25,
    'learning_rate': 0.01,
    # 'keep_ratios': [0.1, 0.3, 0.5, 0.7, 0.9],
    'keep_ratios': [0.9],
    'batch_size_train': 1024,
    'batch_size_test': 1024,
    'clear_memory_between_models': True,
    'verbose': True,
    'optimizer': 'adam',  # New: optimizer type ('adam', 'sgd', 'rmsprop', 'adamw')
    'optimizer_params': {}  # New: additional optimizer parameters
}

training_results_epochs = train_and_evaluate_models(config)