In [1]:
!cd /home/ir739wb/ilyarekun/bc_project/federated-learning/



In [2]:
import os
import flwr as fl                    # Flower framework for federated learning
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import ImageFolder
import numpy as np
import kagglehub                     # Utility to download datasets from Kaggle
import shutil                        # For file operations like moving or copying files
from sklearn.metrics import precision_recall_fscore_support  # For computing classification metrics
import matplotlib.pyplot as plt
from flwr.common import parameters_to_ndarrays  # Convert model parameters to NumPy arrays for Flower

# --------------------------------------------------------------------------------
# Set seeds for reproducibility across CPU, GPU, and NumPy
# --------------------------------------------------------------------------------
seed = 42
torch.manual_seed(seed)                # Seed PyTorch CPU random number generator
torch.cuda.manual_seed(seed)           # Seed the current GPU’s random number generator
torch.cuda.manual_seed_all(seed)       # Seed all available GPUs
np.random.seed(seed)                   # Seed NumPy random number generator
# Ensure deterministic behavior in cuDNN (may reduce performance but ensures reproducible results)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# --------------------------------------------------------------------------------
# Define computation device: Use GPU if available, otherwise CPU
# --------------------------------------------------------------------------------
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


In [3]:
class BrainCNN(nn.Module):
    def __init__(self):
        super(BrainCNN, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.45),
            nn.Conv2d(64, 128, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.45),
            nn.Conv2d(128, 128, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2),
            nn.Dropout2d(p=0.45),
            nn.Conv2d(128, 256, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.45),
            nn.Conv2d(256, 256, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(2),
            nn.Dropout2d(p=0.4),
            nn.Conv2d(256, 512, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.4)
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(512 * 3 * 3, 1024),
            nn.ReLU(),
            nn.Dropout(p=0.4),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(p=0.4),
            nn.Linear(512, 4),
        )
        
    def forward(self, x):
        out = self.conv_layers(x)
        out = out.view(out.size(0), -1)  # Flatten
        out = self.fc_layers(out)
        return out

In [4]:
class EarlyStopping:
    """
    EarlyStopping stops training when the validation loss has not improved after a given patience
    or when it falls below a specified threshold.
    """
    def __init__(self, patience=5, delta=0, threshold=0.19):
        """
        Args:
            patience (int): How many calls with no improvement to wait before stopping.
            delta (float): Minimum change in the monitored value to qualify as an improvement.
            threshold (float): Absolute validation loss threshold; stop immediately if reached or below.
        """
        self.patience = patience              # Number of rounds to wait for an improvement
        self.delta = delta                    # Minimum decrease in loss to count as improvement
        self.threshold = threshold            # Immediate stop if validation loss <= threshold
        self.best_score = None                # Best score seen so far (negative val_loss)
        self.early_stop = False               # Flag indicating whether to stop training
        self.counter = 0                      # Counter for rounds without improvement
        self.best_model_state = None          # State dictionary of the best model encountered

    def __call__(self, val_loss, model):
        """
        Check if training should stop based on validation loss.

        Args:
            val_loss (float): Current validation loss.
            model (nn.Module): Model to save if it is the best so far.
        """
        # If validation loss is below or equal to the threshold, stop immediately
        if val_loss <= self.threshold:
            print(f"Val loss {val_loss:.5f} is below threshold {self.threshold}. Stopping training.")
            self.early_stop = True
            # Save the current model state as the best
            self.best_model_state = model.state_dict()
            return

        # Convert validation loss to a score we want to maximize (negative loss)
        score = -val_loss

        # If this is the first call, record current score and model state
        if self.best_score is None:
            self.best_score = score
            self.best_model_state = model.state_dict()
        # If no sufficient improvement, increment counter
        elif score < self.best_score + self.delta:
            self.counter += 1
            # If patience is exceeded, set early_stop flag
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            # Improvement found: update best score and reset counter
            self.best_score = score
            self.best_model_state = model.state_dict()
            self.counter = 0

    def load_best_model(self, model):
        """
        Load the saved best model state into the provided model.

        Args:
            model (nn.Module): Model instance to load the best state into.
        """
        model.load_state_dict(self.best_model_state)


