In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import math
from torch.utils.data import DataLoader

In [None]:
# Your Custom Implementation (provided)
class CustomLinearLayer(Function):
    @staticmethod
    def forward(ctx, input, weight, bias, backward_dropout_rate, apply_reweighting=True):
        output = torch.matmul(input, weight.T) + bias
        ctx.save_for_backward(input, weight, bias)
        ctx.backward_dropout_rate = backward_dropout_rate
        ctx.apply_reweighting = apply_reweighting
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors
        backward_dropout_rate = ctx.backward_dropout_rate
        apply_reweighting = ctx.apply_reweighting

        if apply_reweighting:
            dropout_rate = backward_dropout_rate
            mask = torch.rand_like(weight) > dropout_rate
            mask.fill_diagonal_(True)
            mask_weight = torch.ones_like(weight) / (1 - dropout_rate)
            mask_weight.fill_diagonal_(1.0)
            weighted_mask = mask_weight * mask
        else:
            weighted_mask = torch.ones_like(weight)

        grad_input = grad_output @ (weight * weighted_mask)
        grad_weight = grad_output.T @ input
        grad_bias = grad_output.sum(dim=0)

        return grad_input, grad_weight, grad_bias, None, None

class CustomLinear(nn.Module):
    def __init__(self, in_features, out_features, backward_dropout_rate=0.0, apply_reweighting=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        self.bias = nn.Parameter(torch.empty(out_features))
        self.backward_dropout_rate = backward_dropout_rate
        self.apply_reweighting = apply_reweighting
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            torch.nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        return CustomLinearLayer.apply(x, self.weight, self.bias,
                                     self.backward_dropout_rate, self.apply_reweighting)

# SBP Implementation from Paper
class SBPLinear2D(nn.Module):
    def __init__(self, in_features, out_features, bias=True, keep_ratio=0.5):
        super(SBPLinear2D, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.keep_ratio = keep_ratio

        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.randn(out_features))
        else:
            self.register_parameter('bias', None)
        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def _generate_keep_mask(self, batch_size):
        num_keep = int(batch_size * self.keep_ratio)
        keep_mask = torch.zeros(batch_size, dtype=torch.bool)
        indices = torch.randperm(batch_size)[:num_keep]
        keep_mask[indices] = True
        return keep_mask

    def forward(self, x):
        batch_size, in_features = x.shape
        keep_mask = self._generate_keep_mask(batch_size)
        keep_mask = keep_mask.to(x.device)

        output = torch.zeros(batch_size, self.out_features,
                           device=x.device, dtype=x.dtype)

        keep_indices = keep_mask.nonzero(as_tuple=True)[0]
        drop_indices = (~keep_mask).nonzero(as_tuple=True)[0]

        # Forward pass for kept indices WITH gradient computation
        if len(keep_indices) > 0:
            x_keep = x[keep_indices]
            with torch.enable_grad():
                out_keep = F.linear(x_keep, self.weight, self.bias)
                output[keep_indices] = out_keep

        # Forward pass for dropped indices WITHOUT gradient computation
        if len(drop_indices) > 0:
            x_drop = x[drop_indices]
            with torch.no_grad():
                out_drop = F.linear(x_drop, self.weight, self.bias)
                output[drop_indices] = out_drop

        return output

# Simple 4-layer Network
class SimpleNet(nn.Module):
    def __init__(self, layer_type='standard', keep_ratio=0.5, backward_dropout_rate=0.5):
        super(SimpleNet, self).__init__()
        self.layer_type = layer_type

        # Architecture: 196 -> 128 -> 64 -> 32 -> 10
        if layer_type == 'standard':
            self.fc1 = nn.Linear(196, 128)
            self.fc2 = nn.Linear(128, 64)
            self.fc3 = nn.Linear(64, 32)
            self.fc4 = nn.Linear(32, 10)
        elif layer_type == 'custom':
            self.fc1 = CustomLinear(196, 128, backward_dropout_rate=backward_dropout_rate)
            self.fc2 = CustomLinear(128, 64, backward_dropout_rate=backward_dropout_rate)
            self.fc3 = CustomLinear(64, 32, backward_dropout_rate=backward_dropout_rate)
            self.fc4 = CustomLinear(32, 10, backward_dropout_rate=backward_dropout_rate)
        elif layer_type == 'custom-no-reweighting':
            self.fc1 = CustomLinear(196, 128, backward_dropout_rate=backward_dropout_rate, apply_reweighting=False)
            self.fc2 = CustomLinear(128, 64, backward_dropout_rate=backward_dropout_rate, apply_reweighting=False)
            self.fc3 = CustomLinear(64, 32, backward_dropout_rate=backward_dropout_rate, apply_reweighting=False)
            self.fc4 = CustomLinear(32, 10, backward_dropout_rate=backward_dropout_rate, apply_reweighting=False)
        elif layer_type == 'sbp':
            self.fc1 = SBPLinear2D(196, 128, keep_ratio=keep_ratio)
            self.fc2 = SBPLinear2D(128, 64, keep_ratio=keep_ratio)
            self.fc3 = SBPLinear2D(64, 32, keep_ratio=keep_ratio)
            self.fc4 = SBPLinear2D(32, 10, keep_ratio=keep_ratio)

        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 196)  # Flatten 14x14 to 196
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.fc4(x)
        return x

