In [None]:
# Setup environment and install dependencies
%pip install torch torchvision matplotlib pandas tqdm


In [None]:
# Clone the repository
!git clone https://github.com/yourusername/Comp430_Project.git
%cd Comp430_Project


In [None]:
# Import necessary libraries
import os
import sys
import math
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
import time
import copy
from collections import OrderedDict
from tqdm.notebook import tqdm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, Subset

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seed for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()


In [None]:
# Define a function to load different datasets
def load_dataset(dataset_name, data_dir='./data'):
    """
    Load and preprocess datasets.
    
    Args:
        dataset_name: Name of the dataset ('mnist', 'cifar10', 'fashion_mnist')
        data_dir: Directory to store datasets
        
    Returns:
        train_dataset, test_dataset
    """
    os.makedirs(data_dir, exist_ok=True)
    
    if dataset_name.lower() == 'mnist':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        
        train_dataset = datasets.MNIST(root=data_dir, train=True, download=True, transform=transform)
        test_dataset = datasets.MNIST(root=data_dir, train=False, download=True, transform=transform)
        
    elif dataset_name.lower() == 'cifar10':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        train_dataset = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
        test_dataset = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)
        
    elif dataset_name.lower() == 'fashion_mnist':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.2860,), (0.3530,))
        ])
        
        train_dataset = datasets.FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
        test_dataset = datasets.FashionMNIST(root=data_dir, train=False, download=True, transform=transform)
        
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    
    return train_dataset, test_dataset


In [None]:
# Data partitioning for IID and non-IID settings

def partition_data_iid(dataset, num_clients):
    """
    Partitions a dataset into IID subsets for each client.
    """
    num_items_per_client = len(dataset) // num_clients
    client_datasets = {}
    all_indices = list(range(len(dataset)))
    np.random.shuffle(all_indices)  # Shuffle indices for random distribution

    for i in range(num_clients):
        start_idx = i * num_items_per_client
        # Ensure the last client gets any remaining data points
        end_idx = (i + 1) * num_items_per_client if i != num_clients - 1 else len(dataset)
        client_indices = all_indices[start_idx:end_idx]
        client_datasets[i] = Subset(dataset, client_indices)

    return client_datasets

def partition_data_dirichlet(dataset, num_clients, alpha):
    """
    Partitions the data using a Dirichlet distribution to create non-IID data splits.
    
    Args:
        dataset: The dataset to partition (must have a 'targets' attribute)
        num_clients: Number of clients to create partitions for
        alpha: Dirichlet concentration parameter - controls skew
               alpha→0: extreme skew, each client gets mostly one class
               alpha→∞: balanced distribution (IID)
    
    Returns:
        A dictionary mapping client IDs to dataset subsets
    """
    if not hasattr(dataset, 'targets'):
        if hasattr(dataset, 'labels'):
            targets = np.array(dataset.labels)
        else:
            # For torchvision datasets
            targets = np.array(dataset.targets)
    else:
        targets = np.array(dataset.targets)
        
    classes = np.unique(targets)
    idx_by_class = {c: np.where(targets == c)[0] for c in classes}

    client_indices = [[] for _ in range(num_clients)]

    for c in classes:
        # draw class proportions
        props = np.random.dirichlet(alpha * np.ones(num_clients))
        props = (props * len(idx_by_class[c])).astype(int)
        
        # If there's rounding error, adjust last partition size
        props[-1] = len(idx_by_class[c]) - props[:-1].sum()

        # split indices
        np.random.shuffle(idx_by_class[c])
        start = 0
        for cid, cnt in enumerate(props):
            client_indices[cid].extend(idx_by_class[c][start:start+cnt])
            start += cnt

    return {cid: Subset(dataset, idx) for cid, idx in enumerate(client_indices)}

def get_client_dataloaders(dataset, num_clients, batch_size, distribution='iid', alpha=1.0):
    """
    Creates DataLoaders for each client based on specified distribution.
    
    Args:
        dataset: Dataset to partition
        num_clients: Number of clients
        batch_size: Batch size for DataLoaders
        distribution: 'iid' or 'dirichlet'
        alpha: Concentration parameter for Dirichlet distribution (used only if distribution='dirichlet')
        
    Returns:
        Dictionary mapping client IDs to their respective DataLoaders
    """
    if distribution == 'iid':
        client_datasets = partition_data_iid(dataset, num_clients)
    elif distribution == 'dirichlet':
        client_datasets = partition_data_dirichlet(dataset, num_clients, alpha)
    else:
        raise ValueError(f"Unknown distribution method: {distribution}")
        
    client_loaders = {}
    for client_id, dataset in client_datasets.items():
        client_loaders[client_id] = DataLoader(
            dataset, batch_size=batch_size, shuffle=True, 
            pin_memory=torch.cuda.is_available()
        )
        
    return client_loaders


In [None]:
# Define various models for our experiments

class SimpleCNN(nn.Module):
    """Simple CNN suitable for MNIST and Fashion-MNIST."""
    def __init__(self, num_classes=10, split_layer=1):
        super(SimpleCNN, self).__init__()
        self.split_layer = split_layer
        
        # Client-side layers (WC)
        self.client_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(1, 16, kernel_size=5, padding=2),
                nn.ReLU(),
                nn.MaxPool2d(2)
            ),
            nn.Sequential(
                nn.Conv2d(16, 32, kernel_size=5, padding=2),
                nn.ReLU(),
                nn.MaxPool2d(2)
            )
        ])
        
        # Server-side layers (WS)
        self.server_layers = nn.ModuleList([
            nn.Sequential(
                nn.Flatten(),
                nn.Linear(32 * 7 * 7, 128),
                nn.ReLU()
            ),
            nn.Linear(128, num_classes)
        ])
    
    def forward(self, x):
        # Apply client-side layers up to split_layer
        for i in range(self.split_layer):
            if i < len(self.client_layers):
                x = self.client_layers[i](x)
        
        # Apply server-side layers
        for layer in self.server_layers:
            x = layer(x)
            
        return x
    
    def client_forward(self, x):
        """Forward pass through client part of the model."""
        for i in range(self.split_layer):
            if i < len(self.client_layers):
                x = self.client_layers[i](x)
        return x
    
    def server_forward(self, x):
        """Forward pass through server part of the model."""
        for layer in self.server_layers:
            x = layer(x)
        return x

class CIFARCNN(nn.Module):
    """Deeper CNN suitable for CIFAR-10."""
    def __init__(self, num_classes=10, split_layer=2):
        super(CIFARCNN, self).__init__()
        self.split_layer = split_layer
        
        # Client-side layers (WC)
        self.client_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(3, 32, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.BatchNorm2d(32)
            ),
            nn.Sequential(
                nn.Conv2d(32, 64, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.BatchNorm2d(64)
            ),
            nn.Sequential(
                nn.Conv2d(64, 128, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.BatchNorm2d(128)
            )
        ])
        
        # Server-side layers (WS)
        self.server_layers = nn.ModuleList([
            nn.Sequential(
                nn.Flatten(),
                nn.Linear(128 * 8 * 8, 256),
                nn.ReLU(),
                nn.Dropout(0.5)
            ),
            nn.Linear(256, num_classes)
        ])
    
    def forward(self, x):
        # Apply client-side layers up to split_layer
        for i in range(self.split_layer):
            if i < len(self.client_layers):
                x = self.client_layers[i](x)
        
        # Apply server-side layers
        for layer in self.server_layers:
            x = layer(x)
            
        return x
    
    def client_forward(self, x):
        """Forward pass through client part of the model."""
        for i in range(self.split_layer):
            if i < len(self.client_layers):
                x = self.client_layers[i](x)
        return x
    
    def server_forward(self, x):
        """Forward pass through server part of the model."""
        for layer in self.server_layers:
            x = layer(x)
        return x

