In [1]:
!cd /home/ir739wb/ilyarekun/bc_project/federated-learning/



In [2]:
import os
import flwr as fl
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import ImageFolder
import numpy as np
import kagglehub
import shutil
from sklearn.metrics import precision_recall_fscore_support
import matplotlib.pyplot as plt
from flwr.common import parameters_to_ndarrays

# Set seeds for reproducibility
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Define device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda:0


In [3]:
class BrainCNN(nn.Module):
    def __init__(self):
        super(BrainCNN, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.45),
            nn.Conv2d(64, 128, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.45),
            nn.Conv2d(128, 128, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2),
            nn.Dropout2d(p=0.45),
            nn.Conv2d(128, 256, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.45),
            nn.Conv2d(256, 256, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(2),
            nn.Dropout2d(p=0.4),
            nn.Conv2d(256, 512, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.4)
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(512 * 3 * 3, 1024),
            nn.ReLU(),
            nn.Dropout(p=0.4),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(p=0.4),
            nn.Linear(512, 4),
        )
        
    def forward(self, x):
        out = self.conv_layers(x)
        out = out.view(out.size(0), -1)  # Flatten
        out = self.fc_layers(out)
        return out

In [4]:
class EarlyStopping:
    """
    EarlyStopping utility to stop training when validation loss stops improving
    or reaches a specified threshold.
    """
    def __init__(self, patience=5, delta=0, threshold=0.19):
        """
        Args:
            patience (int): Number of epochs to wait after last improvement before stopping.
            delta (float): Minimum change in validation loss to consider as an improvement.
            threshold (float): Absolute validation loss threshold; stop immediately if reached or below.
        """
        self.patience = patience              # How many epochs to wait for improvement
        self.delta = delta                    # Minimum improvement in validation loss
        self.threshold = threshold            # Immediate stop if val_loss <= threshold
        self.best_score = None                # Best score seen so far (negative val_loss)
        self.early_stop = False               # Flag to indicate stopping
        self.counter = 0                      # Counter for epochs without improvement
        self.best_model_state = None          # State dict of the best model

    def __call__(self, val_loss, model):
        """
        Call method to check if training should stop.
        
        Args:
            val_loss (float): Current validation loss.
            model (nn.Module): Model being trained; used to save best state.
        """
        # If validation loss is below or equal to the threshold, stop immediately
        if val_loss <= self.threshold:
            print(f"Val loss {val_loss:.5f} is below threshold {self.threshold}. Stopping training.")
            self.early_stop = True
            # Save the current model state as the best
            self.best_model_state = model.state_dict()
            return

        # Convert validation loss to a score (we want to maximize -val_loss)
        score = -val_loss

        # If this is the first call, initialize best_score and best_model_state
        if self.best_score is None:
            self.best_score = score
            self.best_model_state = model.state_dict()
        # If no sufficient improvement, increment counter
        elif score < self.best_score + self.delta:
            self.counter += 1
            # If patience exceeded, set early_stop flag
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            # Improvement found: update best_score and reset counter
            self.best_score = score
            self.best_model_state = model.state_dict()
            self.counter = 0

    def load_best_model(self, model):
        """
        Load the best saved model state into the provided model.
        
        Args:
            model (nn.Module): Model into which to load the best state.
        """
        model.load_state_dict(self.best_model_state)


In [5]:
import os
import kagglehub
import shutil
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import ImageFolder
import numpy as np

