In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.autograd import Function
import time

class CommunicationBackDropFunction(Function):
    @staticmethod
    def forward(ctx, input, weight, bias, training, keep_ratio=0.5, comm_mask=None):

        if comm_mask is None:
            # Create a random communication mask if not provided
            comm_mask = torch.rand(weight.size(), device=input.device) < keep_ratio  # 50% dropout

        ctx.save_for_backward(input, weight, comm_mask)
        ctx.training = training
        output = F.linear(input, weight, bias)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve saved tensors
        input, weight, comm_mask = ctx.saved_tensors

        # Apply communication mask to weight during backward
        masked_weight = weight * comm_mask.float()  # Convert boolean mask to float
        grad_input = grad_output.mm(masked_weight)

        # Gradients for weight and bias
        grad_weight = grad_output.t().mm(input)
        grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias, None, None, None


class ColumnwiseStructuredBackDropFunction(Function):
    @staticmethod
    def forward(ctx, input, weight, bias, training, keep_ratio=0.5, kept_col_mask=None):
        """
        input: [batch, D_in]
        weight: [D_out, D_in]
        """
        # col_mask: 1D binary vector, dtype=torch.bool or 0/1 float
        if kept_col_mask is None:
            # Create a random column mask if not provided
            kept_col_mask = torch.rand(weight.size(1), device=input.device) < keep_ratio  # 50% dropout

        input_reduced = input[:, kept_col_mask]  # size: [batch, D_in_reduced]
        weight_reduced = weight[:, kept_col_mask]  # size: [D_out, D_in_reduced]

        # Save only reduced input and indices
        ctx.save_for_backward(input_reduced, weight_reduced)
        ctx.kept_col_mask = kept_col_mask
        ctx.input_shape = input.shape                    # full input shape
        ctx.weight_shape = weight.shape

        output = F.linear(input, weight, bias)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        """
        grad_output: [batch, D_out]
        weight_reduced: [D_out, D_in_reduced]
        input_reduced: [batch, D_in_reduced]

        grad_input: [batch, D_in], only for kept columns
        grad_weight: [D_out, D_in], only kept columns
        grad_bias: [D_out], if bias is not None
        """

        input_reduced, weight_reduced = ctx.saved_tensors
        kept_col_mask = ctx.kept_col_mask
        input_shape = ctx.input_shape
        weight_shape = ctx.weight_shape

        # grad for input: [batch,D_in], only for kept columns
        grad_input_reduced = grad_output.mm(weight_reduced)
        grad_input = torch.zeros(input_shape, device=grad_input_reduced.device, dtype=grad_input_reduced.dtype)
        grad_input[:, kept_col_mask] = grad_input_reduced

        # grad for weight: [D_out,D_in], only kept columns
        grad_weight_reduced = grad_output.t().mm(input_reduced)
        grad_weight = torch.zeros(weight_shape, device=grad_weight_reduced.device, dtype=grad_weight_reduced.dtype)
        grad_weight[:, kept_col_mask] = grad_weight_reduced

        # grad for bias
        grad_bias = grad_output.sum(0) if ctx.needs_input_grad[2] else None

        return grad_input, grad_weight, grad_bias, None, None, None