class MLP(nn.Module):
    """Baseline MLP model."""
    def __init__(self, input_dim=784, hidden_dim=256, num_classes=10, split_layer=1):
        super(MLP, self).__init__()
        self.split_layer = split_layer
        
        # Client-side layers (WC)
        self.client_layers = nn.ModuleList([
            nn.Sequential(
                nn.Flatten(),
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU()
            ),
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU()
            )
        ])
        
        # Server-side layers (WS)
        self.server_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.ReLU()
            ),
            nn.Linear(hidden_dim // 2, num_classes)
        ])
    
    def forward(self, x):
        # Apply client-side layers up to split_layer
        for i in range(self.split_layer):
            if i < len(self.client_layers):
                x = self.client_layers[i](x)
        
        # Apply server-side layers
        for layer in self.server_layers:
            x = layer(x)
            
        return x
    
    def client_forward(self, x):
        """Forward pass through client part of the model."""
        for i in range(self.split_layer):
            if i < len(self.client_layers):
                x = self.client_layers[i](x)
        return x
    
    def server_forward(self, x):
        """Forward pass through server part of the model."""
        for layer in self.server_layers:
            x = layer(x)
        return x

def get_model(model_name, dataset_name, split_layer=1):
    """
    Returns a model instance based on model_name and dataset_name.
    
    Args:
        model_name: 'simple_cnn', 'cifar_cnn', 'mlp'
        dataset_name: 'mnist', 'cifar10', 'fashion_mnist'
        split_layer: Layer index at which to split the model
        
    Returns:
        Model instance
    """
    if model_name == 'simple_cnn':
        if dataset_name in ['mnist', 'fashion_mnist']:
            return SimpleCNN(num_classes=10, split_layer=split_layer)
        else:
            raise ValueError(f"SimpleCNN not suitable for {dataset_name}")
            
    elif model_name == 'cifar_cnn':
        if dataset_name == 'cifar10':
            return CIFARCNN(num_classes=10, split_layer=split_layer)
        else:
            raise ValueError(f"CIFARCNN not suitable for {dataset_name}")
            
    elif model_name == 'mlp':
        if dataset_name in ['mnist', 'fashion_mnist']:
            return MLP(input_dim=784, hidden_dim=256, num_classes=10, split_layer=split_layer)
        elif dataset_name == 'cifar10':
            return MLP(input_dim=3*32*32, hidden_dim=512, num_classes=10, split_layer=split_layer)
        else:
            raise ValueError(f"Unknown dataset for MLP: {dataset_name}")
            
    else:
        raise ValueError(f"Unknown model: {model_name}")


In [None]:
# Noise utilities

def laplace_eps(scale: float, sensitivity: float) -> float:
    """Compute epsilon for Laplace mechanism."""
    return sensitivity / scale

def add_laplacian_noise(tensor, sensitivity, epsilon_prime, device='cuda'):
    """
    Adds Laplacian noise to a tensor based on sensitivity and epsilon_prime.
    
    Args:
        tensor: The input tensor (e.g., activations).
        sensitivity: The L1 sensitivity of the function outputting the tensor.
        epsilon_prime: The privacy budget epsilon' for this noise addition step.
        device: The device to generate noise on.

    Returns:
        Tensor with added Laplacian noise.
    """
    # If sensitivity is zero, no noise is added.
    if sensitivity == 0.0:
        return tensor

    if epsilon_prime <= 0:
        raise ValueError("Epsilon prime must be positive for Laplacian noise.")

    scale = sensitivity / epsilon_prime
    assert abs(laplace_eps(scale, sensitivity) - epsilon_prime) < 1e-5, \
        "Laplace ε′ mismatch → wrong sensitivity or ε′"

    # Generate Laplacian noise using PyTorch distributions
    laplace_dist = torch.distributions.laplace.Laplace(loc=0.0, scale=scale)
    noise = laplace_dist.sample(tensor.size()).to(device)

    return tensor + noise

def add_gaussian_noise(tensor, clip_norm, noise_multiplier, device='cuda'):
    """
    Adds Gaussian noise scaled by clip_norm and noise_multiplier.
    Used after summing clipped per-sample gradients.
    
    Args:
        tensor: The input tensor (e.g., summed clipped gradients).
        clip_norm: The L2 norm bound C used for clipping.
        noise_multiplier: The noise multiplier z (sigma = z * C).
        device: The device to generate noise on.

    Returns:
        Tensor with added Gaussian noise.
    """
    if noise_multiplier < 0:
        raise ValueError("Noise multiplier cannot be negative.")
    if clip_norm <= 0:
        # Allow clip_norm=0 for cases where no privacy is applied (noise_multiplier=0)
        if noise_multiplier > 0:
             raise ValueError("Clip norm must be positive if noise multiplier is positive.")
        else:
            # No clipping, no noise - return original tensor
            return tensor

    sigma = noise_multiplier * clip_norm

    # If sigma is zero, no noise is added.
    if sigma == 0.0:
        return tensor

    # Generate Gaussian noise N(0, sigma^2 * I)
    gaussian_dist = torch.distributions.normal.Normal(loc=0.0, scale=sigma)
    noise = gaussian_dist.sample(tensor.size()).to(device)

    return tensor + noise


In [None]:
# Privacy Accountant

def _log_comb(n: int, k: int) -> float:
    """Computes log(C(n, k)) using log gamma functions."""
    if k < 0 or k > n:
        return -float('inf') # Log of zero
    # Using math.lgamma which computes log(Gamma(x)) = log((x-1)!)
    # log(n! / (k! * (n-k)!)) = lgamma(n+1) - lgamma(k+1) - lgamma(n-k+1)
    return math.lgamma(n + 1) - math.lgamma(k + 1) - math.lgamma(n - k + 1)

def _log_add_exp(log_a: float, log_b: float) -> float:
    """Computes log(exp(log_a) + exp(log_b)) robustly."""
    if log_a == -float('inf'):
        return log_b
    if log_b == -float('inf'):
        return log_a
    if log_a > log_b:
        return log_a + math.log1p(math.exp(log_b - log_a))
    else:
        return log_b + math.log1p(math.exp(log_a - log_b))