In [5]:
def data_preprocessing_tumor_IID(num_clients=4):
    """
    Download the brain tumor MRI dataset from Kaggle, merge training and testing folders
    into a single directory, perform a stratified train/val/test split, and then
    partition the training data evenly (IID) among a specified number of clients.
    
    Args:
        num_clients (int): Number of clients to split the training set into. Default is 4.
    
    Returns:
        client_train_loaders (list of DataLoader): List of DataLoader objects, one per client.
        val_loader (DataLoader): DataLoader for the validation set.
        test_loader (DataLoader): DataLoader for the test set.
    """
    # Download the dataset from Kaggle and get the local directory path
    dataset_path = kagglehub.dataset_download("masoudnickparvar/brain-tumor-mri-dataset")
    
    # Define paths to the original "Training" and "Testing" subdirectories
    train_path = os.path.join(dataset_path, "Training")
    test_path = os.path.join(dataset_path, "Testing")
    
    # Create a new directory to hold all images together (General_Dataset)
    general_dataset_path = os.path.join(dataset_path, "General_Dataset")
    os.makedirs(general_dataset_path, exist_ok=True)
    
    # Move images from both Training and Testing into General_Dataset, preserving class subfolders
    for source_path in [train_path, test_path]:
        for class_name in os.listdir(source_path):
            class_path = os.path.join(source_path, class_name)
            general_class_path = os.path.join(general_dataset_path, class_name)
            os.makedirs(general_class_path, exist_ok=True)
            for img_name in os.listdir(class_path):
                img_path = os.path.join(class_path, img_name)
                shutil.move(img_path, os.path.join(general_class_path, img_name))
    
    # Define image preprocessing steps: center-crop, resize, convert to tensor
    transform = transforms.Compose([
        transforms.CenterCrop((400, 400)),  # Crop center 400x400 region
        transforms.Resize((200, 200)),      # Resize to 200x200 pixels
        transforms.ToTensor(),              # Convert PIL image to PyTorch tensor
    ])
    
    # Load all images from General_Dataset using ImageFolder (expects subdirectories as class labels)
    general_dataset = ImageFolder(root=general_dataset_path, transform=transform)
    targets = general_dataset.targets             # List of numeric class labels for each image
    classes = list(set(targets))                  # Unique class labels in the dataset
    
    # Initialize lists to hold indices for train/validation/test splits
    train_indices = []
    val_indices = []
    test_indices = []
    
    # Define split ratios for each class
    train_ratio = 0.7
    val_ratio = 0.2
    test_ratio = 0.1
    
    # Perform stratified splitting by class
    for class_label in classes:
        # Collect all indices whose label matches the current class
        class_indices = [i for i, target in enumerate(targets) if target == class_label]
        class_size = len(class_indices)
        
        # Calculate number of samples in each split for this class
        train_size = int(train_ratio * class_size)
        val_size = int(val_ratio * class_size)
        test_size = class_size - train_size - val_size  # Remainder goes to test
        
        # Assign the first train_size indices to the training set
        train_indices.extend(class_indices[:train_size])
        # Assign the next val_size indices to the validation set
        val_indices.extend(class_indices[train_size:train_size + val_size])
        # Assign the remaining indices to the test set
        test_indices.extend(class_indices[train_size + val_size:])
    
    # Create Subset objects based on computed indices for train, val, test
    train_set = Subset(general_dataset, train_indices)
    val_set = Subset(general_dataset, val_indices)
    test_set = Subset(general_dataset, test_indices)
    
    # Create DataLoaders for validation and test sets (no shuffling)
    val_loader = DataLoader(val_set, batch_size=64, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=64, shuffle=False)
    
    # Prepare to split the training indices evenly (IID) among num_clients
    client_indices = {client: [] for client in range(num_clients)}
    for class_label in classes:
        # Filter training indices that belong to the current class
        class_train_indices = [idx for idx in train_indices if general_dataset.targets[idx] == class_label]
        np.random.shuffle(class_train_indices)  # Shuffle to randomize assignment
        
        # Split this class's training indices into num_clients parts
        splits = np.array_split(class_train_indices, num_clients)
        # Assign each split part to the corresponding client's index list
        for client in range(num_clients):
            client_indices[client].extend(splits[client].tolist())
    
    # Create a DataLoader for each client's training subset
    client_train_loaders = []
    for client in range(num_clients):
        subset = Subset(general_dataset, client_indices[client])
        loader = DataLoader(subset, batch_size=64, shuffle=True)
        client_train_loaders.append(loader)
    
    # Return the list of client-specific train loaders, and shared val/test loaders
    return client_train_loaders, val_loader, test_loader