class RowwiseStructuredBackDropFunction(Function):
    @staticmethod
    def forward(ctx, input, weight, bias, training, keep_ratio=0.5, kept_row_mask=None):
        """
        input: [batch, D_in]
        weight: [D_out, D_in]
        weight_reduced: [D_out_reduced, D_in]
        """
        if kept_row_mask is None:
            # Create a random row mask if not provided
            kept_row_mask = torch.rand(weight.size(0), device=input.device) < keep_ratio  # 50% dropout

        weight_reduced = weight[kept_row_mask, :]  # size: [D_out_reduced, D_in]

        # Save full input, only reduced weight, kept indices
        ctx.save_for_backward(input, weight_reduced)
        ctx.kept_row_mask = kept_row_mask
        ctx.input_shape = input.shape
        ctx.weight_shape = weight.shape
        ctx.bias_length = bias.shape[0] if bias is not None else None

        output = F.linear(input, weight, bias)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        """
        grad_output: [batch, D_out]
        weight_reduced: [D_out_reduced, D_in]
        input: [batch, D_in]

        grad_input: [batch, D_in], only for kept rows
        grad_weight: [D_out, D_in], only kept rows
        grad_bias: [D_out], if bias is not None
        """
        input, weight_reduced = ctx.saved_tensors
        kept_row_mask = ctx.kept_row_mask
        input_shape = ctx.input_shape
        weight_shape = ctx.weight_shape
        bias_length = ctx.bias_length

        # Only propagate grad from kept outputs
        grad_output_reduced = grad_output[:, kept_row_mask] # size: [batch, D_out_reduced]

        # grad for input: [batch, D_in], only for kept rows
        if kept_row_mask.dim() == 0:
            # perform the outer product
            grad_output_reduced = grad_output_reduced.unsqueeze(1)  # make it 2D
            weight_reduced = weight_reduced.unsqueeze(0)  # make it 2D
        grad_input = grad_output_reduced.mm(weight_reduced)  # size: [batch, D_in]

        grad_weight_reduced = grad_output_reduced.t().mm(input) # size: [D_out_reduced, D_in]
        grad_weight = torch.zeros(weight_shape, device=grad_weight_reduced.device, dtype=grad_weight_reduced.dtype)
        grad_weight[kept_row_mask, :] = grad_weight_reduced

        grad_bias = torch.zeros(bias_length, device=grad_output.device, dtype=grad_output.dtype) if bias_length else None
        if grad_bias is not None:
            grad_bias[kept_row_mask] = grad_output_reduced.sum(0)

        return grad_input, grad_weight, grad_bias, None, None, None


In [None]:
# Custom layer wrappers
class CommunicationBackDropLinear(nn.Module):
    def __init__(self, in_features, out_features, keep_ratio, bias=True):
        super().__init__()
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)

        self.keep_ratio = keep_ratio
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight)
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / (fan_in ** 0.5)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input, training=True):
        return CommunicationBackDropFunction.apply(input, self.weight, self.bias, training, self.keep_ratio)


class ColumnwiseStructuredBackDropLinear(nn.Module):
    def __init__(self, in_features, out_features, keep_ratio, bias=True):
        super().__init__()
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.keep_ratio = keep_ratio
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight)
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / (fan_in ** 0.5)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input, training=True):
        return ColumnwiseStructuredBackDropFunction.apply(input, self.weight, self.bias, training, self.keep_ratio)


class RowwiseStructuredBackDropLinear(nn.Module):
    def __init__(self, in_features, out_features, keep_ratio, bias=True):
        super().__init__()
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.keep_ratio = keep_ratio
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight)
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / (fan_in ** 0.5)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input, training=True):
        return RowwiseStructuredBackDropFunction.apply(input, self.weight, self.bias, training, self.keep_ratio)