def _compute_rdp_epsilon_step(q: float, noise_multiplier: float, alpha: int) -> float:
    """
    Computes the Renyi Differential Privacy (RDP) epsilon for a single step
    of the sampled Gaussian mechanism.
    """
    if q == 0:
        return 0.0 # No privacy cost if not sampled
    if q == 1.0:
        # Standard (non-sampled) Gaussian mechanism RDP
        if noise_multiplier == 0:
             return float('inf') # Infinite privacy cost with zero noise
        # RDP is alpha / (2 * sigma^2) where sigma is noise_multiplier
        return alpha / (2.0 * noise_multiplier**2)
    if noise_multiplier == 0:
         return float('inf') # Infinite privacy cost if noise is zero and sampled

    sigma_squared = noise_multiplier**2

    # Compute the sum using log-sum-exp trick for numerical stability
    log_sum_exp = -float('inf')
    log_q = math.log(q)
    log_1_minus_q = math.log1p(-q) # More accurate for small q

    for k in range(alpha + 1):
        log_comb_term = _log_comb(alpha, k)
        if log_comb_term == -float('inf'):
            continue

        # Term involving probabilities: k * log(q) + (alpha - k) * log(1-q)
        log_prob_term = k * log_q + (alpha - k) * log_1_minus_q

        # Term involving the RDP of non-sampled mechanism at order k
        # We need exp((k-1) * rdp_epsilon(k)) = exp((k-1) * k / (2 * sigma^2))
        # Handle k=0 and k=1 where the exponent term is 0 -> exp(0) = 1 -> log(1) = 0
        log_exp_term = 0.0
        if k > 1:
            log_exp_term = (k - 1.0) * k / (2.0 * sigma_squared)

        # Combine terms in log space: log(C(a,k) * q^k * (1-q)^(a-k) * exp(...))
        current_term_log = log_comb_term + log_prob_term + log_exp_term

        # Add to the total sum using log-add-exp
        log_sum_exp = _log_add_exp(log_sum_exp, current_term_log)

    # Final RDP epsilon is log(sum) / (alpha - 1)
    rdp_epsilon = log_sum_exp / (alpha - 1.0)

    return rdp_epsilon

class LaplaceAccumulator:
    """Simple pure-DP counter for Laplace ε's."""
    def __init__(self):
        self.eps_sum = 0.0
        self._eps_history = []  # Track epsilon history

    def step(self, eps: float):
        assert eps >= 0
        self.eps_sum += eps
        self._eps_history.append(self.eps_sum)

class ManualPrivacyAccountant:
    """
    Manually implemented Moments Accountant (based on Renyi DP) to track
    cumulative privacy cost (epsilon, delta) for the Gaussian Noise mechanism.
    """
    def __init__(self, moment_orders=None):
        """
        Initializes the accountant.

        Args:
            moment_orders: A list of RDP orders (alpha values > 1) to track.
                           If None, uses a default set.
        """
        # Use a default set of orders if none provided
        if moment_orders is None:
            moment_orders = list(range(2, 33)) + [40.0, 48.0, 56.0, 64.0]
            moment_orders = [int(a) for a in moment_orders if isinstance(a, (int, float)) and a > 1]
            moment_orders = sorted(list(set(moment_orders))) # Unique & sorted

        if not moment_orders or any(alpha <= 1 for alpha in moment_orders):
            raise ValueError("Moment orders (alphas) must be > 1.")

        self.moment_orders = moment_orders
        # Store total accumulated RDP epsilon for each order alpha
        self._total_rdp_epsilons = {alpha: 0.0 for alpha in self.moment_orders}
        self._steps = 0 # Track total number of steps taken

    def step(self, noise_multiplier: float, sampling_rate: float, num_steps: int = 1):
        """
        Records the privacy cost of applying the sampled Gaussian mechanism
        for a number of steps.
        """
        if noise_multiplier < 0:
            raise ValueError("Noise multiplier cannot be negative.")
        if not (0 <= sampling_rate <= 1):
            raise ValueError("Sampling rate must be between 0 and 1.")
        if num_steps <= 0:
            return # No steps taken

        for alpha in self.moment_orders:
            # Calculate RDP epsilon for a *single* step with these params
            # Ensure alpha is int for _compute_rdp_epsilon_step as implemented
            rdp_epsilon_step = _compute_rdp_epsilon_step(sampling_rate, noise_multiplier, int(alpha))

            # Accumulate the total RDP epsilon for this order
            self._total_rdp_epsilons[alpha] += num_steps * rdp_epsilon_step

        self._steps += num_steps

    def get_privacy_spent(self, delta: float) -> tuple:
        """
        Computes the (epsilon, delta)-DP guarantee for the accumulated
        privacy cost.
        """
        if delta <= 0:
            print("Warning: Target delta must be positive.")
            return float('inf'), delta

        min_epsilon = float('inf')

        for alpha in self.moment_orders:
            total_rdp_epsilon = self._total_rdp_epsilons[alpha]

            if total_rdp_epsilon == float('inf'):
                continue # This alpha gives infinite epsilon

            # Formula to convert RDP epsilon(alpha) to (epsilon, delta)-DP:
            # epsilon = RDP_epsilon(alpha) - log(delta) / (alpha - 1)
            epsilon = total_rdp_epsilon - (math.log(delta) / (alpha - 1.0))

            # Ensure epsilon is not negative
            epsilon = max(0.0, epsilon)

            min_epsilon = min(min_epsilon, epsilon)

        return min_epsilon, delta

    @property
    def total_steps(self):
        return self._steps

class HybridAccountant:
    """
    Tracks:
      1) pure-DP ε from Laplace steps,
      2) RDP from Gaussian steps via ManualPrivacyAccountant.
    """
    def __init__(self, noise_multiplier, sampling_rate, moment_orders=None):
        # for Gaussian RDP
        self.gauss_acc = ManualPrivacyAccountant(moment_orders)
        self.fixed_sigma = noise_multiplier
        self.q = sampling_rate
        self._eps_history = []  # Track total epsilon history

        # for Laplace pure-DP
        self.laplace_acc = LaplaceAccumulator()

    def laplace_step(self, epsilon_prime: float):
        """Call every time you inject Laplace(…,ε′)."""
        self.laplace_acc.step(epsilon_prime)
        # Update the combined history
        self._update_history()

    def gaussian_step(self, noise_multiplier=None, num_steps: int = 1):
        """Call every time you inject Gaussian noise on gradients."""
        # uses the *fixed* sigma for all steps (unless adaptive is provided)
        sigma = noise_multiplier if noise_multiplier is not None else self.fixed_sigma
        self.gauss_acc.step(sigma, self.q, num_steps)
        # Update the combined history
        self._update_history()

    def _update_history(self):
        """Update the combined epsilon history."""
        eps_gauss, _ = self.gauss_acc.get_privacy_spent(delta=1e-5)
        eps_lap = self.laplace_acc.eps_sum
        self._eps_history.append(eps_lap + eps_gauss)

    def get_privacy_spent(self, delta: float):
        """
        Returns the composed (ε,δ):
          ε = ε_laplace + ε_gauss
          δ = δ  (pure-DP from Laplace contributes no δ)
        """
        eps_gauss, _ = self.gauss_acc.get_privacy_spent(delta)
        eps_lap = self.laplace_acc.eps_sum
        return eps_lap + eps_gauss, delta
        
    @property
    def epsilon_laplace(self):
        """Get the current Laplace privacy cost."""
        return self.laplace_acc.eps_sum
        
    @property
    def epsilon_gaussian(self):
        """Get the current Gaussian privacy cost."""
        eps_gauss, _ = self.gauss_acc.get_privacy_spent(delta=1e-5)
        return eps_gauss
        
    @property
    def total_steps(self):
        """Get the total number of Gaussian steps."""
        return self.gauss_acc.total_steps


In [None]:
# Split Federated Learning implementation