# Gradient Analysis Functions
def compute_cosine_similarity(grad1, grad2):
    """Compute cosine similarity between two gradient sets"""
    grad1_flat = torch.cat([g.flatten() for g in grad1 if g is not None])
    grad2_flat = torch.cat([g.flatten() for g in grad2 if g is not None])

    cos_sim = F.cosine_similarity(grad1_flat.unsqueeze(0), grad2_flat.unsqueeze(0))
    return cos_sim.item()

def get_gradients(model):
    """Extract gradients from model parameters"""
    gradients = []
    for param in model.parameters():
        if param.grad is not None:
            gradients.append(param.grad.clone())
    return gradients

def load_mnist_data():
    """Load and preprocess MNIST data"""
    transform = transforms.Compose([
        transforms.Resize((14, 14)),  # Resize to 14x14 as requested
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    dataset = torchvision.datasets.MNIST(
        root='./data', train=True, download=True, transform=transform
    )

    # Use smaller subset for faster computation
    subset_indices = torch.randperm(len(dataset))[:500]
    subset = torch.utils.data.Subset(dataset, subset_indices)

    dataloader = DataLoader(subset, batch_size=32, shuffle=True)
    return dataloader



In [None]:
# Main Comparison Function
def compare_gradient_methods():
    """Compare gradient similarities across different keep ratios"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    dataloader = load_mnist_data()

    # Test different keep ratios from 0.1 to 0.9
    keep_ratios = [0.1, 0.3, 0.5, 0.7, 0.9]
    results = {
        'keep_ratios': keep_ratios,
        'custom_similarities': [],
        'custom_no_reweighting_similarities': [],
        'sbp_similarities': [],
        'custom_std': [],
        'custom_no_reweighting_std': [],
        'sbp_std': []
    }

    for keep_ratio in keep_ratios:
        print(f"\nTesting keep ratio: {keep_ratio}")

        batch_similarities_custom = []
        batch_similarities_sbp = []
        batch_similarities_custom_no_reweighting = []

        # Test on multiple batches for statistical significance
        for i, (data, target) in enumerate(dataloader):
            # if i >= 80:  # Limit batches for speed
            #     break

            data, target = data.to(device), target.to(device)

            # Initialize models with same random seed for fair comparison
            torch.manual_seed(42 + i)
            baseline_model = SimpleNet(layer_type='standard').to(device)

            torch.manual_seed(42 + i)
            custom_model = SimpleNet(layer_type='custom',
                                   backward_dropout_rate=1-keep_ratio).to(device)

            torch.manual_seed(42 + i)
            custom_no_reweighting_model = SimpleNet(layer_type='custom-no-reweighting',
                                                    backward_dropout_rate=1-keep_ratio).to(device)

            torch.manual_seed(42 + i)
            sbp_model = SimpleNet(layer_type='sbp', keep_ratio=keep_ratio).to(device)

            criterion = nn.CrossEntropyLoss()

            # Baseline gradients
            baseline_model.zero_grad()
            output_baseline = baseline_model(data)
            loss_baseline = criterion(output_baseline, target)
            loss_baseline.backward()
            grad_baseline = get_gradients(baseline_model)

            # Custom method gradients
            custom_model.zero_grad()
            output_custom = custom_model(data)
            loss_custom = criterion(output_custom, target)
            loss_custom.backward()
            grad_custom = get_gradients(custom_model)

            custom_no_reweighting_model.zero_grad()
            output_custom_no_reweighting = custom_no_reweighting_model(data)
            loss_custom_no_reweighting = criterion(output_custom_no_reweighting, target)
            loss_custom_no_reweighting.backward()
            grad_custom_no_reweighting = get_gradients(custom_no_reweighting_model)

            # SBP method gradients
            sbp_model.zero_grad()
            output_sbp = sbp_model(data)
            loss_sbp = criterion(output_sbp, target)
            loss_sbp.backward()
            grad_sbp = get_gradients(sbp_model)

            # Compute cosine similarities
            sim_custom = compute_cosine_similarity(grad_baseline, grad_custom)
            sim_sbp = compute_cosine_similarity(grad_baseline, grad_sbp)
            sim_custom_no_reweighting = compute_cosine_similarity(grad_baseline, grad_custom_no_reweighting)

            batch_similarities_custom.append(sim_custom)
            batch_similarities_custom_no_reweighting.append(sim_custom_no_reweighting)
            batch_similarities_sbp.append(sim_sbp)

        # Store results
        results['custom_similarities'].append(np.mean(batch_similarities_custom))
        results['custom_no_reweighting_similarities'].append(np.mean(batch_similarities_custom_no_reweighting))
        results['sbp_similarities'].append(np.mean(batch_similarities_sbp))
        results['custom_std'].append(np.std(batch_similarities_custom))
        results['custom_no_reweighting_std'].append(np.std(batch_similarities_custom_no_reweighting))
        results['sbp_std'].append(np.std(batch_similarities_sbp))

        print(f"  Custom method: {np.mean(batch_similarities_custom):.4f} ± {np.std(batch_similarities_custom):.4f}")
        print(f"  Custom method (no reweighting): {np.mean(batch_similarities_custom_no_reweighting):.4f} ± {np.std(batch_similarities_custom_no_reweighting):.4f}")
        print(f"  SBP method: {np.mean(batch_similarities_sbp):.4f} ± {np.std(batch_similarities_sbp):.4f}")

    return results

In [None]:
# Visualization Function
def plot_results(results):
    """Create comprehensive visualizations"""
    plt.style.use('seaborn-v0_8' if 'seaborn-v0_8' in plt.style.available else 'default')
    fig, axes = plt.subplots(1, 1, figsize=(9, 6))

    # Plot 1: Cosine similarities with error bars
    ax1 = axes
    ax1.errorbar(results['keep_ratios'], results['custom_similarities'],
                yerr=results['custom_std'], marker='o', linewidth=2,
                capsize=5, label='Custom Method', color='blue')
    ax1.errorbar(results['keep_ratios'], results['sbp_similarities'],
                yerr=results['sbp_std'], marker='s', linewidth=2,
                capsize=5, label='SBP Method', color='red')
    # ax1.errorbar(results['keep_ratios'], results['custom_no_reweighting_similarities'],
    #             yerr=results['custom_no_reweighting_std'], marker='^', linewidth=2,
    #             capsize=5, label='Custom Method (no reweighting)', color='green')
    ax1.set_xlabel('Keep Ratio')
    ax1.set_ylabel('Cosine Similarity with Baseline')
    ax1.set_title('Gradient Cosine Similarity Comparison')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim([0, 1])

    plt.tight_layout()
    plt.show()

print("Starting Gradient Similarity Comparison")
print("=" * 50)
print("Comparing Custom Backward Dropout vs SBP Method")
print("Dataset: MNIST (resized to 14x14)")
print("Architecture: 4-layer MLP (196→128→64→32→10)")
print("=" * 50)


results = compare_gradient_methods()
plot_results(results)

print("\n Analysis completed successfully!")
print("\n INTERPRETATION GUIDE:")
print("• Higher cosine similarity = gradients more similar to baseline")
print("• SBP drops entire batch samples, Custom drops weight connections")
print("• Error bars show variance across different batches")
print("• Keep ratio: proportion of gradients/samples retained")

In [None]:
# Additional Training Block - Run after the previous comparison code
import time
from torch.utils.data import DataLoader
import torch.optim as optim

def train_and_evaluate_models():
    """
    Train models with different gradient methods and compare final accuracies.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"🚀 Starting Training Comparison on {device}")
    print("=" * 60)

    # Load MNIST data for training and testing
    transform = transforms.Compose([
        transforms.Resize((14, 14)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    # Training dataset
    train_dataset = torchvision.datasets.MNIST(
        root='./data', train=True, download=True, transform=transform
    )
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

    # Test dataset
    test_dataset = torchvision.datasets.MNIST(
        root='./data', train=False, download=True, transform=transform
    )
    test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

    # Training parameters
    num_epochs = 1
    learning_rate = 0.001
    keep_ratios = [0.1, 0.3, 0.5, 0.7, 0.9]  # Test different keep ratios

    # Results storage
    results = {
        'keep_ratios': keep_ratios,
        'baseline_acc': [],
        'custom_acc': [],
        'sbp_acc': [],
        'training_times': {
            'baseline': [],
            'custom': [],
            'sbp': []
        }
    }

    print(f"Training Configuration:")
    print(f"  • Dataset: MNIST (14×14)")
    print(f"  • Architecture: 4-layer MLP (196→128→64→32→10)")
    print(f"  • Epochs: {num_epochs}")
    print(f"  • Learning Rate: {learning_rate}")
    print(f"  • Batch Size: 128 (train), 256 (test)")
    print("=" * 60)

    # Train baseline model
    print("\n🔧 Training Baseline Model (Standard Backpropagation)...")
    baseline_model = SimpleNet(layer_type='standard').to(device)
    baseline_optimizer = optim.Adam(baseline_model.parameters(), lr=learning_rate)
    baseline_start_time = time.time()
    baseline_acc = train_model(baseline_model, baseline_optimizer, train_loader,
                              test_loader, num_epochs, device, "Baseline")
    baseline_time = time.time() - baseline_start_time

    # Train models with different keep ratios
    for keep_ratio in keep_ratios:
        print(f"\n🔧 Training Models with Keep Ratio: {keep_ratio}")

        # Custom method
        print(f"  → Custom Backward Dropout (dropout_rate={1-keep_ratio:.1f})...")
        custom_model = SimpleNet(layer_type='custom',
                                backward_dropout_rate=1-keep_ratio).to(device)
        custom_optimizer = optim.Adam(custom_model.parameters(), lr=learning_rate)
        custom_start_time = time.time()
        custom_acc = train_model(custom_model, custom_optimizer, train_loader,
                                test_loader, num_epochs, device, f"Custom-{keep_ratio}")
        custom_time = time.time() - custom_start_time

        # SBP method
        print(f"  → SBP Method (keep_ratio={keep_ratio:.1f})...")
        sbp_model = SimpleNet(layer_type='sbp', keep_ratio=keep_ratio).to(device)
        sbp_optimizer = optim.Adam(sbp_model.parameters(), lr=learning_rate)
        sbp_start_time = time.time()
        sbp_acc = train_model(sbp_model, sbp_optimizer, train_loader,
                             test_loader, num_epochs, device, f"SBP-{keep_ratio}")
        sbp_time = time.time() - sbp_start_time

        # Store results
        results['custom_acc'].append(custom_acc)
        results['sbp_acc'].append(sbp_acc)
        results['training_times']['custom'].append(custom_time)
        results['training_times']['sbp'].append(sbp_time)

    # Store baseline results (same for all keep ratios for comparison)
    results['baseline_acc'] = [baseline_acc] * len(keep_ratios)
    results['training_times']['baseline'] = [baseline_time] * len(keep_ratios)

    return results

def train_model(model, optimizer, train_loader, test_loader, num_epochs, device, model_name):
    """Train a single model and return final test accuracy."""
    model.train()
    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        total_loss = 0
        correct = 0
        total = 0

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)

        # Print progress every 5 epochs
        if (epoch + 1) % 5 == 0:
            train_acc = 100. * correct / total
            print(f"    Epoch {epoch+1:2d}/{num_epochs}: Loss={total_loss/len(train_loader):.4f}, "
                  f"Train Acc={train_acc:.2f}%")

    # Final evaluation
    test_acc = evaluate_model(model, test_loader, device)
    print(f"    ✅ {model_name} Final Test Accuracy: {test_acc:.2f}%")
    return test_acc

def evaluate_model(model, test_loader, device):
    """Evaluate model on test set."""
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)

    return 100. * correct / total

def plot_training_results(results):
    """Create comprehensive visualizations of training results."""
    plt.style.use('seaborn-v0_8' if 'seaborn-v0_8' in plt.style.available else 'default')
    fig, axes = plt.subplots(1, 1, figsize=(9, 5))

    keep_ratios = results['keep_ratios']

    # Plot 1: Final Accuracies Comparison
    ax1 = axes
    x = np.arange(len(keep_ratios))
    width = 0.25

    bars1 = ax1.bar(x - width, results['baseline_acc'], width,
                   label='Baseline (Standard BP)', alpha=0.8, color='green')
    bars2 = ax1.bar(x, results['custom_acc'], width,
                   label='Custom Backward Dropout', alpha=0.8, color='blue')
    bars3 = ax1.bar(x + width, results['sbp_acc'], width,
                   label='SBP Method', alpha=0.8, color='red')

    # Add value labels on bars
    for bars in [bars1, bars2, bars3]:
        for bar in bars:
            height = bar.get_height()
            ax1.text(bar.get_x() + bar.get_width()/2., height + 0.1,
                    f'{height:.1f}%', ha='center', va='bottom', fontsize=10, fontweight='bold')

    ax1.set_xlabel('Keep Ratio')
    ax1.set_ylabel('Test Accuracy (%)')
    ax1.set_title('Final Test Accuracy Comparison')
    ax1.set_xticks(x)
    ax1.set_xticklabels([f'{r:.1f}' for r in keep_ratios])
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim([0, 100])


    plt.tight_layout()
    plt.show()

    return results

# Run the training comparison
print("🎯 EXTENDED ANALYSIS: Training Models to Completion")
print("This will train all models for multiple epochs and compare final accuracies...")
print("Expected runtime: ~5-10 minutes depending on hardware")

try:
    training_results = train_and_evaluate_models()
    final_plot = plot_training_results(training_results)

    print("\n" + "="*60)
    print("✅ TRAINING ANALYSIS COMPLETED!")
    print("="*60)
    print("\n📋 INTERPRETATION GUIDE:")
    print("• Final accuracies show real-world performance of each method")
    print("• Training times indicate computational overhead")
    print("• SBP maintains paper's claims about minimal accuracy loss")
    print("• Custom method reveals effectiveness of connection-level dropout")
    print("\n🔍 KEY FINDINGS:")
    print("• Compare accuracy drops: which method preserves performance better?")
    print("• Analyze training efficiency: time vs. accuracy trade-offs")
    print("• Observe keep ratio effects: optimal balance point")

except Exception as e:
    print(f"❌ Training failed: {e}")
    print("Check CUDA availability and memory if using GPU")