In [None]:
class SimpleNet(nn.Module):
    def __init__(self,
                 input_dim=196,
                 output_dim=10,
                 hidden_dims=[128, 64, 32],
                 layer_type='standard',
                 keep_ratio=0.5,
                 activation='relu'):
        super(SimpleNet, self).__init__()
        self.layer_type = layer_type
        self.input_dim = input_dim

        # Build the complete architecture dimensions
        self.layer_dims = [input_dim] + hidden_dims + [output_dim]

        # Create layers based on architecture
        self.layers = nn.ModuleList()

        for i in range(len(self.layer_dims) - 1):
            in_dim = self.layer_dims[i]
            out_dim = self.layer_dims[i + 1]

            if layer_type == 'original':
                layer = nn.Linear(in_dim, out_dim)
            elif layer_type == 'communication':
                layer = CommunicationBackDropLinear(in_dim, out_dim, keep_ratio=keep_ratio)
            elif layer_type == 'columnwise':
                layer = ColumnwiseStructuredBackDropLinear(in_dim, out_dim, keep_ratio=keep_ratio)
            elif layer_type == 'rowwise':
                layer = RowwiseStructuredBackDropLinear(in_dim, out_dim, keep_ratio=keep_ratio)
            else:
                raise ValueError(f"Unknown layer_type: {layer_type}")

            self.layers.append(layer)

        # Set activation function
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'sigmoid':
            self.activation = nn.Sigmoid()
        elif activation == 'leaky_relu':
            self.activation = nn.LeakyReLU()
        elif activation == 'gelu':
            self.activation = nn.GELU()
        else:
            raise ValueError(f"Unknown activation: {activation}")

    def forward(self, x):
        # Flatten input to match input_dim
        x = x.view(-1, self.input_dim)

        # Forward through all layers except the last one with activation
        for layer in self.layers[:-1]:
            x = self.activation(layer(x))

        # Final layer without activation
        x = self.layers[-1](x)

        return x

    def get_architecture_info(self):
        """Return information about the network architecture"""
        return {
            'layer_dims': self.layer_dims,
            'layer_type': self.layer_type,
            'num_layers': len(self.layers),
            'total_params': sum(p.numel() for p in self.parameters())
        }

def get_gradients(model):
    """Extract gradients from model parameters"""
    gradients = []
    for param in model.parameters():
        if param.grad is not None:
            gradients.append(param.grad.clone())
    return gradients


def evaluate_model(model, test_loader, device):
    """Evaluate model on test set."""
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)

    return 100. * correct / total


def train_model(model, optimizer, train_loader, test_loader, num_epochs, device, model_name,
                clear_memory_first=True, track_time=False, track_memory=False, debug=False):
    """Train a single model and return results dictionary with optional time, VRAM usage, and accuracy tracking."""

    # Clear GPU memory before starting if requested
    if clear_memory_first and device.type == 'cuda':
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats(device)
        torch.cuda.synchronize()

    model.train()
    criterion = nn.CrossEntropyLoss()

    # Initialize time tracking variables
    if track_time:
        start_time = time.time()

    # GPU memory tracking if available and requested
    if track_memory and device.type == 'cuda':
        torch.cuda.reset_peak_memory_stats(device)
        initial_gpu_memory = torch.cuda.memory_allocated(device) / 1024 / 1024  # MB
        max_gpu_memory = initial_gpu_memory
    else:
        initial_gpu_memory = 0
        max_gpu_memory = 0

    # Lists to store epoch-wise metrics
    epoch_losses = []
    epoch_test_accuracies = []

    for epoch in range(num_epochs):
        total_loss = 0
        correct = 0
        total = 0

        for batch_idx, (data, target) in enumerate(train_loader):
            if debug and batch_idx >= 10:  # Limit to first 10 batches for debugging
                break
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)

            # Track GPU memory usage only if requested
            if track_memory and device.type == 'cuda':
                current_gpu_memory = torch.cuda.memory_allocated(device) / 1024 / 1024  # MB
                max_gpu_memory = max(max_gpu_memory, current_gpu_memory)

        # Store epoch metrics
        avg_loss = total_loss / len(train_loader)
        epoch_losses.append(avg_loss)

        # Evaluate on test set for this epoch
        test_acc = evaluate_model(model, test_loader, device)
        epoch_test_accuracies.append(test_acc)

        # Print progress every 5 epochs
        if (epoch + 1) % 5 == 0:
            train_acc = 100. * correct / total
            print(f" Epoch {epoch+1:2d}/{num_epochs}: Loss={avg_loss:.4f}, "
                  f"Train Acc={train_acc:.2f}%, Test Acc={test_acc:.2f}%")

    # Calculate total training time only if tracking is enabled
    if track_time:
        end_time = time.time()
        training_time = end_time - start_time

    # Final evaluation
    final_test_acc = epoch_test_accuracies[-1]  # Use last epoch's test accuracy
    print(f" ✅ {model_name} Final Test Accuracy: {final_test_acc:.2f}%")

    # Create results dictionary
    results = {
        'model_name': model_name,
        'test_accuracy': final_test_acc,
        'epoch_losses': epoch_losses,
        'epoch_test_accuracies': epoch_test_accuracies,
    }

    # Add time tracking results if enabled
    if track_time:
        results.update({
            'training_time_seconds': training_time,
            'training_time_minutes': training_time / 60,
        })

    # Add GPU memory stats if available and tracking is enabled
    if track_memory and device.type == 'cuda':
        peak_gpu_memory = torch.cuda.max_memory_allocated(device) / 1024 / 1024  # MB
        results['gpu_memory_usage'] = {
            'initial_gpu_memory_mb': initial_gpu_memory,
            'max_gpu_memory_mb': max_gpu_memory,
            'peak_gpu_memory_mb': peak_gpu_memory,
            'gpu_memory_increase_mb': peak_gpu_memory - initial_gpu_memory
        }
        print(f" 🎮 GPU memory usage: {peak_gpu_memory:.2f}MB (increase: {peak_gpu_memory-initial_gpu_memory:.2f}MB)")
    elif track_memory and device.type != 'cuda':
        print(f" ⚠️ GPU not available - no VRAM tracking")

    # Print time summary only if tracking is enabled
    if track_time:
        print(f" ⏱️ Training time: {training_time:.2f}s ({training_time/60:.2f}m)")

    # Optional: Clear memory after training for next model
    if device.type == 'cuda':
        del model, optimizer
        torch.cuda.empty_cache()

    return results