def data_preprocessing_tumor_IID(num_clients=4):
    """
    Download the brain tumor MRI dataset from Kaggle, merge training and testing folders
    into a single directory, perform a stratified train/val/test split, and then
    partition the training data evenly (IID) among a specified number of clients.
    
    Args:
        num_clients (int): Number of clients to split the training set into. Default is 4.
    
    Returns:
        client_train_loaders (list of DataLoader): List of DataLoader objects, one per client.
        val_loader (DataLoader): DataLoader for the validation set.
        test_loader (DataLoader): DataLoader for the test set.
    """
    # Download the dataset from Kaggle and get the local path
    dataset_path = kagglehub.dataset_download("masoudnickparvar/brain-tumor-mri-dataset")
    
    # Define paths to the original Training and Testing folders
    train_path = os.path.join(dataset_path, "Training")
    test_path = os.path.join(dataset_path, "Testing")
    
    # Create a new directory to hold all images together (General_Dataset)
    general_dataset_path = os.path.join(dataset_path, "General_Dataset")
    os.makedirs(general_dataset_path, exist_ok=True)
    
    # Move images from Training and Testing into General_Dataset, preserving class subfolders
    for source_path in [train_path, test_path]:
        for class_name in os.listdir(source_path):
            class_path = os.path.join(source_path, class_name)
            general_class_path = os.path.join(general_dataset_path, class_name)
            os.makedirs(general_class_path, exist_ok=True)
            for img_name in os.listdir(class_path):
                img_path = os.path.join(class_path, img_name)
                shutil.move(img_path, os.path.join(general_class_path, img_name))
    
    # Define image transformations: center crop, resize, then convert to tensor
    transform = transforms.Compose([
        transforms.CenterCrop((400, 400)),  # Crop the image to 400x400 pixels from the center
        transforms.Resize((200, 200)),      # Resize cropped image to 200x200 pixels
        transforms.ToTensor(),              # Convert PIL Image to PyTorch tensor
    ])
    
    # Load all images from General_Dataset via ImageFolder (expects subfolders = class names)
    general_dataset = ImageFolder(root=general_dataset_path, transform=transform)
    targets = general_dataset.targets             # List of class labels for each image
    classes = list(set(targets))                  # Unique class labels present in the dataset
    
    # Initialize lists to hold indices for train/val/test splits
    train_indices = []
    val_indices = []
    test_indices = []
    
    # Ratios for splitting each class: 70% train, 20% validation, 10% test
    train_ratio = 0.7
    val_ratio = 0.2
    test_ratio = 0.1
    
    # Perform a stratified split: for each class, allocate indices proportionally
    for class_label in classes:
        # Collect indices of all samples belonging to the current class
        class_indices = [i for i, target in enumerate(targets) if target == class_label]
        class_size = len(class_indices)
        
        # Determine how many samples go to train, val, and test
        train_size = int(train_ratio * class_size)
        val_size = int(val_ratio * class_size)
        test_size = class_size - train_size - val_size  # Remainder goes to test
        
        # Assign the first train_size indices to the training split
        train_indices.extend(class_indices[:train_size])
        # Assign the next val_size indices to the validation split
        val_indices.extend(class_indices[train_size:train_size + val_size])
        # Assign the remaining indices to the test split
        test_indices.extend(class_indices[train_size + val_size:])
    
    # Create Subset objects for each split using the computed indices
    train_set = Subset(general_dataset, train_indices)
    val_set = Subset(general_dataset, val_indices)
    test_set = Subset(general_dataset, test_indices)
    
    # DataLoader for validation and test sets (no shuffling needed)
    val_loader = DataLoader(val_set, batch_size=64, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=64, shuffle=False)
    
    # Prepare to split the training set into num_clients disjoint subsets (IID)
    # Initialize a dictionary mapping client_id -> list of sample indices
    client_indices = {client: [] for client in range(num_clients)}
    
    # For each class, shuffle its training indices and split them equally among clients
    for class_label in classes:
        # Get all training indices that belong to the current class
        class_train_indices = [idx for idx in train_indices if general_dataset.targets[idx] == class_label]
        np.random.shuffle(class_train_indices)  # Shuffle to randomize assignment
        
        # Split this class's indices into num_clients roughly equal parts
        splits = np.array_split(class_train_indices, num_clients)
        
        # Assign each split to the corresponding client's index list
        for client in range(num_clients):
            client_indices[client].extend(splits[client].tolist())
    
    # Create a DataLoader for each client's subset of the training data
    client_train_loaders = []
    for client in range(num_clients):
        subset = Subset(general_dataset, client_indices[client])
        loader = DataLoader(subset, batch_size=64, shuffle=True)
        client_train_loaders.append(loader)
    
    # Return the list of client-specific train loaders, and shared val/test loaders
    return client_train_loaders, val_loader, test_loader

# Example usage (commented out on purpose):
# client_train_loaders, val_loader, test_loader = data_preprocessing_tumor_IID(num_clients=4)