# Compute data loaders for 4 clients
#client_train_loaders, val_loader, test_loader = data_preprocessing_tumor_IID(num_clients=4)


In [6]:
def data_preprocessing_tumor_NON_IID(num_clients=4):
    """
    Download the brain tumor MRI dataset from Kaggle, merge training and testing data
    into a single directory, perform a stratified train/validation/test split, and
    partition the training data in a non-IID manner across multiple clients.
    
    Args:
        num_clients (int): Number of clients to split the training set into. Default is 4.
    
    Returns:
        client_train_loaders (list of DataLoader): List of DataLoader objects, one per client.
        val_loader (DataLoader): DataLoader for the validation set.
        test_loader (DataLoader): DataLoader for the test set.
    """
    # ----------------------------------------
    # Download and merge the dataset
    # ----------------------------------------
    # Download dataset from Kaggle and get local path
    dataset_path = kagglehub.dataset_download("masoudnickparvar/brain-tumor-mri-dataset")
    # Define the original training and testing subdirectories
    train_path = os.path.join(dataset_path, "Training")
    test_path = os.path.join(dataset_path, "Testing")
    # Create a new directory for holding all images together
    general_dataset_path = os.path.join(dataset_path, "General_Dataset")
    os.makedirs(general_dataset_path, exist_ok=True)
    
    # Move all images from "Training" and "Testing" into "General_Dataset"
    # so that ImageFolder can load them uniformly by class subfolders
    for source_path in [train_path, test_path]:
        for class_name in os.listdir(source_path):
            class_path = os.path.join(source_path, class_name)
            general_class_path = os.path.join(general_dataset_path, class_name)
            os.makedirs(general_class_path, exist_ok=True)
            for img_name in os.listdir(class_path):
                img_path = os.path.join(class_path, img_name)
                shutil.move(img_path, os.path.join(general_class_path, img_name))
    
    # ----------------------------------------
    # Define image transformations
    # ----------------------------------------
    # Center-crop to 400x400, then resize to 200x200, and convert to tensor
    transform = transforms.Compose([
        transforms.CenterCrop((400, 400)),  # Crop center region of size 400x400
        transforms.Resize((200, 200)),      # Resize cropped image to 200x200
        transforms.ToTensor(),              # Convert PIL image to PyTorch tensor
    ])
    
    # ----------------------------------------
    # Load the combined dataset with ImageFolder
    # ----------------------------------------
    general_dataset = ImageFolder(root=general_dataset_path, transform=transform)
    targets = general_dataset.targets         # List of numeric class labels for each image
    classes = list(set(targets))              # Unique class labels present in the dataset
    
    # ----------------------------------------
    # Stratified train/validation/test split
    # ----------------------------------------
    train_indices = []   # Will hold indices for training samples
    val_indices = []     # Will hold indices for validation samples
    test_indices = []    # Will hold indices for test samples
    
    # Define split ratios: 70% train, 20% val, remainder (10%) test
    train_ratio = 0.7
    val_ratio = 0.2
    # test_ratio will implicitly be 0.1
    
    # For each class, split its indices into train/val/test
    for class_label in classes:
        # Collect all indices belonging to this class
        class_indices = [i for i, target in enumerate(targets) if target == class_label]
        class_size = len(class_indices)
        # Compute how many samples go to train and validation for this class
        train_size = int(train_ratio * class_size)
        val_size = int(val_ratio * class_size)
        # Assign the first train_size indices to training
        train_indices.extend(class_indices[:train_size])
        # Assign the next val_size indices to validation
        val_indices.extend(class_indices[train_size:train_size + val_size])
        # Assign the remaining indices to test
        test_indices.extend(class_indices[train_size + val_size:])
    
    # Create Subset objects for train, val, and test sets
    train_set = Subset(general_dataset, train_indices)
    val_set = Subset(general_dataset, val_indices)
    test_set = Subset(general_dataset, test_indices)
    
    # Create DataLoaders for validation and test sets
    # (no shuffling needed for val/test)
    val_loader = DataLoader(val_set, batch_size=64, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=64, shuffle=False)
    
    # ----------------------------------------
    # Non-IID partitioning of training data
    # ----------------------------------------
    # Define a fixed distribution for 4 classes across 4 clients
    # Each list is the percentage of that class assigned to each client
    distribution = {
        0: [0.70, 0.15, 0.10, 0.05],  # Class 0: client0=70%, client1=15%, client2=10%, client3=5%
        1: [0.15, 0.70, 0.10, 0.05],  # Class 1: client0=15%, client1=70%, client2=10%, client3=5%
        2: [0.10, 0.15, 0.70, 0.05],  # Class 2: client0=10%, client1=15%, client2=70%, client3=5%
        3: [0.05, 0.10, 0.15, 0.70]   # Class 3: client0=5%, client1=10%, client2=15%, client3=70%
    }
    
    # Initialize a dictionary to collect training indices per client
    client_indices = {client: [] for client in range(num_clients)}
    
    # For each class, assign training samples to clients according to distribution
    for class_label in classes:
        # Extract training indices for this class
        class_train_indices = [idx for idx in train_indices if general_dataset.targets[idx] == class_label]
        np.random.shuffle(class_train_indices)  # Shuffle indices before splitting
        
        n = len(class_train_indices)  # Total number of training samples for this class
        allocation = []
        # Calculate the number of samples per client for this class
        for client in range(num_clients):
            cnt = int(distribution[class_label][client] * n)
            allocation.append(cnt)
        # Adjust the last client's allocation to ensure the sum matches n
        allocation[-1] = n - sum(allocation[:-1])
        
        # Distribute the shuffled indices according to the calculated allocation
        start = 0
        for client in range(num_clients):
            cnt = allocation[client]
            # Assign a slice of class_train_indices to this client
            client_indices[client].extend(class_train_indices[start:start + cnt])
            start += cnt
    
    # Create a DataLoader for each client's training partition
    client_train_loaders = []
    for client in range(num_clients):
        subset = Subset(general_dataset, client_indices[client])
        loader = DataLoader(subset, batch_size=64, shuffle=True)
        client_train_loaders.append(loader)
    
    # Return the list of client-specific train loaders, plus shared val_loader and test_loader
    return client_train_loaders, val_loader, test_loader