def plot_training_results_time_and_space_analysis(results):
    """Create comprehensive visualizations of training results."""
    plt.style.use('seaborn-v0_8' if 'seaborn-v0_8' in plt.style.available else 'default')
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))

    keep_ratios = results['keep_ratios']
    x = np.arange(len(keep_ratios))
    width = 0.20

    # Extract VRAM data from new results structure (convert MB to GB)
    extracted_vram = {}
    for model_name in results['model_selection']:
        extracted_vram[model_name] = []
        result = results['models'][model_name]
        for kr in keep_ratios:
            if kr in result and 'gpu_memory_usage' in results['models'][model_name][kr]:
                vram_mb = results['models'][model_name][kr]['gpu_memory_usage']['peak_gpu_memory_mb']
                extracted_vram[model_name].append(vram_mb / 1024)  # Convert to GB
            else:
                extracted_vram[model_name].append(0)  # No data available

    # Extract training time data from new results structure
    extracted_time = {}
    for model_name in results['model_selection']:
        extracted_time[model_name] = []
        result = results['models'][model_name]
        for kr in keep_ratios:
            if kr in result and 'training_time_seconds' in results['models'][model_name][kr]:
                time_sec = results['models'][model_name][kr]['training_time_seconds']
                extracted_time[model_name].append(time_sec)
            else:
                extracted_time[model_name].append(0)  # No data available

    # Plot 1: VRAM Usage Comparison
    ax1 = axes[0]
    for index, (model_name, vram_data) in enumerate(extracted_vram.items()):
        ax1.bar(x+(index-1)*width, vram_data, width, label=model_name, alpha=0.8)

    ax1.set_xlabel('Keep Ratio')
    ax1.set_ylabel('Max VRAM Usage (GB)')
    ax1.set_title('Maximum VRAM Usage Comparison')
    ax1.set_xticks(x)
    ax1.set_xticklabels([f'{r:.1f}' for r in keep_ratios])
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    plt.tight_layout()

    # Plot 2: Training Time Comparison
    ax2 = axes[1]
    for model_name, time_data in extracted_time.items():
        ax2.plot(keep_ratios, time_data, label=model_name, linewidth=2, marker='o', markersize=6)
    ax2.set_xlabel('Keep Ratio')
    ax2.set_ylabel('Training Time (seconds)')
    ax2.set_title('Training Time Comparison')
    ax2.set_xticks(x)
    ax2.set_xticklabels([f'{r:.1f}' for r in keep_ratios])
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()


    return results