In [6]:
def data_preprocessing_tumor_NON_IID(num_clients=4):
    # Download and prepare the dataset
    dataset_path = kagglehub.dataset_download("masoudnickparvar/brain-tumor-mri-dataset")
    train_path = os.path.join(dataset_path, "Training")
    test_path = os.path.join(dataset_path, "Testing")
    general_dataset_path = os.path.join(dataset_path, "General_Dataset")
    os.makedirs(general_dataset_path, exist_ok=True)
    
    # Merge Training and Testing folders into a single General_Dataset folder by class
    for source_path in [train_path, test_path]:
        for class_name in os.listdir(source_path):
            class_path = os.path.join(source_path, class_name)
            general_class_path = os.path.join(general_dataset_path, class_name)
            os.makedirs(general_class_path, exist_ok=True)
            for img_name in os.listdir(class_path):
                img_path = os.path.join(class_path, img_name)
                shutil.move(img_path, os.path.join(general_class_path, img_name))
    
    # Define image transformations
    transform = transforms.Compose([
        transforms.CenterCrop((400, 400)),
        transforms.Resize((200, 200)),
        transforms.ToTensor(),
    ])
    
    # Create the PyTorch dataset from the general dataset folder
    general_dataset = ImageFolder(root=general_dataset_path, transform=transform)
    targets = general_dataset.targets
    classes = list(set(targets))
    
    # Split indices into train/validation/test sets (stratified by class)
    train_indices = []
    val_indices = []
    test_indices = []
    train_ratio = 0.7
    val_ratio = 0.2
    # test_ratio will be the remainder (0.1)
    
    for class_label in classes:
        class_indices = [i for i, target in enumerate(targets) if target == class_label]
        class_size = len(class_indices)
        train_size = int(train_ratio * class_size)
        val_size = int(val_ratio * class_size)
        # Assign the first portion for training, next for validation, rest for test
        train_indices.extend(class_indices[:train_size])
        val_indices.extend(class_indices[train_size:train_size + val_size])
        test_indices.extend(class_indices[train_size + val_size:])
    
    # Create subsets and dataloaders for validation and test sets
    train_set = Subset(general_dataset, train_indices)
    val_set = Subset(general_dataset, val_indices)
    test_set = Subset(general_dataset, test_indices)
    
    val_loader = DataLoader(val_set, batch_size=64, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=64, shuffle=False)
    
    # ===== Non-IID Train Data Partitioning =====
    # Distribution table (for 4 classes and 4 clients):
    # For class 0: client0:70%, client1:15%, client2:10%, client3:5%
    # For class 1: client0:15%, client1:70%, client2:10%, client3:5%
    # For class 2: client0:10%, client1:15%, client2:70%, client3:5%
    # For class 3: client0:5%,  client1:10%, client2:15%, client3:70%
    # The keys are class labels and each value is a list of percentages for each client.
    distribution = {
        0: [0.70, 0.15, 0.10, 0.05],
        1: [0.15, 0.70, 0.10, 0.05],
        2: [0.10, 0.15, 0.70, 0.05],
        3: [0.05, 0.10, 0.15, 0.70]
    }
    
    # Initialize a dictionary to hold the train indices for each client
    client_indices = {client: [] for client in range(num_clients)}
    
    # For each class, distribute the training samples among clients according to the distribution
    for class_label in classes:
        # Get all training indices for the given class
        class_train_indices = [idx for idx in train_indices if general_dataset.targets[idx] == class_label]
        np.random.shuffle(class_train_indices)
        
        n = len(class_train_indices)
        # Compute allocation counts for each client for this class
        allocation = []
        for client in range(num_clients):
            cnt = int(distribution[class_label][client] * n)
            allocation.append(cnt)
        # Adjust the last client allocation to account for rounding errors
        allocation[-1] = n - sum(allocation[:-1])
        
        start = 0
        for client in range(num_clients):
            cnt = allocation[client]
            client_indices[client].extend(class_train_indices[start:start + cnt])
            start += cnt
    
    # Create DataLoaders for each client's training subset
    client_train_loaders = []
    for client in range(num_clients):
        subset = Subset(general_dataset, client_indices[client])
        loader = DataLoader(subset, batch_size=64, shuffle=True)
        client_train_loaders.append(loader)
    
    return client_train_loaders, val_loader, test_loader

# Compute data loaders for 4 clients
client_train_loaders, val_loader, test_loader = data_preprocessing_tumor_NON_IID(num_clients=4)