class SFLClient:
    """
    Client in Split Federated Learning.
    Manages local data, client-side model (WC), performs local computations,
    applies noise, and communicates intermediate results.
    """
    def __init__(self, client_id: int, client_model: nn.Module, dataloader: DataLoader, config: dict, device: torch.device):
        """
        Args:
            client_id: Unique identifier for the client.
            client_model: A *copy* of the initial client-side model (WC) architecture.
            dataloader: DataLoader for the client's local dataset partition.
            config: Configuration dictionary.
            device: The torch device ('cpu' or 'cuda').
        """
        self.client_id = client_id
        self.client_model = client_model.to(device)
        self.dataloader = dataloader
        self.config = config
        self.device = device
        self.optimizer = self._create_optimizer() # Optimizer for WC

        # Store intermediate activation for backward pass
        self._activations = None
        self._data_batch = None # Store data batch to access individual samples for clipping
        self._labels_batch = None # Store labels corresponding to the data batch

        # Adaptive DP: Store previous round's gradient norms and current noise scale
        self._prev_round_grad_norms = None
        self._current_clip_threshold = config['dp_noise']['clip_norm'] # Initial threshold
        self._current_sigma = config['dp_noise']['initial_sigma'] # Initial noise scale

    def _create_optimizer(self) -> optim.Optimizer:
        """Creates the optimizer for the client-side model (WC)."""
        lr = self.config.get('lr', 0.01)
        optimizer_name = self.config.get('optimizer', 'SGD').lower()
        if optimizer_name == 'sgd':
            return optim.SGD(self.client_model.parameters(), lr=lr)
        elif optimizer_name == 'adam':
            return optim.Adam(self.client_model.parameters(), lr=lr)
        else:
            raise ValueError(f"Unsupported optimizer: {optimizer_name}")

    def set_model_params(self, global_params: OrderedDict):
        """Updates the local client model (WC) with parameters from the FedServer."""
        self.client_model.load_state_dict(global_params)

    def update_noise_scale(self, new_sigma: float):
        """Updates the current noise scale sigma_t."""
        self._current_sigma = new_sigma

    def _calculate_adaptive_clip_threshold(self) -> float:
        """Calculates the adaptive clipping threshold Ck_t for the current round."""
        # Check if adaptive clipping is disabled (adaptive_clipping_factor = 0.0)
        if self.config['dp_noise']['adaptive_clipping_factor'] == 0.0:
            # For fixed DP, always use the initial clip norm
            return self.config['dp_noise']['clip_norm']
            
        if self._prev_round_grad_norms is None:
            # First round: use initial threshold
            return self.config['dp_noise']['clip_norm']
        
        # Calculate mean norm from previous round
        mean_norm = torch.mean(torch.tensor(self._prev_round_grad_norms, device=self.device))
        # Apply adaptive factor
        adaptive_factor = self.config['dp_noise']['adaptive_clipping_factor']
        return float(adaptive_factor * mean_norm)

    def local_forward_pass(self) -> tuple:
        """
        Performs the forward pass on the client model (WC) using one batch of local data.
        Applies Laplacian noise to the activations before returning.
        """
        try:
            data, labels = next(iter(self.dataloader))
        except StopIteration:
            print(f"Client {self.client_id}: Dataloader exhausted. Re-initializing for simulation.")
            self.dataloader = DataLoader(self.dataloader.dataset, batch_size=self.config['batch_size'], shuffle=True)
            data, labels = next(iter(self.dataloader))

        data, labels = data.to(self.device), labels.to(self.device)
        self._data_batch = data
        self._labels_batch = labels

        self.optimizer.zero_grad()
        activations = self.client_model(data)

        # Laplacian Noise (Mechanism 1)
        sensitivity = self.config['dp_noise']['laplacian_sensitivity']
        epsilon_prime = self.config['dp_noise']['epsilon_prime']

        if sensitivity > 0:
            noisy_activations = add_laplacian_noise(
                activations,
                sensitivity,
                epsilon_prime,
                device=self.device
            )
        else:
            noisy_activations = activations

        self._activations_for_backward = activations  # Store the original activations for backward pass
        return noisy_activations.detach().clone(), labels.clone()

    def local_backward_pass(self, activation_grads: torch.Tensor) -> dict:
        """
        Performs the backward pass with adaptive clipping and noise.
        """
        if self._activations_for_backward is None or self._data_batch is None:
            raise RuntimeError("Client must perform forward pass before backward pass.")

        # Calculate adaptive clipping threshold for this round
        self._current_clip_threshold = self._calculate_adaptive_clip_threshold()
        
        summed_clipped_grads = OrderedDict([(name, torch.zeros_like(param)) 
                                          for name, param in self.client_model.named_parameters() 
                                          if param.requires_grad])
        
        batch_size = self._data_batch.size(0)
        activation_grads = activation_grads.to(self.device)
        current_round_grad_norms = [] # Store norms for next round's threshold calculation

        # Per-sample gradient computation with adaptive clipping
        for i in range(batch_size):
            self.optimizer.zero_grad()
            sample_activation = self._activations_for_backward[i:i+1]
            sample_activation_grad = activation_grads[i:i+1]
            
            sample_activation.backward(gradient=sample_activation_grad, retain_graph=True)
            
            # Calculate L2 norm of gradients for this sample
            total_norm_sq = torch.zeros(1, device=self.device)
            for name, param in self.client_model.named_parameters():
                if param.grad is not None:
                    total_norm_sq += param.grad.norm(2).item() ** 2
            
            total_norm = torch.sqrt(total_norm_sq)
            current_round_grad_norms.append(total_norm.item())
            
            # Clip gradients using adaptive threshold
            clip_coef = min(1.0, self._current_clip_threshold / (total_norm + 1e-6))
            
            for name, param in self.client_model.named_parameters():
                if param.grad is not None:
                    summed_clipped_grads[name] += param.grad.data * clip_coef

        # Store gradient norms for next round's threshold calculation
        self._prev_round_grad_norms = current_round_grad_norms

        # Add Gaussian noise with adaptive scale
        noisy_gradients = OrderedDict()
        if self._current_sigma > 0:
            for name, summed_grad in summed_clipped_grads.items():
                noisy_gradients[name] = add_gaussian_noise(
                    summed_grad,
                    self._current_clip_threshold,
                    self._current_sigma,
                    device=self.device
                )
        else:
            noisy_gradients = summed_clipped_grads

        # Clear intermediate values
        self._activations_for_backward = None
        self._data_batch = None
        self._labels_batch = None
        self.optimizer.zero_grad()

        return noisy_gradients