In [None]:
def train_and_evaluate_models(config=None):
    """
    Train models with different gradient methods and compare final accuracies with plots.

    Args:
        config (dict): Configuration dictionary with training parameters
    """
    # Default configuration
    default_config = {
        'num_epochs': 20,  # Changed from 1 to 20 for better plotting
        'model_selection': ['original', 'communication', 'columnwise', 'rowwise'],
        'learning_rate': 0.001,
        'keep_ratios': [0.3, 0.5, 0.7],
        'batch_size_train': 1024,
        'batch_size_test': 1024,
        'image_size': (14, 14),
        'dataset': 'MNIST',
        'normalize_mean': (0.1307,),
        'normalize_std': (0.3081,),
        'clear_memory_between_models': True,
        'verbose': True,
        'optimizer': 'adam',
        'optimizer_params': {},
        'track_time': False,
        'track_memory': False,
        'debug': False  # New: debug mode to limit training batches
    }

    # Merge user config with defaults
    if config is None:
        config = default_config
    else:
        for key, value in default_config.items():
            if key not in config:
                config[key] = value

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if config['verbose']:
        print(f"🚀 Starting Training Comparison on {device}")
        print("=" * 60)

    # Load data based on configuration
    transform = transforms.Compose([
        transforms.Resize(config['image_size']),
        transforms.ToTensor(),
        transforms.Normalize(config['normalize_mean'], config['normalize_std'])
    ])

    # Training dataset
    if config['dataset'] == 'MNIST':
        train_dataset = datasets.MNIST(
            root='./data', train=True, download=True, transform=transform
        )
        test_dataset = datasets.MNIST(
            root='./data', train=False, download=True, transform=transform
        )
    else:
        raise ValueError(f"Dataset {config['dataset']} not supported yet")

    train_loader = DataLoader(train_dataset, batch_size=config['batch_size_train'], shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=config['batch_size_test'], shuffle=False)

    # Results storage
    results = {
        'config': config,
        'keep_ratios': config['keep_ratios'],
        'model_selection': config['model_selection'],
        'models': {}
    }
    for model_name in config['model_selection']:
        results['models'][model_name] = {}

    # Helper function to create optimizer
    def create_optimizer(model_parameters):
        optimizer_name = config['optimizer'].lower()
        base_params = {'lr': config['learning_rate']}
        base_params.update(config['optimizer_params'])

        if optimizer_name == 'adam':
            return optim.Adam(model_parameters, **base_params)
        elif optimizer_name == 'sgd':
            return optim.SGD(model_parameters, **base_params)
        elif optimizer_name == 'rmsprop':
            return optim.RMSprop(model_parameters, **base_params)
        elif optimizer_name == 'adamw':
            return optim.AdamW(model_parameters, **base_params)
        else:
            raise ValueError(f"Optimizer '{optimizer_name}' not supported. Use 'adam', 'sgd', 'rmsprop', or 'adamw'")

    if config['verbose']:
        print(f"Training Configuration:")
        print(f"  • Dataset: {config['dataset']} ({config['image_size'][0]}×{config['image_size'][1]})")
        print(f"  • Architecture: 4-layer MLP (196→128→64→32→10)")
        print(f"  • Epochs: {config['num_epochs']}")
        print(f"  • Learning Rate: {config['learning_rate']}")
        print(f"  • Optimizer: {config['optimizer'].upper()}")
        print(f"  • Batch Size: {config['batch_size_train']} (train), {config['batch_size_test']} (test)")
        print(f"  • Keep Ratios: {config['keep_ratios']}")
        print("=" * 60)

    # Helper function to clear memory if configured
    def clear_memory_if_needed():
        if config['clear_memory_between_models'] and device.type == 'cuda':
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.synchronize()
            if config['verbose']:
                print("🧹 GPU memory cleared")

    # Train models with different keep ratios
    for keep_ratio in config['keep_ratios']:
        if config['verbose']:
            print(f"\n🔧 Training Models with Keep Ratio: {keep_ratio}")

        for model_name in config['model_selection']:
            if config['verbose']:
                print(f"  → Training {model_name} Model (Keep Ratio={keep_ratio:.1f})...")

            clear_memory_if_needed()
            model = SimpleNet(layer_type=model_name, keep_ratio=keep_ratio).to(device)
            optimizer = create_optimizer(model.parameters())

            model_results = train_model(
                model=model,
                optimizer=optimizer,
                train_loader=train_loader,
                test_loader=test_loader,
                num_epochs=config['num_epochs'],
                device=device,
                model_name=f"{model_name}-{keep_ratio}",
                clear_memory_first=False,  # Already cleared above
                track_time=config['track_time'],  # Disable time tracking
                track_memory=config['track_memory']  # Disable memory tracking
            )

            # Store results with keep_ratio as key
            if model_name not in results['models']:
                results['models'][model_name] = {}
            results['models'][model_name][keep_ratio] = model_results

    # save results to disk with name based on config
    results_filename = f"training_results_{config['dataset']}_{'_'.join(config['model_selection'])}_{'_'.join(map(str, config['keep_ratios']))}_{config['optimizer']}_lr_{config['learning_rate']}.json"
    import json
    with open(results_filename, 'w') as f:
        json.dump(results, f, indent=4)
    if config['verbose']:
        print(f"📂 Results saved to {results_filename}")
    # Plot training curves
    _plot_training_curves(results)

    return results