In [None]:
def get_model():
    """
    Instantiate a new BrainCNN model and move it to the configured device (CPU or GPU).
    """
    return BrainCNN().to(device)


def get_optimizer(model):
    """
    Create and return an optimizer for training the provided model.
    
    Notes on experiment history (do not modify these commented lines):
        _: lr = 0.001, weight_decay=0.001
        #return optim.SGD(model.parameters(), lr=0.0008, momentum=0.9, weight_decay=0.09)
        #return optim.SGD(model.parameters(), lr=0.0007, momentum=0.9, weight_decay=0.09)
        #return optim.SGD(model.parameters(), lr=0.0009, momentum=0.9, weight_decay=0.05)
        #return optim.SGD(model.parameters(), lr=0.0009, momentum=0.8, weight_decay=0.07) -- iid final
        #return optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.09) - 1
    The active configuration below was chosen for non-IID experiments:
    """
    return optim.SGD(model.parameters(), lr=0.0008, momentum=0.7, weight_decay=0.09)  # -- non iid final


def get_loss_function():
    """
    Return the loss function to be used during training and evaluation.
    Using CrossEntropyLoss for multi-class classification.
    """
    return nn.CrossEntropyLoss()


def fit_config(server_round: int):
    """
    Return training configuration for federated learning.
    In this case, each client will perform 5 local epochs per round.
    
    Args:
        server_round (int): Current round index on the server (unused here).
    
    Returns:
        dict: Contains "local_epochs" to be used by Flower.
    """
    return {"local_epochs": 5}  # 5 local epochs per client per round


def evaluate_fn(server_round, parameters, config):
    """
    Evaluate the global model on the validation set after receiving updated parameters.
    
    Args:
        server_round (int): Current federated learning round (unused here but included by Flower API).
        parameters (list of np.ndarray): Model weights from the server to load into a fresh model.
        config (dict): Additional configuration (unused here).
    
    Returns:
        Tuple containing:
            - val_loss (float): The average loss on the validation dataset.
            - metrics (dict): Dictionary with keys "accuracy", "precision", "recall", "f1".
    
    Procedure:
        1. Recreate a new model and load the parameters into it.
        2. Set the model to evaluation mode.
        3. Iterate through the global val_loader (assumed defined globally), accumulate loss and predictions.
        4. Compute average loss and classification metrics.
    """
    # 1. Instantiate a fresh model and load received parameters
    model = get_model()
    # Map the list of ndarrays back to the model's state_dict
    state_dict = {k: torch.tensor(v) for k, v in zip(model.state_dict().keys(), parameters)}
    model.load_state_dict(state_dict)
    model.eval()  # Switch to evaluation mode (disable dropout, batchnorm updates, etc.)

    # 2. Set up loss function
    criterion = get_loss_function()
    val_loss = 0.0
    all_preds = []
    all_targets = []

    # 3. Iterate over validation batches (assumes val_loader is defined globally)
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            val_loss += loss.item() * data.size(0)  # Accumulate weighted by batch size

            # Get predicted class indices
            _, predicted = torch.max(output, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())

    # 4. Compute average validation loss across all samples
    val_loss /= len(val_loader.dataset)

    # 5. Compute accuracy, precision, recall, and F1 (macro-averaged)
    accuracy = (np.array(all_preds) == np.array(all_targets)).mean()
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_targets, all_preds, average='macro', zero_division=0
    )

    # Return loss and a dictionary of metrics
    return val_loss, {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1
    }


In [None]:
from flwr.common import Context