class MainServer:
    """
    Main Server in Split Federated Learning.
    Manages the server-side model (WS), processes client activations,
    computes gradients for clients, and updates the server model.
    """
    def __init__(self, server_model: nn.Module, config: dict, device: torch.device):
        """
        Args:
            server_model: The server-side model (WS) architecture.
            config: Configuration dictionary.
            device: The torch device ('cpu' or 'cuda').
        """
        self.server_model = server_model.to(device)
        self.config = config
        self.device = device
        self.optimizer = self._create_optimizer()
        self.criterion = nn.CrossEntropyLoss()
        
        # Store client data
        self.client_data = {}  # {client_id: (activations, labels)}
        self.activation_gradients = {}  # {client_id: gradients}
        
    def _create_optimizer(self) -> optim.Optimizer:
        """Creates the optimizer for the server-side model (WS)."""
        lr = self.config.get('lr', 0.01)
        optimizer_name = self.config.get('optimizer', 'SGD').lower()
        if optimizer_name == 'sgd':
            return optim.SGD(self.server_model.parameters(), lr=lr)
        elif optimizer_name == 'adam':
            return optim.Adam(self.server_model.parameters(), lr=lr)
        else:
            raise ValueError(f"Unsupported optimizer: {optimizer_name}")
    
    def receive_client_data(self, client_id: int, activations: torch.Tensor, labels: torch.Tensor):
        """Stores client activations and labels for processing."""
        self.client_data[client_id] = (activations.to(self.device), labels.to(self.device))
    
    def clear_round_data(self):
        """Clears all client data from the current round."""
        self.client_data = {}
        self.activation_gradients = {}
    
    def forward_backward_pass(self, client_id: int) -> torch.Tensor:
        """
        Performs forward and backward pass on the server model for a client.
        
        Args:
            client_id: The ID of the client whose data to process.
            
        Returns:
            The gradient of the activations.
        """
        if client_id not in self.client_data:
            raise ValueError(f"No data received from client {client_id}")
        
        activations, labels = self.client_data[client_id]
        
        # Forward pass
        outputs = self.server_model(activations)
        loss = self.criterion(outputs, labels)
        
        # Backward pass
        self.optimizer.zero_grad()
        loss.backward()
        
        # Store and return activation gradients
        act_grad = activations.grad.clone()
        self.activation_gradients[client_id] = act_grad
        
        return act_grad
    
    def aggregate_and_update(self):
        """Aggregates losses and updates the server model."""
        # Compute average loss if we have client data
        if not self.client_data:
            return 0.0
        
        # Since the backward passes have already populated the gradients,
        # we just need to step the optimizer
        self.optimizer.step()
        
        return self.server_model.parameters()
    
    def get_server_model(self) -> nn.Module:
        """Returns the current server model."""
        return self.server_model
    
    def get_criterion(self):
        """Returns the loss criterion used by the server."""
        return self.criterion

class FedServer:
    """
    Federation Server in Split Federated Learning.
    Manages the global client-side model (WC) and aggregates client updates.
    """
    def __init__(self, client_model: nn.Module, config: dict, device: torch.device):
        """
        Args:
            client_model: An instance of the client-side model (WC) architecture.
            config: Configuration dictionary.
            device: The torch device ('cpu' or 'cuda').
        """
        self.client_model = client_model.to(device)
        self.config = config
        self.device = device
        self.optimizer = self._create_optimizer()
        self._client_updates = []  # Stores received client updates in a round

        # Adaptive DP: Track validation loss and noise scale
        self._current_sigma = config['dp_noise']['initial_sigma']
        self._validation_losses = []  # Store recent validation losses
        self._noise_decay_patience = config['dp_noise']['noise_decay_patience']
        self._adaptive_noise_decay_factor = config['dp_noise']['adaptive_noise_decay_factor']
        self._criterion = nn.CrossEntropyLoss()
        self._sigma_history = [self._current_sigma]  # Track sigma history

    def _create_optimizer(self) -> optim.Optimizer:
        """Creates the optimizer for the global client-side model (WC)."""
        lr = self.config.get('lr', 0.01)
        optimizer_name = self.config.get('optimizer', 'SGD').lower()

        if optimizer_name == 'sgd':
            return optim.SGD(self.client_model.parameters(), lr=lr)
        elif optimizer_name == 'adam':
            return optim.Adam(self.client_model.parameters(), lr=lr)
        else:
            raise ValueError(f"Unsupported optimizer: {optimizer_name}")

    def get_client_model_params(self) -> OrderedDict:
        """Returns the state dictionary of the current global client model (WC)."""
        return self.client_model.state_dict()

    def receive_client_update(self, client_update: dict):
        """
        Receives and stores an update (typically gradients) from a client.
        """
        # Ensure updates are on the correct device and detached
        processed_update = OrderedDict()
        for name, param in client_update.items():
            processed_update[name] = param.detach().clone().to(self.device)
        self._client_updates.append(processed_update)

    def _update_noise_scale(self, validation_loss: float):
        """
        Updates the noise scale based on validation loss trend.
        """
        self._validation_losses.append(validation_loss)
        
        # Check if we have enough history to make a decision
        if len(self._validation_losses) < self._noise_decay_patience + 1:
            return
        
        # Check if loss has been decreasing for the required number of rounds
        recent_losses = self._validation_losses[-self._noise_decay_patience:]
        is_decreasing = all(recent_losses[i] > recent_losses[i+1] 
                          for i in range(len(recent_losses)-1))
        
        if is_decreasing:
            # Decrease noise scale
            self._current_sigma *= self._adaptive_noise_decay_factor
            print(f"FedServer: Loss decreasing for {self._noise_decay_patience} rounds. "
                  f"Updated noise scale to {self._current_sigma:.4f}")
            # Track sigma change
            self._sigma_history.append(self._current_sigma)

    def get_current_sigma(self) -> float:
        """Returns the current noise scale sigma_t."""
        return self._current_sigma

    def aggregate_updates(self, validation_loader=None, main_server=None):
        """
        Aggregates the received client updates using FedAvg and updates the global client model (WC).
        """
        if not self._client_updates:
            print("FedServer: No client updates received for aggregation.")
            return

        # Federated averaging for gradients
        averaged_gradients = self.federated_averaging_gradients(self._client_updates)

        # Update the global client model parameters using the averaged gradients via the optimizer
        self.optimizer.zero_grad()
        with torch.no_grad():  # Manually assign gradients
            for name, param in self.client_model.named_parameters():
                if name in averaged_gradients:
                    if param.grad is None:
                        param.grad = torch.zeros_like(param)
                    param.grad.copy_(averaged_gradients[name])
                else:
                    if param.grad is not None:
                        param.grad.zero_()

        self.optimizer.step()  # Update parameters using assigned gradients

        # Update noise scale and evaluate metrics if validation loader is provided
        if validation_loader is not None:
            validation_loss, accuracy = self.evaluate_metrics(validation_loader, main_server)
            self._update_noise_scale(validation_loss)

        # Clear updates for the next round
        self._client_updates = []

    def evaluate_metrics(self, validation_loader, main_server=None) -> tuple:
        """
        Evaluates and prints both validation loss and accuracy.
        """
        self.client_model.eval()
        total_loss = 0.0
        correct = 0
        total = 0
        
        # Check if we can evaluate complete model
        evaluate_complete = main_server is not None
        
        if evaluate_complete:
            # Get server model
            server_model = main_server.get_server_model()
            server_model.eval()
        
        with torch.no_grad():
            for data, labels in validation_loader:
                data, labels = data.to(self.device), labels.to(self.device)
                
                # Forward pass through client model
                client_outputs = self.client_model(data)
                
                if evaluate_complete:
                    # Complete forward pass through server model
                    outputs = server_model(client_outputs)
                    loss = main_server.get_criterion()(outputs, labels)
                    
                    # Calculate accuracy
                    _, predicted = torch.max(outputs.data, 1)
                else:
                    # Just use client outputs (which aren't actual predictions)
                    outputs = client_outputs
                    loss = self._criterion(outputs, labels)
                    _, predicted = torch.max(outputs.data, 1)
                
                total_loss += loss.item()
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        avg_loss = total_loss / len(validation_loader)
        accuracy = 100 * correct / total
        
        print(f"\nValidation Metrics:")
        if evaluate_complete:
            print(f"Complete Model - Loss: {avg_loss:.4f} | Accuracy: {accuracy:.2f}%")
        else:
            print(f"Client Model Only (incomplete) - Loss: {avg_loss:.4f}")
        
        return avg_loss, accuracy

    def federated_averaging_gradients(self, gradients_list):
        """
        Implements FedAvg for a list of gradients.
        """
        # Initialize with zeros
        avg_grad = OrderedDict()
        
        # Get the first client's gradients to determine the structure
        for name, grad in gradients_list[0].items():
            avg_grad[name] = torch.zeros_like(grad)
            
        # Sum all gradients
        for gradients in gradients_list:
            for name, grad in gradients.items():
                avg_grad[name] += grad
                
        # Divide by the number of clients to get the average
        for name in avg_grad.keys():
            avg_grad[name] = avg_grad[name] / len(gradients_list)
            
        return avg_grad