def _plot_training_curves(results):
    """Create loss and test accuracy plots for different methods."""
    import matplotlib.pyplot as plt

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # Plot 1: Training Loss
    for model_name in results['model_selection']:
        for keep_ratio in results['keep_ratios']:
            if keep_ratio in results['models'][model_name]:
                loss = results['models'][model_name][keep_ratio]['epoch_losses']
                epochs = list(range(1, len(loss) + 1))
                ax1.plot(epochs, loss, label=f'{model_name} (keep={keep_ratio})', linewidth=2, marker='s', markersize=4)

    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Training Loss')
    ax1.set_title('Training Loss Over Epochs')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Plot 2: Test Accuracy
    for model_name in results['model_selection']:
        for keep_ratio in results['keep_ratios']:
            if keep_ratio in results['models'][model_name]:
                acc = results['models'][model_name][keep_ratio]['epoch_test_accuracies']
                epochs = list(range(1, len(acc) + 1))
                ax2.plot(epochs, acc, label=f'{model_name} (keep={keep_ratio})', linewidth=2, marker='s', markersize=4)

    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Test Accuracy (%)')
    ax2.set_title('Test Accuracy Over Epochs')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    # Print final summary
    print("\n" + "="*60)
    print("📊 FINAL RESULTS SUMMARY")
    print("="*60)
    print(f"{'Method':<25} {'Final Test Acc':<15} {'Final Loss':<12}")
    print("-" * 60)

    for keep_ratio in results['keep_ratios']:
        for model_name in results['model_selection']:
            if keep_ratio in results['models'][model_name]:
                model_results = results['models'][model_name][keep_ratio]
                print(f"{f'{model_name} (keep={keep_ratio})':<25} {model_results['test_accuracy']:<15.2f} {model_results['epoch_losses'][-1]:<12.4f}")


In [None]:
config = {
    # 'model_selection': ['original', 'communication', 'columnwise', 'rowwise'],
    'model_selection': ['communication'],
    'num_epochs': 100,
    'learning_rate': 0.1,
    'keep_ratios': [0.1, 0.3, 0.5, 0.7, 0.9],
    'batch_size_train': 1024,
    'batch_size_test': 1024,
    'clear_memory_between_models': True,
    'verbose': True,
    'optimizer': 'sgd',  # New: optimizer type ('adam', 'sgd', 'rmsprop', 'adamw')
    # 'debug': True,
}