def client_fn(context: Context):
    """
    Factory function to create a Flower client instance based on the node context.
    Args:
        context (Context): Flower-provided context containing node ID and configuration.
    Returns:
        fl.client.NumPyClient: A Flower client configured for the appropriate local data.
    """
    # Determine which client index to use based on the context's node_id.
    # We mod by the number of available client_train_loaders to cycle through them.
    # Uncomment alternatives if you want to extract client ID differently:
    # cid = context.node_config["cid"]
    # cid = context.node_id
    cid = int(context.node_id) % len(client_train_loaders)

    # Initialize a fresh BrainCNN model and move it to the configured device (CPU/GPU).
    model = BrainCNN().to(device)
    # Create an optimizer for this model using the chosen hyperparameters.
    optimizer = get_optimizer(model)
    # Define the loss function (CrossEntropyLoss in this case).
    criterion = get_loss_function()
    # Select the local training DataLoader corresponding to this client ID.
    train_loader = client_train_loaders[int(cid)]
    
    class FlowerClient(fl.client.NumPyClient):
        """
        Implementation of the Flower NumPyClient interface for federated learning.
        Each method below defines how the client handles parameter exchange,
        local training (fit), and local evaluation.
        """

        def get_parameters(self, config):
            """
            Return the current local model parameters as a list of NumPy arrays.
            Flower will send these arrays to the server.
            Args:
                config (dict): Configuration dictionary provided by the server (unused here).
            Returns:
                List[np.ndarray]: Model weights converted to NumPy arrays.
            """
            return [val.cpu().numpy() for _, val in model.state_dict().items()]

        def fit(self, parameters, config):
            """
            Perform local training on this client's data.
            Args:
                parameters (List[np.ndarray]): Global model weights from the server.
                config (dict): Configuration dictionary (includes "local_epochs").
            Returns:
                Tuple[List[np.ndarray], int, dict]: 
                    - Updated model parameters,
                    - Number of training examples used,
                    - Optional metrics dictionary (empty here).
            """
            # Load the global parameters into the local model
            state_dict = {
                k: torch.tensor(v) 
                for k, v in zip(model.state_dict().keys(), parameters)
            }
            model.load_state_dict(state_dict)
            model.train()  # Switch to training mode

            # Run the specified number of local epochs
            for _ in range(config["local_epochs"]):
                for data, target in train_loader:
                    data, target = data.to(device), target.to(device)
                    optimizer.zero_grad()
                    output = model(data)
                    loss = criterion(output, target)
                    loss.backward()
                    optimizer.step()

            # After training, return updated weights and dataset size
            return self.get_parameters(config), len(train_loader.dataset), {}

        def evaluate(self, parameters, config):
            """
            Perform local evaluation on this client's training set (used as a proxy here).
            Args:
                parameters (List[np.ndarray]): Global model weights from the server.
                config (dict): Configuration dictionary (unused for evaluation).
            Returns:
                Tuple[float, int, dict]:
                    - Average loss over the local dataset,
                    - Number of examples used for evaluation,
                    - Metrics dictionary (accuracy placeholder here).
            """
            # Load the server-provided parameters into the local model
            state_dict = {
                k: torch.tensor(v)
                for k, v in zip(model.state_dict().keys(), parameters)
            }
            model.load_state_dict(state_dict)
            model.eval()  # Switch to evaluation mode

            loss = 0.0
            num_examples = 0

            # Accumulate loss over all local data
            with torch.no_grad():
                for data, target in train_loader:
                    data, target = data.to(device), target.to(device)
                    output = model(data)
                    loss += criterion(output, target).item()
                    num_examples += data.size(0)  # Track total examples count

            # Compute average loss; avoid division by zero if loader is empty
            avg_loss = loss / len(train_loader) if len(train_loader) > 0 else 0.0
            metrics = {"accuracy": 0.0}  # Placeholder; replace with actual accuracy if desired

            return avg_loss, num_examples, metrics

    # Instantiate and return the Flower client object for this node
    return FlowerClient().to_client()