# Compute data loaders for 4 clients
client_train_loaders, val_loader, test_loader = data_preprocessing_tumor_NON_IID(num_clients=4)


In [7]:
def get_model():
    """
    Instantiate a new BrainCNN model and move it to the configured device (CPU or GPU).
    """
    return BrainCNN().to(device)


def get_optimizer(model):
    """
    Create and return an optimizer for the provided model.
    
    Notes on experiment history (do not modify these commented lines):
        _: lr = 0.001, weight_decay=0.001
        #return optim.SGD(model.parameters(), lr=0.0008, momentum=0.7, weight_decay=0.09)  # 1 iid final
    The active configuration below was chosen for the current experiments:
    """
    return optim.SGD(model.parameters(), lr=0.001, momentum=0.7, weight_decay=0.09)


def get_loss_function():
    """
    Return the loss function to be used during training and evaluation.
    Using CrossEntropyLoss for multi-class classification.
    """
    return nn.CrossEntropyLoss()


def fit_config(server_round: int):
    """
    Return the training configuration for each federated round.
    
    Args:
        server_round (int): Current federated learning round (not used here).
    
    Returns:
        dict: Contains the number of local epochs each client should perform.
    """
    return {"local_epochs": 5}  # 5 local epochs per client per round


def evaluate_fn(server_round, parameters, config):
    """
    Evaluate the global model on the validation set after receiving updated parameters.
    
    Args:
        server_round (int): Current federated learning round (unused here).
        parameters (list of np.ndarray): Model weights from the server to load into a fresh model.
        config (dict): Additional configuration (unused here).
    
    Returns:
        Tuple containing:
            - val_loss (float): The average loss on the validation dataset.
            - metrics (dict): Dictionary with keys "accuracy", "precision", "recall", "f1".
    
    Procedure:
        1. Recreate a new model and load the received parameters into it.
        2. Set the model to evaluation mode.
        3. Iterate through the global val_loader (assumed defined globally), accumulate loss and predictions.
        4. Compute average loss and classification metrics.
    """
    # 1. Instantiate a fresh model and move it to the correct device
    model = get_model()
    # Map the list of weight arrays back to the model’s state_dict
    state_dict = {k: torch.tensor(v) for k, v in zip(model.state_dict().keys(), parameters)}
    model.load_state_dict(state_dict)
    model.eval()  # Switch to evaluation mode (disables dropout, etc.)

    # 2. Set up loss function
    criterion = get_loss_function()
    val_loss = 0.0
    all_preds = []
    all_targets = []

    # 3. Iterate over validation batches (assumes val_loader is defined globally)
    with torch.no_grad():  # Disable gradient computation
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            # Accumulate weighted loss by batch size
            val_loss += loss.item() * data.size(0)

            # Get predicted class indices
            _, predicted = torch.max(output, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())

    # 4. Compute average validation loss across all samples
    val_loss /= len(val_loader.dataset)

    # Compute accuracy, precision, recall, and F1 (macro-averaged)
    accuracy = (np.array(all_preds) == np.array(all_targets)).mean()
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_targets, all_preds, average='macro', zero_division=0
    )

    # Return the computed validation loss and metrics dictionary
    return val_loss, {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1
    }