training_results_epochs = train_and_evaluate_models(config)

In [None]:
config = {
    # 'model_selection': ['original', 'communication', 'columnwise', 'rowwise'],
    'model_selection': ['communication'],
    'num_epochs': 100,
    'learning_rate': 1,
    'keep_ratios': [0.1, 0.3, 0.5, 0.7, 0.9],
    'batch_size_train': 1024,
    'batch_size_test': 1024,
    'clear_memory_between_models': True,
    'verbose': True,
    'optimizer': 'sgd',  # New: optimizer type ('adam', 'sgd', 'rmsprop', 'adamw')
    'optimizer_params': {}  # New: additional optimizer parameters
}

training_results_epochs = train_and_evaluate_models(config)

In [None]:
config = {
    # 'model_selection': ['original', 'communication', 'columnwise', 'rowwise'],
    'model_selection': ['columnwise'],
    'num_epochs': 100,
    'learning_rate': 0.1,
    'keep_ratios': [0.1, 0.3, 0.5, 0.7, 0.9],
    'batch_size_train': 1024,
    'batch_size_test': 1024,
    'clear_memory_between_models': True,
    'verbose': True,
    'optimizer': 'sgd',  # New: optimizer type ('adam', 'sgd', 'rmsprop', 'adamw')
    'optimizer_params': {}  # New: additional optimizer parameters
}

training_results_epochs = train_and_evaluate_models(config)

In [None]:
config = {
    # 'model_selection': ['original', 'communication', 'columnwise', 'rowwise'],
    'model_selection': ['rowwise'],
    'num_epochs': 100,
    'learning_rate': 0.1,
    'keep_ratios': [0.1, 0.3, 0.5, 0.7, 0.9],
    'batch_size_train': 1024,
    'batch_size_test': 1024,
    'clear_memory_between_models': True,
    'verbose': True,
    'optimizer': 'sgd',  # New: optimizer type ('adam', 'sgd', 'rmsprop', 'adamw')
    'optimizer_params': {}  # New: additional optimizer parameters
}

training_results_epochs = train_and_evaluate_models(config)

In [None]:
 # load three results from disk
import json
import numpy as np
def load_results_from_file(filename):
    with open(filename, 'r') as f:
        return json.load(f)
results_communication = load_results_from_file('training_results_MNIST_communication_0.1_0.3_0.5_0.7_0.9.json')
results_columnwise = load_results_from_file('training_results_MNIST_columnwise_0.1_0.3_0.5_0.7_0.9.json')
results_rowwise = load_results_from_file('training_results_MNIST_rowwise_0.1_0.3_0.5_0.7_0.9.json')
restuls_original = load_results_from_file('training_results_MNIST_original_0.9.json')

# plot the keep=0.1 results for all three methods
import matplotlib.pyplot as plt
def plot_keep_ratio_comparison(results_communication, results_columnwise, results_rowwise, keep_ratio=0.1):
    """Plot comparison of different methods for a specific keep ratio."""
    plt.style.use('default')
    fig, ax = plt.subplots(figsize=(8, 6))
    methods = ['communication', 'columnwise', 'rowwise']
    for method in methods:
        results = eval(f'results_{method}')
        # print(len(results['models'][method]['0.1']['epoch_losses']))
        model_results = results['models'][method][f'{keep_ratio}']
        x = list(range(1, len(model_results['epoch_test_accuracies']) + 1))
        # make the marker different for each method
        if method == 'communication':
            marker = 'o'
        elif method == 'columnwise':
            marker = 's'
        elif method == 'rowwise':
            marker = 'D'
        else:
            marker = 'x'
        ax.plot(x, model_results['epoch_losses'], label=f'{method}',
                linewidth=2)
    model_results = restuls_original['models']['original'][f'0.9']
    x = list(range(1, len(model_results['epoch_test_accuracies']) + 1))
    ax.plot(x, model_results['epoch_losses'], label=f'original', linewidth=2)
    # ax.set_yscale('log')
    ax.loglog()
    # add a horizontal line for the original model with y = 97.7
    # ax.axhline(y=97.7, color='pink', linestyle='--', label='Full Backprop (97.7%)', linewidth=3)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Training loss')
    ax.set_title(f'Training loss Comparison at Keep Ratio {keep_ratio}')
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    return results_communication, results_columnwise, results_rowwise
# Plot comparison for keep ratio 0.1