In [None]:
# Default configuration for experiments
default_config = {
    'batch_size': 32,
    'optimizer': 'adam',
    'lr': 0.001,
    'num_rounds': 50,
    'num_clients': 5,
    'log_interval': 5,
    'split_layer': 1,  # Layer index at which to split the model
    'seed': 42,
    'dp_noise': {
        'mode': 'hybrid',  # 'vanilla', 'adaptive', or 'hybrid'
        'initial_sigma': 1.0,  # Initial noise multiplier for Gaussian
        'clip_norm': 1.0,  # Initial gradient clipping norm
        'adaptive_clipping_factor': 1.5,  # For adaptive clipping (C_t = factor * mean_norm)
        'adaptive_noise_decay_factor': 0.9,  # Decay factor for sigma when validation loss decreases
        'noise_decay_patience': 3,  # Number of rounds with decreasing loss before decaying sigma
        'delta': 1e-5,  # Target delta for privacy accounting
        'laplacian_sensitivity': 0.1,  # Sensitivity for Laplacian noise on activations
        'epsilon_prime': 0.1,  # Privacy budget for Laplacian mechanism
        'validation_set_ratio': 0.1  # Ratio of training data to use for validation
    }
}

def create_experiment_configs():
    """Create different experiment configurations for comparison."""
    
    configs = []
    
    # Base configuration combinations
    datasets = ['mnist', 'fashion_mnist', 'cifar10']
    models = {
        'mnist': 'simple_cnn',
        'fashion_mnist': 'simple_cnn',
        'cifar10': 'cifar_cnn'
    }
    dp_modes = ['vanilla', 'adaptive']
    distributions = ['iid', 'dirichlet']
    alphas = [0.1, 0.5, 1.0, 10.0]  # Dirichlet concentration parameters
    client_counts = [5, 10, 20]
    
    # Generate combinations
    for dataset in datasets:
        model = models[dataset]
        for dp_mode in dp_modes:
            for distribution in distributions:
                if distribution == 'dirichlet':
                    for alpha in alphas:
                        for num_clients in client_counts:
                            # Create config variation
                            config = copy.deepcopy(default_config)
                            config['dataset'] = dataset
                            config['model'] = model
                            config['dp_noise']['mode'] = dp_mode
                            config['partition_method'] = distribution
                            config['dirichlet_alpha'] = alpha
                            config['num_clients'] = num_clients
                            
                            # Create a unique identifier for this experiment
                            config['id'] = f"{dataset}_{model}_{dp_mode}_{distribution}_alpha{alpha}_clients{num_clients}"
                            
                            configs.append(config)
                else:  # IID distribution
                    for num_clients in client_counts:
                        # Create config variation
                        config = copy.deepcopy(default_config)
                        config['dataset'] = dataset
                        config['model'] = model
                        config['dp_noise']['mode'] = dp_mode
                        config['partition_method'] = distribution
                        config['num_clients'] = num_clients
                        
                        # Create a unique identifier for this experiment
                        config['id'] = f"{dataset}_{model}_{dp_mode}_{distribution}_clients{num_clients}"
                        
                        configs.append(config)
    
    return configs

# Generate experiment configurations
experiment_configs = create_experiment_configs()
print(f"Generated {len(experiment_configs)} experiment configurations.")

# Show a sample configuration
sample_config = experiment_configs[0]
print("\nSample experiment configuration:")
for key, value in sample_config.items():
    if key != 'dp_noise':  # Don't print the nested dp_noise dictionary
        print(f"{key}: {value}")
print("\nDP Noise settings:")
for key, value in sample_config['dp_noise'].items():
    print(f"  {key}: {value}")