In [8]:
from flwr.common import Context

def client_fn(context: Context):
    """
    Factory function to create a Flower client instance based on the provided context.
    Args:
        context (Context): Flower context containing node_id and configuration.
    Returns:
        flower.client.NumPyClient: A configured Flower client for federated training.
    """
    # Determine which client index to use based on the node ID provided by Flower.
    # We mod by the number of client_train_loaders to cycle through them.
    # Alternative approaches (commented out) could extract a custom cid from context.node_config.
    # cid = context.node_config["cid"]
    # cid = context.node_id
    cid = int(context.node_id) % len(client_train_loaders)

    # Initialize a fresh BrainCNN model and move it to the configured device (CPU or GPU).
    model = BrainCNN().to(device)
    # Create an optimizer for the model using the chosen hyperparameters.
    optimizer = get_optimizer(model)
    # Define the loss function to use during training and evaluation.
    criterion = get_loss_function()
    # Select the local training DataLoader corresponding to this client.
    train_loader = client_train_loaders[int(cid)]

    class FlowerClient(fl.client.NumPyClient):
        """
        Implementation of the Flower NumPyClient interface for federated learning.
        Each method defines how the client exchanges parameters, trains locally (fit),
        and performs local evaluation.
        """

        def get_parameters(self, config):
            """
            Return the current local model parameters as a list of NumPy arrays.
            Flower will gather these to send to the server.
            Args:
                config (dict): Configuration dictionary provided by the server (unused here).
            Returns:
                List[np.ndarray]: Model weights converted to NumPy arrays.
            """
            return [val.cpu().numpy() for _, val in model.state_dict().items()]

        def fit(self, parameters, config):
            """
            Perform local training on this client's data, optionally including a proximal term
            for FedProx-style regularization.
            Args:
                parameters (List[np.ndarray]): Global model weights received from the server.
                config (dict): Configuration dictionary containing "local_epochs" and possibly "proximal_mu".
            Returns:
                Tuple[List[np.ndarray], int, dict]:
                    - Updated model parameters after local training
                    - Number of training examples used
                    - Metrics dictionary (empty here)
            """
            # 1. Load the global parameters into the local model
            state_dict = {
                k: torch.tensor(v) 
                for k, v in zip(model.state_dict().keys(), parameters)
            }
            model.load_state_dict(state_dict)
            model.train()  # Switch model to training mode

            # 2. Read proximal_mu from config to control FedProx penalty term
            proximal_mu = config.get("proximal_mu", 0.0)

            # 3. If using FedProx (proximal_mu > 0), save a copy of global parameters
            if proximal_mu > 0:
                # Clone and detach global parameters to use as reference during local updates
                global_params = [p.clone().detach() for p in model.parameters()]
            else:
                global_params = None

            # 4. Training loop: run the specified number of local epochs
            for _ in range(config["local_epochs"]):
                for data, target in train_loader:
                    data, target = data.to(device), target.to(device)
                    optimizer.zero_grad()       # Zero out gradients before each batch
                    output = model(data)        # Forward pass
                    loss = criterion(output, target)  # Compute classification loss

                    # 5. If FedProx is enabled, add the proximal term to the loss
                    if proximal_mu > 0:
                        proximal_term = 0.0
                        # Sum squared differences between local and global parameters
                        for local_param, global_param in zip(model.parameters(), global_params):
                            proximal_term += (local_param - global_param).pow(2).sum()
                        # Add the scaled proximal term to the original loss
                        loss += (proximal_mu / 2) * proximal_term

                    loss.backward()  # Backpropagate gradients
                    optimizer.step()  # Update model parameters

            # After training, return the updated weights and the number of examples used
            return self.get_parameters(config), len(train_loader.dataset), {}

        def evaluate(self, parameters, config):
            """
            Perform local evaluation on this client's training set (used as proxy here).
            Args:
                parameters (List[np.ndarray]): Global model weights from the server.
                config (dict): Configuration dictionary (unused for evaluation).
            Returns:
                Tuple[float, int, dict]:
                    - Average loss over the local dataset
                    - Number of examples used for evaluation
                    - Metrics dictionary (accuracy placeholder here)
            """
            # 1. Load the received global parameters into the model
            state_dict = {
                k: torch.tensor(v) 
                for k, v in zip(model.state_dict().keys(), parameters)
            }
            model.load_state_dict(state_dict)
            model.eval()  # Switch model to evaluation mode

            loss = 0.0
            num_examples = 0

            # 2. Accumulate loss over all local data without computing gradients
            with torch.no_grad():
                for data, target in train_loader:
                    data, target = data.to(device), target.to(device)
                    output = model(data)
                    loss += criterion(output, target).item()
                    num_examples += data.size(0)  # Count examples processed

            # 3. Compute average loss; avoid division by zero if loader is empty
            avg_loss = loss / len(train_loader) if len(train_loader) > 0 else 0.0
            # Placeholder metrics dictionary; replace with actual accuracy computation if desired
            metrics = {"accuracy": 0.0}

            return avg_loss, num_examples, metrics

    # Instantiate and return the Flower client object for this federated node
    return FlowerClient().to_client()