results_communication, results_columnwise, results_rowwise = plot_keep_ratio_comparison(
    results_communication, results_columnwise, results_rowwise, keep_ratio=0.1
)
results_communication, results_columnwise, results_rowwise = plot_keep_ratio_comparison(
    results_communication, results_columnwise, results_rowwise, keep_ratio=0.3
)
results_communication, results_columnwise, results_rowwise = plot_keep_ratio_comparison(
    results_communication, results_columnwise, results_rowwise, keep_ratio=0.5
)
results_communication, results_columnwise, results_rowwise = plot_keep_ratio_comparison(
    results_communication, results_columnwise, results_rowwise, keep_ratio=0.7
)
results_communication, results_columnwise, results_rowwise = plot_keep_ratio_comparison(
    results_communication, results_columnwise, results_rowwise, keep_ratio=0.9
)

In [None]:
config = {
    # 'model_selection': ['original', 'communication', 'columnwise', 'rowwise'],
    'model_selection': ['original'],
    'num_epochs': 100,
    'learning_rate': 0.1,
    'keep_ratios': [0.9],
    'batch_size_train': 1024,
    'batch_size_test': 1024,
    'clear_memory_between_models': True,
    'verbose': True,
    'optimizer': 'sgd',  # New: optimizer type ('adam', 'sgd', 'rmsprop', 'adamw')
    'optimizer_params': {}  # New: additional optimizer parameters
}

training_results_epochs = train_and_evaluate_models(config)
# import json
# load the saved tensor


In [None]:
config = {
    # 'model_selection': ['original', 'communication', 'columnwise', 'rowwise'],
    'model_selection': ['rowwise'],
    'num_epochs': 500,
    'learning_rate': 0.5,
    'keep_ratios': [0.1],
    'batch_size_train': 1024,
    'batch_size_test': 1024,
    'clear_memory_between_models': True,
    'verbose': True,
    'optimizer': 'sgd',  # New: optimizer type ('adam', 'sgd', 'rmsprop', 'adamw')
    'optimizer_params': {}  # New: additional optimizer parameters
}

training_results_epochs = train_and_evaluate_models(config)

In [None]:
config = {
    'model_selection': ['original', 'communication', 'columnwise', 'rowwise'],
    'num_epochs': 1,
    'learning_rate': 0.1,
    'keep_ratios': [0.1, 0.3, 0.5, 0.7, 0.9],
    'batch_size_train': 1024,
    'batch_size_test': 1024,
    'clear_memory_between_models': True,
    'verbose': True,
    'optimizer': 'sgd',  # New: optimizer type ('adam', 'sgd', 'rmsprop', 'adamw')
    'track_time': True,
    'track_memory': True,
    'debug': True  # New: debug mode to limit training batches
}

training_results_epochs = train_and_evaluate_models(config)
plot_training_results_time_and_space_analysis(training_results_epochs)
print(training_results_epochs['models']['original'][0.1].keys())

In [None]:
config = {
    # 'model_selection': ['original', 'communication', 'columnwise', 'rowwise'],
    'model_selection': ['communication'],
    'num_epochs': 100,
    'learning_rate': 0.3,
    'keep_ratios': [0.1],
    'batch_size_train': 1024,
    'batch_size_test': 1024,
    'clear_memory_between_models': True,
    'verbose': True,
    'optimizer': 'sgd',  # New: optimizer type ('adam', 'sgd', 'rmsprop', 'adamw')
    # 'debug': True,
}

training_results_epochs = train_and_evaluate_models(config)