In [None]:
# Training function
def train_sfl(config, results_dict=None):
    """
    Train a model using Split Federated Learning.
    
    Args:
        config: Configuration dictionary
        results_dict: Optional dictionary to store results
        
    Returns:
        Dictionary of metrics
    """
    # Set random seed for reproducibility
    set_seed(config['seed'])
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load dataset
    train_dataset, test_dataset = load_dataset(config['dataset'])
    
    # Split training data into client data and validation set
    validation_ratio = config['dp_noise']['validation_set_ratio']
    validation_size = int(len(train_dataset) * validation_ratio)
    train_size = len(train_dataset) - validation_size
    
    train_subset, validation_subset = torch.utils.data.random_split(
        train_dataset, [train_size, validation_size]
    )
    
    validation_loader = torch.utils.data.DataLoader(
        validation_subset,
        batch_size=config['batch_size'],
        shuffle=False,
        pin_memory=torch.cuda.is_available()
    )
    
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        pin_memory=torch.cuda.is_available()
    )
    
    # Create client dataloaders
    distribution = config.get('partition_method', 'iid')
    alpha = config.get('dirichlet_alpha', 1.0)  # Only used for dirichlet distribution
    
    client_loaders = get_client_dataloaders(
        train_subset, 
        config['num_clients'], 
        config['batch_size'],
        distribution=distribution,
        alpha=alpha
    )
    
    # Initialize model
    full_model = get_model(config['model'], config['dataset'], config['split_layer'])
    
    # Split model
    client_model = nn.ModuleList()
    server_model = nn.ModuleList()
    
    # Extract client and server parts
    for i, layer in enumerate(full_model.client_layers):
        if i < config['split_layer']:
            client_model.append(layer)
    
    for layer in full_model.server_layers:
        server_model.append(layer)
    
    # Wrap in nn.Sequential
    client_model = nn.Sequential(*client_model)
    server_model = nn.Sequential(*server_model)
    
    # Create FedServer and MainServer
    fed_server = FedServer(client_model, config, device)
    main_server = MainServer(server_model, config, device)
    
    # Create clients
    clients = []
    for i in range(config['num_clients']):
        client_model_copy = copy.deepcopy(client_model)
        client = SFLClient(
            client_id=i, 
            client_model=client_model_copy, 
            dataloader=client_loaders[i], 
            config=config, 
            device=device
        )
        clients.append(client)
    
    # Privacy accountant
    sampling_rate = config['batch_size'] / train_size
    privacy_accountant = HybridAccountant(
        noise_multiplier=config['dp_noise']['initial_sigma'],
        sampling_rate=sampling_rate
    )
    
    # Training metrics
    metrics = {
        'train_loss': [],
        'validation_loss': [],
        'test_accuracy': [],
        'epsilon': [],
        'epsilon_laplace': [],
        'epsilon_gaussian': [],
        'sigma': [],
        'round_times': []
    }
    
    # Training loop
    num_rounds = config['num_rounds']
    log_interval = config['log_interval']
    
    # Initial test accuracy
    test_accuracy = evaluate_model(full_model, test_loader, device)
    metrics['test_accuracy'].append(test_accuracy)
    print(f"Initial Test Accuracy: {test_accuracy:.2f}%")
    
    for round_num in tqdm(range(num_rounds), desc="Training Rounds"):
        round_start_time = time.time()
        
        # Broadcast current noise scale to clients
        current_sigma = fed_server.get_current_sigma()
        for client in clients:
            client.update_noise_scale(current_sigma)
        
        # Broadcast client model parameters
        global_client_params = fed_server.get_client_model_params()
        for client in clients:
            client.set_model_params(global_client_params)
        
        # Clear previous round data
        client_activation_grads = {}
        client_noisy_wc_grads = []
        main_server.clear_round_data()
        
        # Client forward passes
        client_data_for_main_server = {}
        for client in clients:
            noisy_activations, labels = client.local_forward_pass()
            client_data_for_main_server[client.client_id] = (noisy_activations, labels)
            
            # Track Laplace privacy cost
            if config['dp_noise']['laplacian_sensitivity'] > 0:
                privacy_accountant.laplace_step(config['dp_noise']['epsilon_prime'])
        
        # Server processes client data
        for client_id, (noisy_acts, lbls) in client_data_for_main_server.items():
            main_server.receive_client_data(client_id, noisy_acts, lbls)
        
        # Server computes activation gradients
        for client_id in client_data_for_main_server.keys():
            act_grad = main_server.forward_backward_pass(client_id)
            client_activation_grads[client_id] = act_grad
        
        # Server aggregates and updates its model
        main_server.aggregate_and_update()
        
        # Clients compute gradients and apply Gaussian noise
        for client in clients:
            if client.client_id in client_activation_grads:
                activation_grad = client_activation_grads[client.client_id]
                noisy_wc_grad = client.local_backward_pass(activation_grad)
                client_noisy_wc_grads.append(noisy_wc_grad)
                
                # Track Gaussian privacy cost
                if current_sigma > 0:
                    privacy_accountant.gaussian_step(noise_multiplier=current_sigma)
            else:
                print(f"Warning: No activation gradient received for Client {client.client_id}")
        
        # Fed server aggregates client updates
        for noisy_grad in client_noisy_wc_grads:
            fed_server.receive_client_update(noisy_grad)
        
        # Fed server updates global client model
        fed_server.aggregate_updates(validation_loader, main_server)
        
        # Record metrics
        if (round_num + 1) % log_interval == 0 or round_num == num_rounds - 1:
            # Evaluate the combined model
            # First update the full model with client and server parameters
            client_params = fed_server.get_client_model_params()
            server_params = main_server.get_server_model().state_dict()
            
            # Set client part parameters
            for i, layer in enumerate(full_model.client_layers):
                if i < config['split_layer']:
                    for name, param in layer.state_dict().items():
                        full_param_name = f"{i}.{name}"
                        if full_param_name in client_params:
                            param.copy_(client_params[full_param_name])
            
            # Set server part parameters
            server_part_idx = 0
            for i, layer in enumerate(full_model.server_layers):
                for name, param in layer.state_dict().items():
                    full_param_name = f"{i}.{name}"
                    if full_param_name in server_params:
                        param.copy_(server_params[full_param_name])
            
            # Evaluate on test set
            test_accuracy = evaluate_model(full_model, test_loader, device)
            metrics['test_accuracy'].append(test_accuracy)
            
            # Get privacy cost
            epsilon, delta = privacy_accountant.get_privacy_spent(delta=config['dp_noise']['delta'])
            metrics['epsilon'].append(epsilon)
            metrics['epsilon_laplace'].append(privacy_accountant.epsilon_laplace)
            metrics['epsilon_gaussian'].append(privacy_accountant.epsilon_gaussian)
            metrics['sigma'].append(current_sigma)
            
            print(f"\nRound {round_num + 1}:")
            print(f"  Test Accuracy: {test_accuracy:.2f}%")
            print(f"  Privacy Budget (ε, δ): ({epsilon:.4f}, {config['dp_noise']['delta']})")
            print(f"  Current Noise Scale (σ): {current_sigma:.4f}")
        
        # Record round time
        round_end_time = time.time()
        metrics['round_times'].append(round_end_time - round_start_time)
    
    # Final metrics for results dictionary
    if results_dict is not None:
        results_dict[config['id']] = {
            'config': config,
            'final_test_acc': metrics['test_accuracy'][-1],
            'final_epsilon': metrics['epsilon'][-1],
            'final_epsilon_laplace': metrics['epsilon_laplace'][-1],
            'final_epsilon_gaussian': metrics['epsilon_gaussian'][-1],
            'final_sigma': metrics['sigma'][-1],
            'test_accuracy_history': metrics['test_accuracy'],
            'epsilon_history': metrics['epsilon'],
            'sigma_history': metrics['sigma'],
            'total_training_time': sum(metrics['round_times'])
        }
    
    return metrics

def evaluate_model(model, test_loader, device):
    """Evaluates the model on the test dataset."""
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, labels in test_loader:
            data, labels = data.to(device), labels.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy


In [None]:
# Run experiments and save results
def run_experiments(configs, results_csv_path='results.csv', subset=None):
    """
    Run experiments based on configurations and save results.
    
    Args:
        configs: List of configuration dictionaries
        results_csv_path: Path to save CSV results
        subset: Optional integer to run only a subset of configurations
        
    Returns:
        Dictionary of results
    """
    # Subset configurations if requested
    if subset is not None and subset > 0:
        configs = configs[:subset]
    
    # Dictionary to store results
    results = {}
    
    # Run each experiment
    for i, config in enumerate(configs):
        print(f"\n\n{'='*80}")
        print(f"Experiment {i+1}/{len(configs)}: {config['id']}")
        print(f"{'='*80}")
        
        # Run the experiment
        try:
            metrics = train_sfl(config, results)
            
            # Add results to dictionary
            results[config['id']] = {
                'config': config,
                'metrics': metrics
            }
            
        except Exception as e:
            print(f"Error in experiment {config['id']}: {str(e)}")
            import traceback
            traceback.print_exc()
    
    # Convert results to DataFrame
    results_df = pd.DataFrame([
        {
            'experiment_id': exp_id,
            'dataset': results[exp_id]['config']['dataset'],
            'model': results[exp_id]['config']['model'],
            'dp_mode': results[exp_id]['config']['dp_noise']['mode'],
            'distribution': results[exp_id]['config']['partition_method'],
            'alpha': results[exp_id]['config'].get('dirichlet_alpha', 'N/A'),
            'num_clients': results[exp_id]['config']['num_clients'],
            'final_test_acc': results[exp_id]['final_test_acc'],
            'final_epsilon': results[exp_id]['final_epsilon'],
            'final_epsilon_laplace': results[exp_id]['final_epsilon_laplace'],
            'final_epsilon_gaussian': results[exp_id]['final_epsilon_gaussian'],
            'final_sigma': results[exp_id]['final_sigma'],
            'total_training_time': results[exp_id]['total_training_time']
        }
        for exp_id in results if 'final_test_acc' in results[exp_id]
    ])
    
    # Save results to CSV
    results_df.to_csv(results_csv_path, index=False)
    print(f"Results saved to {results_csv_path}")
    
    return results, results_df