In [9]:
class CustomFedProx(fl.server.strategy.FedProx):
    """
    Custom FedProx strategy that extends Flower's built-in FedProx by adding:
    - Tracking of validation metrics history
    - Early stopping based on validation loss threshold
    """
    def __init__(self, proximal_mu=0.1, *args, **kwargs):
        """
        Args:
            proximal_mu (float): Coefficient for the FedProx proximal term.
            *args, **kwargs: Additional arguments passed to the parent FedProx constructor.
        """
        # Initialize the base FedProx strategy with the given proximal_mu and any other args
        super().__init__(proximal_mu=proximal_mu, *args, **kwargs)
        
        # Initialize a dictionary to store metrics for each round:
        # - Validation loss, accuracy, precision, recall, and F1 score
        self.metrics_history = {
            "val_loss": [],
            "accuracy": [],
            "precision": [],
            "recall": [],
            "f1": []
        }
        
        # Will hold the best or last model parameters depending on early stopping
        self.final_parameters = None
        
        # Create an EarlyStopping instance with high patience and small delta for fine-grained stopping
        self.early_stopping = EarlyStopping(
            patience=40,        # Number of rounds to wait without improvement
            delta=0.00001,      # Minimum change in validation loss to count as improvement
            threshold=0.0001    # Absolute threshold for validation loss to stop immediately
        )
    
    def evaluate(self, server_round, parameters):
        """
        Override the FedProx evaluate method to:
        1. Call the parent evaluate to compute (val_loss, metrics) for the global model
        2. Append those metrics to metrics_history
        3. Load the global model with current parameters and run early stopping check
        4. If early stopping is triggered, save best parameters and raise StopIteration
        """
        # Call the base class evaluate, which uses evaluate_fn to compute loss & metrics
        result = super().evaluate(server_round, parameters)
        
        # If evaluation returned valid results (i.e., validation set exists)
        if result:
            loss, metrics = result
            
            # Append the new validation loss and other metrics to history
            self.metrics_history["val_loss"].append(loss)
            for key in metrics:
                self.metrics_history[key].append(metrics[key])
            
            # Save the current parameters as the most recent parameters
            self.final_parameters = parameters
            
            # Recreate a fresh global model and load the new parameters into it
            global_model = get_model()  # get_model() returns a BrainCNN instance on the correct device
            # Convert the Flower parameters to numpy arrays and build a state_dict
            final_ndarrays = parameters_to_ndarrays(parameters)
            state_dict = {
                k: torch.tensor(v) 
                for k, v in zip(global_model.state_dict().keys(), final_ndarrays)
            }
            global_model.load_state_dict(state_dict)
            
            # Check early stopping with the current validation loss and model
            self.early_stopping(loss, global_model)
            if self.early_stopping.early_stop:
                print(f"Early stopping triggered at round {server_round}.")
                
                # Retrieve the best model state recorded by EarlyStopping
                best_state_dict = self.early_stopping.best_model_state
                # Convert that state_dict into a list of NumPy arrays in the same key order as the model
                best_parameters = [
                    best_state_dict[k].cpu().numpy() 
                    for k in global_model.state_dict().keys()
                ]
                # Update final_parameters to point to the best weights found so far
                self.final_parameters = best_parameters
                
                # Raise StopIteration to instruct Flower to halt federated training early
                raise StopIteration("Early stopping triggered.")
        
        # Return the result (val_loss and metrics) if early stopping did not occur
        return result