In [None]:
class CustomFedAvg(fl.server.strategy.FedAvg):
    """
    Custom federated averaging strategy that extends Flower's FedAvg and adds:
    - Tracking of validation metrics history
    - Early stopping based on validation loss threshold
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        # Initialize a history dictionary to store validation loss and other metrics over rounds
        self.metrics_history = {
            "val_loss": [],
            "accuracy": [],
            "precision": [],
            "recall": [],
            "f1": []
        }
        # Will hold the parameters of the model at the last evaluated round (or best if early stopped)
        self.final_parameters = None
        # Set up an EarlyStopping instance with high patience and a small delta for fine-grained stopping
        self.early_stopping = EarlyStopping(
            patience=40,        # Number of rounds to wait without improvement before stopping
            delta=0.00001,      # Minimum change in validation loss to count as improvement
            threshold=0.0001    # Absolute threshold for validation loss to stop immediately
        )

    def evaluate(self, server_round, parameters):
        """
        Override the FedAvg evaluate method to:
        1. Call the parent evaluate to get (val_loss, metrics) for the global model
        2. Append those to metrics_history
        3. Load the global model with current parameters and check early stopping
        4. If early stopping is triggered, save best parameters and raise StopIteration
        """
        # Call the base class evaluate, which will use the provided evaluate_fn to compute loss & metrics
        result = super().evaluate(server_round, parameters)
        
        # If evaluation returned something (i.e., there was a validation set and clients returned metrics)
        if result:
            loss, metrics = result
            
            # Append the new metrics to the history lists
            self.metrics_history["val_loss"].append(loss)
            for key in metrics:
                self.metrics_history[key].append(metrics[key])
            
            # Save the current parameters as the most recent parameters
            self.final_parameters = parameters
            
            # Recreate a fresh global model and load it with the current parameters to check early stopping
            global_model = get_model()  # get_model() returns a new BrainCNN instance on the correct device
            # Flower passes parameters as a list of numpy arrays; convert them to torch Tensors and build state_dict
            final_ndarrays = parameters_to_ndarrays(parameters)
            state_dict = {
                k: torch.tensor(v) 
                for k, v in zip(global_model.state_dict().keys(), final_ndarrays)
            }
            global_model.load_state_dict(state_dict)
            
            # Call early stopping with the current validation loss and the freshly loaded model
            self.early_stopping(loss, global_model)
            if self.early_stopping.early_stop:
                print(f"Early stopping triggered at server round {server_round}.")
                # Retrieve the best model state from early_stopping
                best_state_dict = self.early_stopping.best_model_state
                # Convert that state_dict into a list of NumPy arrays in the same key order as the model
                best_parameters = [
                    best_state_dict[k].cpu().numpy() 
                    for k in global_model.state_dict().keys()
                ]
                # Update final_parameters to point to the best weights found so far
                self.final_parameters = best_parameters
                # Raise StopIteration to instruct Flower to halt training early
                raise StopIteration("Early stopping triggered.")
        
        # Return whatever the base class returned (loss and metrics) if no early stop
        return result


# Instantiate the custom strategy with required arguments for federated training:
strategy = CustomFedAvg(
    fraction_fit=1.0,       # Use 100% of available clients for training each round
    min_fit_clients=4,      # Minimum number of clients to be sampled for training
    min_available_clients=4,# Minimum number of clients that must be connected to start training
    evaluate_fn=evaluate_fn,  # Function to evaluate global model on validation data
    on_fit_config_fn=fit_config # Function to pass configuration (e.g., local epochs) to clients
)


In [None]:
try:
    # Start a federated learning simulation using Flower's simulation API
    history = fl.simulation.start_simulation(
        client_fn=client_fn,  # Function to create new client instances
        num_clients=4,  # Total number of federated clients to simulate
        config=fl.server.ServerConfig(num_rounds=72),  # Server configuration: run for 72 rounds
        strategy=strategy,  # Custom FedAvg strategy instance (handles aggregation, evaluation, early stopping)
        client_resources={"num_cpus": 2, "num_gpus": 0.5},  
            # Resources to allocate per client: 2 CPU cores and 0.5 GPU
        ray_init_args={
            "num_cpus": 16,  
                # Total number of CPU cores available to Ray for parallel client simulation
            "object_store_memory": 40 * 1024**3  
                # Amount of memory (in bytes) for Ray's object store (40 GiB)
        }
    )
except StopIteration as e:
    # Catch StopIteration raised by early stopping in the custom strategy
    print(e)  # Print reason for early stopping
# This line will execute after the simulation finishes (or is stopped early)
print("Federated learning simulation completed.")


In [None]:
# Define the x-axis values as rounds, starting from 1 up to the number of recorded accuracy values
rounds = range(1, len(strategy.metrics_history['accuracy']) + 1)

# Create a new figure with a specified size (width=12 inches, height=8 inches)
plt.figure(figsize=(12, 8))

# Plot validation loss over rounds
plt.plot(rounds, strategy.metrics_history['val_loss'], label='Validation Loss')

# Plot accuracy over rounds
plt.plot(rounds, strategy.metrics_history['accuracy'], label='Accuracy')

# Plot precision over rounds
plt.plot(rounds, strategy.metrics_history['precision'], label='Precision')

# Plot recall over rounds
plt.plot(rounds, strategy.metrics_history['recall'], label='Recall')

# Plot F1 score over rounds
plt.plot(rounds, strategy.metrics_history['f1'], label='F1 Score')

# Label the x-axis as 'Round'
plt.xlabel('Round')

# Label the y-axis as 'Metric Value'
plt.ylabel('Metric Value')

# Set the title of the plot
plt.title('Federated Learning Metrics Over Rounds')

# Display a legend to identify each plotted metric
plt.legend()

# Show a grid to improve readability of the plot
plt.grid(True)

# Save the figure to the specified file path at 300 DPI resolution
plt.savefig(
    '/home/ir739wb/ilyarekun/bc_project/federated-learning/outputs/fed-avg-non-iid-graph2.png',
    dpi=300
)

# Display the plot window (if running in an interactive environment)
plt.show()

# Close the figure to release memory
plt.close()


In [None]:
from flwr.common import parameters_to_ndarrays
import torch
import numpy as np
from sklearn.metrics import precision_recall_fscore_support

# Determine which set of model parameters to use:
# If early stopping was triggered, use the best parameters saved by early stopping.
# Otherwise, use the final parameters from the last federated round.
if strategy.early_stopping.early_stop:
    print("Using the best model parameters from early stopping.")
    best_parameters = strategy.final_parameters
else:
    print("Early stopping was not triggered. Using the final round's parameters.")
    best_parameters = strategy.final_parameters

# Create a fresh model instance (BrainCNN on the chosen device)
model = get_model()  # Assumes get_model() returns BrainCNN().to(device)

# Convert Flower parameters (which may be wrapped) into a list of NumPy arrays
final_ndarrays = parameters_to_ndarrays(best_parameters)

# Build a state_dict mapping each key in the model to its corresponding torch.Tensor
state_dict = {
    k: torch.tensor(v)
    for k, v in zip(model.state_dict().keys(), final_ndarrays)
}

# Load the constructed state_dict into the model
model.load_state_dict(state_dict)
model.eval()  # Switch model to evaluation mode (disables dropout, etc.)

# Evaluate the loaded model on the held-out test set
all_preds = []
all_targets = []
with torch.no_grad():  # Disable gradient tracking for inference
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        _, predicted = torch.max(output, 1)  # Get predicted class indices
        all_preds.extend(predicted.cpu().numpy())
        all_targets.extend(target.cpu().numpy())

# Compute overall test accuracy
accuracy = (np.array(all_preds) == np.array(all_targets)).mean()

# Compute precision, recall, and F1 score (macro-averaged)
precision, recall, f1, _ = precision_recall_fscore_support(
    all_targets, all_preds, average='macro', zero_division=0
)

# Print test set metrics
print(f"Test Accuracy: {accuracy:.4f}")
print(f"Test Precision: {precision:.4f}")
print(f"Test Recall: {recall:.4f}")
print(f"Test F1 Score: {f1:.4f}")

# Define the path to save per-round and test metrics
metrics_file = '/home/ir739wb/ilyarekun/bc_project/federated-learning/outputs/fed-avg-non-iid-metrics2.txt'
with open(metrics_file, 'w') as f:
    # Write per-round validation metrics
    rounds = range(1, len(strategy.metrics_history['val_loss']) + 1)
    for round_num in rounds:
        f.write(f"Round {round_num}:\n")
        f.write(f"  Validation Loss: {strategy.metrics_history['val_loss'][round_num-1]:.4f}\n")
        f.write(f"  Accuracy: {strategy.metrics_history['accuracy'][round_num-1]:.4f}\n")
        f.write(f"  Precision: {strategy.metrics_history['precision'][round_num-1]:.4f}\n")
        f.write(f"  Recall: {strategy.metrics_history['recall'][round_num-1]:.4f}\n")
        f.write(f"  F1 Score: {strategy.metrics_history['f1'][round_num-1]:.4f}\n")
    
    # Write overall test set metrics at the end of the file
    f.write("\nTest Metrics:\n")
    f.write(f"  Accuracy: {accuracy:.4f}\n")
    f.write(f"  Precision: {precision:.4f}\n")
    f.write(f"  Recall: {recall:.4f}\n")
    f.write(f"  F1 Score: {f1:.4f}\n")

print(f"Metrics saved to '{metrics_file}'")