In [None]:
# Visualization functions
def plot_accuracy_vs_epsilon(results_df, dataset=None, distribution=None, num_clients=None):
    """
    Plot accuracy vs epsilon for different DP modes.
    
    Args:
        results_df: DataFrame of results
        dataset: Optional filter for dataset
        distribution: Optional filter for distribution ('iid' or 'dirichlet')
        num_clients: Optional filter for number of clients
    """
    # Filter results
    filtered_df = results_df.copy()
    if dataset is not None:
        filtered_df = filtered_df[filtered_df['dataset'] == dataset]
    if distribution is not None:
        filtered_df = filtered_df[filtered_df['distribution'] == distribution]
    if num_clients is not None:
        filtered_df = filtered_df[filtered_df['num_clients'] == num_clients]
    
    # Group by DP mode
    grouped = filtered_df.groupby('dp_mode')
    
    plt.figure(figsize=(10, 6))
    
    for name, group in grouped:
        plt.scatter(
            group['final_epsilon'], 
            group['final_test_acc'], 
            label=name, 
            alpha=0.7,
            s=50
        )
    
    plt.xlabel('Privacy Budget (ε)')
    plt.ylabel('Test Accuracy')
    plt.title('Accuracy vs Privacy Budget by DP Mode')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

def plot_accuracy_vs_distribution(results_df, dataset=None, dp_mode=None):
    """
    Plot accuracy for different data distributions.
    
    Args:
        results_df: DataFrame of results
        dataset: Optional filter for dataset
        dp_mode: Optional filter for DP mode
    """
    # Filter results
    filtered_df = results_df.copy()
    if dataset is not None:
        filtered_df = filtered_df[filtered_df['dataset'] == dataset]
    if dp_mode is not None:
        filtered_df = filtered_df[filtered_df['dp_mode'] == dp_mode]
    
    # Create distribution category for grouping
    def get_distribution_category(row):
        if row['distribution'] == 'iid':
            return 'IID'
        else:  # dirichlet
            alpha = row['alpha']
            if alpha <= 0.1:
                return 'Highly Non-IID (α=0.1)'
            elif alpha <= 0.5:
                return 'Moderately Non-IID (α=0.5)'
            elif alpha <= 1.0:
                return 'Slightly Non-IID (α=1.0)'
            else:
                return 'Nearly IID (α=10.0)'
    
    filtered_df['distribution_category'] = filtered_df.apply(get_distribution_category, axis=1)
    
    # Group by distribution category
    grouped = filtered_df.groupby(['distribution_category', 'num_clients'])
    
    # Calculate mean accuracy for each group
    mean_acc = grouped['final_test_acc'].mean().reset_index()
    pivot_df = mean_acc.pivot(index='distribution_category', columns='num_clients', values='final_test_acc')
    
    # Sort categories
    category_order = [
        'IID', 
        'Nearly IID (α=10.0)', 
        'Slightly Non-IID (α=1.0)', 
        'Moderately Non-IID (α=0.5)', 
        'Highly Non-IID (α=0.1)'
    ]
    pivot_df = pivot_df.reindex(category_order)
    
    plt.figure(figsize=(12, 6))
    pivot_df.plot(kind='bar', ax=plt.gca())
    
    plt.xlabel('Data Distribution')
    plt.ylabel('Average Test Accuracy')
    plt.title('Impact of Data Distribution on Accuracy')
    plt.legend(title='Number of Clients')
    plt.grid(True, axis='y', alpha=0.3)
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

def plot_privacy_decay(results, exp_id):
    """
    Plot privacy parameter decay over training rounds.
    
    Args:
        results: Dictionary of results
        exp_id: Experiment ID to plot
    """
    if exp_id not in results or 'metrics' not in results[exp_id]:
        print(f"No results found for experiment {exp_id}")
        return
    
    metrics = results[exp_id]['metrics']
    
    plt.figure(figsize=(12, 8))
    
    # Plot in a 2x1 grid
    plt.subplot(2, 1, 1)
    plt.plot(metrics['epsilon'], label='Total ε', marker='o')
    plt.plot(metrics['epsilon_laplace'], label='Laplace ε', marker='s')
    plt.plot(metrics['epsilon_gaussian'], label='Gaussian ε', marker='^')
    plt.xlabel('Logging Step')
    plt.ylabel('Privacy Budget (ε)')
    plt.title('Privacy Budget Evolution')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(2, 1, 2)
    plt.plot(metrics['sigma'], label='Noise Scale (σ)', marker='o')
    plt.plot(metrics['test_accuracy'], label='Test Accuracy', marker='s')
    plt.xlabel('Logging Step')
    plt.ylabel('Value')
    plt.title('Noise Scale and Accuracy Evolution')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()


In [None]:
# Run a subset of experiments to demonstrate functionality
# Note: Running all experiments may take several hours or days
# You can adjust the subset size as needed

# Reduce number of rounds for faster execution
for config in experiment_configs:
    config['num_rounds'] = 10  # Reduced from 50 for demonstration
    config['log_interval'] = 2  # Log more frequently

# Run a small subset of experiments for demonstration
subset_size = 4  # Adjust this based on available computation time
results, results_df = run_experiments(experiment_configs, results_csv_path='results.csv', subset=subset_size)

# Display results summary
print("\nResults Summary:")
print(results_df[['experiment_id', 'dataset', 'dp_mode', 'distribution', 'alpha', 'num_clients', 'final_test_acc', 'final_epsilon']])

# Plot results
if len(results_df) > 0:
    # Plot accuracy vs privacy budget
    plot_accuracy_vs_epsilon(results_df)
    
    # Plot first experiment's privacy decay
    first_exp_id = list(results.keys())[0]
    plot_privacy_decay(results, first_exp_id)
    
    # If we have enough data with different distributions
    if len(results_df['distribution'].unique()) > 1:
        plot_accuracy_vs_distribution(results_df)


In [None]:
# Additional analysis functions

def compare_dp_modes(results_df, metric='final_test_acc'):
    """Compare different DP modes across all experiments."""
    grouped = results_df.groupby('dp_mode')[metric].agg(['mean', 'std', 'min', 'max'])
    
    print(f"Comparison of DP modes by {metric}:")
    print(grouped)
    
    # Create box plot
    plt.figure(figsize=(10, 6))
    results_df.boxplot(column=metric, by='dp_mode', grid=False)
    plt.title(f'Distribution of {metric} by DP Mode')
    plt.suptitle('')  # Remove default title
    plt.ylabel(metric)
    plt.show()

def calculate_privacy_utility_ratio(results_df):
    """Calculate and compare privacy-utility ratio (accuracy/epsilon)."""
    # Higher ratio means better utility per privacy cost
    results_df['privacy_utility_ratio'] = results_df['final_test_acc'] / results_df['final_epsilon']
    
    grouped = results_df.groupby('dp_mode')['privacy_utility_ratio'].agg(['mean', 'std', 'min', 'max'])
    
    print("Privacy-Utility Ratio (Accuracy/Epsilon) - higher is better:")
    print(grouped)
    
    # Create box plot
    plt.figure(figsize=(10, 6))
    results_df.boxplot(column='privacy_utility_ratio', by='dp_mode', grid=False)
    plt.title('Privacy-Utility Ratio by DP Mode')
    plt.suptitle('')  # Remove default title
    plt.ylabel('Accuracy/Epsilon Ratio')
    plt.show()
    
    return results_df

# Execute these analyses if we have results
if 'results_df' in locals() and len(results_df) > 0:
    compare_dp_modes(results_df, 'final_test_acc')
    compare_dp_modes(results_df, 'final_epsilon')
    results_df = calculate_privacy_utility_ratio(results_df)
else:
    print("No results available yet. Run the experiments first.")