In [10]:
strategy = CustomFedProx(
    proximal_mu=0.9,  # Strength of the proximal term; adjust as needed
    fraction_fit=1.0,
    min_fit_clients=4,
    min_available_clients=4,
    evaluate_fn=evaluate_fn,
    on_fit_config_fn=fit_config
)

In [11]:
try:
    history = fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=4,
        config=fl.server.ServerConfig(num_rounds=72),
        strategy=strategy,
        client_resources={"num_cpus": 2, "num_gpus": 0.5},
        ray_init_args={
            "num_cpus": 16,
            "object_store_memory": 40 * 1024**3
        }
    )
except StopIteration as e:
    print(e)
print("Federated learning simulation completed.")


In [12]:
# Define the range of rounds for plotting, starting from 1 to the length of the accuracy history
rounds = range(1, len(strategy.metrics_history['accuracy']) + 1)

# Create a new figure with a specified size (12 inches wide, 8 inches tall)
plt.figure(figsize=(12, 8))

# Plot validation loss over rounds, using the metrics stored in strategy.metrics_history
plt.plot(rounds, strategy.metrics_history['val_loss'], label='Validation Loss')

# Plot accuracy over rounds, using the metrics stored in strategy.metrics_history
plt.plot(rounds, strategy.metrics_history['accuracy'], label='Accuracy')

# Plot precision over rounds, using the metrics stored in strategy.metrics_history
plt.plot(rounds, strategy.metrics_history['precision'], label='Precision')

# Plot recall over rounds, using the metrics stored in strategy.metrics_history
plt.plot(rounds, strategy.metrics_history['recall'], label='Recall')

# Plot F1 score over rounds, using the metrics stored in strategy.metrics_history
plt.plot(rounds, strategy.metrics_history['f1'], label='F1 Score')

# Label the x-axis as 'Round' to indicate the federated learning rounds
plt.xlabel('Round')

# Label the y-axis as 'Metric Value' to represent the values of the plotted metrics
plt.ylabel('Metric Value')

# Set the title of the plot to describe the context of the metrics
plt.title('Federated Learning Metrics Over Rounds')

# Add a legend to distinguish between the different metric lines
plt.legend()

# Enable a grid in the plot for better readability of the metric values
plt.grid(True)

# Save the plot to a file with high resolution (300 DPI) before displaying it
plt.savefig('/home/ir739wb/ilyarekun/bc_project/federated-learning/outputs/fed-prox-non-iid-graph1.png', dpi=300)

# Display the plot (useful in interactive environments like Jupyter notebooks)
plt.show()

# Close the figure to free memory and prevent overlap with future plots
plt.close()

In [13]:
# Import necessary libraries for federated learning, PyTorch, NumPy, and metrics computation
from flwr.common import parameters_to_ndarrays  # Import function to convert Flower Parameters to NumPy arrays
import torch  # Import PyTorch for model operations
import numpy as np  # Import NumPy for numerical operations
from sklearn.metrics import precision_recall_fscore_support  # Import function to compute precision, recall, and F1 score

# Check if early stopping was triggered to determine which model parameters to use
if strategy.early_stopping.early_stop:
    print("Using the best model parameters from early stopping.")  # Inform user that early stopping parameters are used
    best_parameters = strategy.final_parameters  # Use parameters saved by early stopping
else:
    print("Early stopping was not triggered. Using the final round's parameters.")  # Inform user that final round parameters are used
    best_parameters = strategy.final_parameters  # Use parameters from the last round

# Create a new instance of the model (assumes get_model() returns BrainCNN().to(device))
model = get_model()  # Instantiate the BrainCNN model and move it to the appropriate device (CPU/GPU)

# Convert the federated learning parameters (Flower Parameters) to a list of NumPy arrays
final_ndarrays = parameters_to_ndarrays(best_parameters)  # Transform parameters to NumPy format for compatibility

# Create a state dictionary by mapping model parameter keys to corresponding NumPy arrays converted to PyTorch tensors
state_dict = {k: torch.tensor(v) for k, v in zip(model.state_dict().keys(), final_ndarrays)}

# Load the state dictionary into the model to set its weights
model.load_state_dict(state_dict)

# Set the model to evaluation mode (disables dropout and batch normalization updates)
model.eval()

# Initialize lists to store predictions and true labels for evaluation
all_preds = []  # List to store predicted class labels
all_targets = []  # List to store true class labels

# Evaluate the model on the test set without computing gradients to save memory
with torch.no_grad():
    for data, target in test_loader:  # Iterate over batches in the test data loader
        data, target = data.to(device), target.to(device)  # Move data and targets to the appropriate device (CPU/GPU)
        output = model(data)  # Perform a forward pass to get model predictions
        _, predicted = torch.max(output, 1)  # Get the predicted class by selecting the index with the highest score
        all_preds.extend(predicted.cpu().numpy())  # Append predictions to the list (move to CPU and convert to NumPy)
        all_targets.extend(target.cpu().numpy())  # Append true labels to the list (move to CPU and convert to NumPy)

# Compute evaluation metrics
accuracy = (np.array(all_preds) == np.array(all_targets)).mean()  # Calculate accuracy as the mean of correct predictions
precision, recall, f1, _ = precision_recall_fscore_support(all_targets, all_preds, average='macro', zero_division=0)  # Compute macro-averaged precision, recall, and F1 score

# Print the test metrics with 4 decimal places for clarity
print(f"Test Accuracy: {accuracy:.4f}")  # Display test set accuracy
print(f"Test Precision: {precision:.4f}")  # Display test set precision
print(f"Test Recall: {recall:.4f}")  # Display test set recall
print(f"Test F1 Score: {f1:.4f}")  # Display test set F1 score

# Define the file path to save the metrics
metrics_file = '/home/ir739wb/ilyarekun/bc_project/federated-learning/outputs/fed-prox-non-iid-metrics1.txt'

# Save metrics to a text file, including per-round metrics and final test metrics
with open(metrics_file, 'w') as f:
    rounds = range(1, len(strategy.metrics_history['val_loss']) + 1)  # Define the range of rounds based on validation loss history
    for round_num in rounds:  # Iterate over each round
        f.write(f"Round {round_num}:\n")  # Write the round number
        f.write(f"  Validation Loss: {strategy.metrics_history['val_loss'][round_num-1]:.4f}\n")  # Write validation loss for the round
        f.write(f"  Accuracy: {strategy.metrics_history['accuracy'][round_num-1]:.4f}\n")  # Write accuracy for the round
        f.write(f"  Precision: {strategy.metrics_history['precision'][round_num-1]:.4f}\n")  # Write precision for the round
        f.write(f"  Recall: {strategy.metrics_history['recall'][round_num-1]:.4f}\n")  # Write recall for the round
        f.write(f"  F1 Score: {strategy.metrics_history['f1'][round_num-1]:.4f}\n")  # Write F1 score for the round
    f.write("\nTest Metrics:\n")  # Add a header for final test metrics
    f.write(f"  Accuracy: {accuracy:.4f}\n")  # Write test set accuracy
    f.write(f"  Precision: {precision:.4f}\n")  # Write test set precision
    f.write(f"  Recall: {recall:.4f}\n")  # Write test set recall
    f.write(f"  F1 Score: {f1:.4f}\n")  # Write test set F1 score

# Confirm that the metrics have been saved to the specified file
print(f"Metrics saved to '{metrics_file}'")