In [1]:
!cd /home/ir739wb/ilyarekun/bc_project/federated-learning



In [None]:
# Import necessary libraries for file handling, federated learning, deep learning, and data processing
import os
import flwr as fl  # Flower framework for federated learning
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import ImageFolder
import numpy as np
import kagglehub  # Utility for downloading datasets from Kaggle
import shutil
from sklearn.metrics import precision_recall_fscore_support
import matplotlib.pyplot as plt
from flwr.common import parameters_to_ndarrays
import scipy.optimize as opt  # For Hungarian algorithm in FedMA

# Set seeds for reproducibility
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
# Ensure determinism in CUDA operations (may slow down training)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Define device: use GPU if available, otherwise fall back to CPU
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")



# Define the CNN model for brain tumor classification
class BrainCNN(nn.Module):
    def __init__(self):
        super(BrainCNN, self).__init__()
        # Convolutional layers: progressively extract spatial features
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3),  # Input channels: 3 (RGB), output: 64
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2),  # Downsample by factor of 2
            nn.Dropout2d(0.45),

            nn.Conv2d(64, 128, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.45),

            nn.Conv2d(128, 128, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2),
            nn.Dropout2d(p=0.45),

            nn.Conv2d(128, 256, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.45),

            nn.Conv2d(256, 256, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(2),
            nn.Dropout2d(p=0.4),

            nn.Conv2d(256, 512, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.MaxPool2d(2),
            nn.Dropout2d(p=0.4)
        )
        # Fully connected layers: map extracted features to class logits
        self.fc_layers = nn.Sequential(
            nn.Linear(512 * 3 * 3, 1024),  # Flattened feature size: 512 channels × 3 × 3 spatial dims
            nn.ReLU(),
            nn.Dropout(p=0.4),

            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(p=0.4),

            nn.Linear(512, 4),  # Output layer: 4 classes
        )
        
    def forward(self, x):
        # Forward pass: apply convolutional layers, then flatten and apply FC layers
        out = self.conv_layers(x)
        out = out.view(out.size(0), -1)  # Flatten tensor to (batch_size, 512*3*3)
        out = self.fc_layers(out)
        return out



# Early stopping class to terminate training when validation loss plateaus or goes below a threshold
class EarlyStopping:
    def __init__(self, patience=5, delta=0, threshold=0.19):
        """
        Args:
            patience: Number of consecutive rounds without improvement before stopping.
            delta: Minimum change in validation loss to be considered an improvement.
            threshold: If validation loss falls below this value, stop immediately.
        """
        self.patience = patience
        self.delta = delta
        self.threshold = threshold
        self.best_score = None
        self.early_stop = False
        self.counter = 0
        self.best_model_state = None

    def __call__(self, val_loss, model):
        # If validation loss is below threshold, stop immediately and save the model state
        if val_loss <= self.threshold:
            print(f"Val loss {val_loss:.5f} is below threshold {self.threshold}.")
            self.early_stop = True
            self.best_model_state = model.state_dict()
            return

        score = -val_loss  # We maximize negative validation loss
        if self.best_score is None:
            # First time: initialize best score and save state
            self.best_score = score
            self.best_model_state = model.state_dict()
        elif score < self.best_score + self.delta:
            # No significant improvement: increase counter
            self.counter += 1
            if self.counter >= self.patience:
                # If patience exceeded, trigger early stopping
                self.early_stop = True
        else:
            # Improvement found: reset counter and update best score/state
            self.best_score = score
            self.best_model_state = model.state_dict()
            self.counter = 0

    def load_best_model(self, model):
        # Load the saved best model state into the provided model
        model.load_state_dict(self.best_model_state)



# Data preprocessing function: downloads, organizes, and splits data among clients in a non-IID fashion
def data_preprocessing_tumor_NON_IID(num_clients=4):
    # Download the dataset from Kaggle and extract train/test into a general folder
    dataset_path = kagglehub.dataset_download("masoudnickparvar/brain-tumor-mri-dataset")
    train_path = os.path.join(dataset_path, "Training")
    test_path = os.path.join(dataset_path, "Testing")
    general_dataset_path = os.path.join(dataset_path, "General_Dataset")
    os.makedirs(general_dataset_path, exist_ok=True)
    
    # Move all images from Training/Testing into the general dataset folder, preserving class subfolders
    for source_path in [train_path, test_path]:
        for class_name in os.listdir(source_path):
            class_path = os.path.join(source_path, class_name)
            general_class_path = os.path.join(general_dataset_path, class_name)
            os.makedirs(general_class_path, exist_ok=True)
            for img_name in os.listdir(class_path):
                img_path = os.path.join(class_path, img_name)
                shutil.move(img_path, os.path.join(general_class_path, img_name))
    
    # Define transformations: center crop to 400x400, resize to 200x200, convert to tensor
    transform = transforms.Compose([
        transforms.CenterCrop((400, 400)),
        transforms.Resize((200, 200)),
        transforms.ToTensor(),
    ])
    
    # Load the entire dataset from the general folder
    general_dataset = ImageFolder(root=general_dataset_path, transform=transform)
    targets = general_dataset.targets  # List of class labels for each image
    classes = list(set(targets))       # Unique class labels
    
    # Split indices into train/validation/test sets per class
    train_indices = []
    val_indices = []
    test_indices = []
    train_ratio = 0.7
    val_ratio = 0.2
    
    for class_label in classes:
        # Get all indices for this class
        class_indices = [i for i, target in enumerate(targets) if target == class_label]
        class_size = len(class_indices)
        train_size = int(train_ratio * class_size)
        val_size = int(val_ratio * class_size)
        # Assign first part to train, next to val, remainder to test
        train_indices.extend(class_indices[:train_size])
        val_indices.extend(class_indices[train_size:train_size + val_size])
        test_indices.extend(class_indices[train_size + val_size:])
    
    # Create subsets for validation and test sets
    val_set = Subset(general_dataset, val_indices)
    test_set = Subset(general_dataset, test_indices)
    val_loader = DataLoader(val_set, batch_size=64, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=64, shuffle=False)
    
    # Define a fixed non-IID distribution of classes across clients
    distribution = {
        0: [0.70, 0.15, 0.10, 0.05],
        1: [0.15, 0.70, 0.10, 0.05],
        2: [0.10, 0.15, 0.70, 0.05],
        3: [0.05, 0.10, 0.15, 0.70]
    }
    
    # Initialize an empty list of indices per client
    client_indices = {client: [] for client in range(num_clients)}
    
    # Distribute training indices to clients according to the distribution for each class
    for class_label in classes:
        class_train_indices = [idx for idx in train_indices if general_dataset.targets[idx] == class_label]
        np.random.shuffle(class_train_indices)
        
        n = len(class_train_indices)
        allocation = []
        # Calculate the number of samples per client for this class
        for client in range(num_clients):
            cnt = int(distribution[class_label][client] * n)
            allocation.append(cnt)
        # Ensure sum of allocated samples equals total
        allocation[-1] = n - sum(allocation[:-1])
        
        start = 0
        for client in range(num_clients):
            cnt = allocation[client]
            client_indices[client].extend(class_train_indices[start:start + cnt])
            start += cnt
    
    # Create a DataLoader for each client using their allocated indices
    client_train_loaders = []
    for client in range(num_clients):
        subset = Subset(general_dataset, client_indices[client])
        loader = DataLoader(subset, batch_size=64, shuffle=True)
        client_train_loaders.append(loader)
    
    return client_train_loaders, val_loader, test_loader


# Preprocess data and obtain DataLoaders for each client, validation, and test sets
client_train_loaders, val_loader, test_loader = data_preprocessing_tumor_NON_IID(num_clients=4)



# Helper function: return a new instance of the CNN model on the appropriate device
def get_model():
    return BrainCNN().to(device)

# Helper function: return an optimizer configured for a given model
def get_optimizer(model):
    return optim.SGD(model.parameters(), lr=0.001, momentum=0.7, weight_decay=0.09)

# Helper function: return the loss function for classification
def get_loss_function():
    return nn.CrossEntropyLoss()

# Function to supply configuration to clients before fitting (e.g., number of local epochs)
def fit_config(server_round: int):
    return {"local_epochs": 5}

# Evaluation function to be called by the server: computes validation loss and metrics
def evaluate_fn(server_round, parameters, config):
    # Instantiate and load model parameters from numpy arrays
    model = get_model()
    state_dict = {k: torch.tensor(v) for k, v in zip(model.state_dict().keys(), parameters)}
    model.load_state_dict(state_dict)
    model.eval()
    
    criterion = get_loss_function()
    val_loss = 0.0
    all_preds = []
    all_targets = []
    
    # Compute loss and predictions on the validation set
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            val_loss += loss.item() * data.size(0)
            _, predicted = torch.max(output, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    
    val_loss /= len(val_loader.dataset)  # Average validation loss
    accuracy = (np.array(all_preds) == np.array(all_targets)).mean()
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_targets, all_preds, average='macro', zero_division=0
    )
    
    # Return validation loss and a dictionary of metrics
    return val_loss, {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1
    }



# Client function: defines the behavior of each client in the federated simulation
from flwr.common import Context

def client_fn(context: Context):
    # Determine client ID and assign the appropriate DataLoader
    cid = int(context.node_id) % len(client_train_loaders)
    train_loader = client_train_loaders[cid]
    
    import torch
    from torch import device
    from collections import OrderedDict

    class FlowerClient:
        def __init__(self):
            self.model = get_model()
            self.train_loader = train_loader
            self.optimizer = get_optimizer(self.model)
            self.criterion = get_loss_function()
            self.device = device

        def fit(self, parameters, config):
            try:
                # Load incoming global model parameters
                state_dict = OrderedDict({k: torch.tensor(v).to(self.device) for k, v in zip(self.model.state_dict().keys(), parameters)})
                self.model.load_state_dict(state_dict)
                self.model.train()

                local_epochs = config.get("local_epochs", 1)
                # Perform local training for the specified number of epochs
                for epoch in range(local_epochs):
                    for data, target in self.train_loader:
                        data, target = data.to(self.device), target.to(self.device)
                        self.optimizer.zero_grad()
                        output = self.model(data)
                        loss = self.criterion(output, target)
                        loss.backward()
                        self.optimizer.step()

                # Return updated parameters and number of training examples
                updated_params = [param.cpu().detach().numpy() for param in self.model.parameters()]
                return updated_params, len(self.train_loader.dataset), {}

            except Exception as e:
                print(f"Error in client fit: {str(e)}")
                raise
        
        def get_parameters(self, config):
            # Return the current model parameters without training
            return [param.cpu().detach().numpy() for param in self.model.parameters()]
    
    # Return the client wrapped as a FlowerClient
    return FlowerClient().to_client()



# Custom FedMA strategy: implements federated matching and averaging to align neurons/filters
class CustomFedMA(fl.server.strategy.Strategy):
    def __init__(self, fraction_fit=1.0, min_fit_clients=4, min_available_clients=4, evaluate_fn=None, on_fit_config_fn=None):
        super().__init__()
        self.fraction_fit = fraction_fit
        self.min_fit_clients = min_fit_clients
        self.min_available_clients = min_available_clients
        self.evaluate_fn = evaluate_fn
        self.on_fit_config_fn = on_fit_config_fn
        self.metrics_history = {
            "val_loss": [],
            "accuracy": [],
            "precision": [],
            "recall": [],
            "f1": []
        }
        self.final_parameters = None
        # Use early stopping with high patience for federated rounds
        self.early_stopping = EarlyStopping(patience=40, delta=0.00001, threshold=0.0001)

    def configure_fit(self, server_round, parameters, client_manager):
        """
        Select which clients to run the next fit round on and package the FitIns messages.
        """
        config = {}
        if self.on_fit_config_fn is not None:
            config = self.on_fit_config_fn(server_round)
        sample_size, min_num_clients = self.num_fit_clients(client_manager.num_available())
        clients = client_manager.sample(num_clients=sample_size, min_num_clients=min_num_clients)
        return [(client, fl.common.FitIns(parameters, config)) for client in clients]

    def aggregate_fit(self, server_round, results, failures):
        """
        Aggregate model updates from clients using neuron/filter matching (Hungarian algorithm).
        """
        if not results:
            return None, {}

        # Convert Flower parameters back to numpy arrays for each client
        client_params = [parameters_to_ndarrays(result.parameters) for result in results]
        num_examples = [result.num_examples for result in results]
        total_examples = sum(num_examples)

        # Instantiate a fresh model to inspect layer shapes
        model = get_model()
        param_shapes = [param.shape for param in model.parameters()]
        aggregated_params = []

        # Iterate over each layer in the model
        for layer_idx, shape in enumerate(param_shapes):
            # Collect this layer's parameters from all clients
            layer_params = [client[layer_idx] for client in client_params]

            if len(shape) == 1:
                # Bias vector: simple weighted average
                aggregated_layer = np.average(layer_params, weights=num_examples, axis=0)
            else:
                if len(shape) == 2:
                    # Fully connected layer: match neurons across clients
                    num_neurons = shape[0]
                    cost_matrix = np.zeros((num_neurons, num_neurons))
                    # Compute cost between every pair of neurons by average L2 distance
                    for i in range(num_neurons):
                        for j in range(num_neurons):
                            cost_matrix[i, j] = np.mean([
                                np.linalg.norm(p1[i] - p2[j]) 
                                for p1 in layer_params 
                                for p2 in layer_params if p1 is not p2
                            ])
                    # Solve assignment problem to align neurons
                    row_ind, col_ind = opt.linear_sum_assignment(cost_matrix)
                    matched_params = np.mean([p[row_ind] for p in layer_params], axis=0)
                    aggregated_layer = matched_params

                elif len(shape) == 4:
                    # Convolutional layer: match filters across clients
                    num_filters = shape[0]
                    cost_matrix = np.zeros((num_filters, num_filters))
                    # Compute cost between every pair of filters by average L2 distance
                    for i in range(num_filters):
                        for j in range(num_filters):
                            cost_matrix[i, j] = np.mean([
                                np.linalg.norm(p1[i].flatten() - p2[j].flatten())
                                for p1 in layer_params
                                for p2 in layer_params if p1 is not p2
                            ])
                    # Solve assignment problem to align filters
                    row_ind, col_ind = opt.linear_sum_assignment(cost_matrix)
                    matched_params = np.mean([p[row_ind] for p in layer_params], axis=0)
                    aggregated_layer = matched_params

                else:
                    # Default to weighted average for other parameter shapes
                    aggregated_layer = np.average(layer_params, weights=num_examples, axis=0)

            aggregated_params.append(aggregated_layer)

        return aggregated_params, {}

    def evaluate(self, server_round, parameters):
        """
        Evaluate global model on validation data and optionally trigger early stopping.
        """
        if self.evaluate_fn is None:
            return None

        result = self.evaluate_fn(server_round, parameters, {})
        if result:
            loss, metrics = result
            # Record metrics history
            self.metrics_history["val_loss"].append(loss)
            for key in metrics:
                self.metrics_history[key].append(metrics[key])
            self.final_parameters = parameters
            
            # Load global model to check for early stopping
            global_model = get_model()
            final_ndarrays = parameters_to_ndarrays(parameters)
            state_dict = {k: torch.tensor(v) for k, v in zip(global_model.state_dict().keys(), final_ndarrays)}
            global_model.load_state_dict(state_dict)
            
            # Perform early stopping check
            self.early_stopping(loss, global_model)
            if self.early_stopping.early_stop:
                print(f"Early stopping triggered at round {server_round}.")
                best_state_dict = self.early_stopping.best_model_state
                # Convert best state dict back to numpy parameters
                best_parameters = [best_state_dict[k].cpu().numpy() for k in global_model.state_dict().keys()]
                self.final_parameters = best_parameters
                # Raise StopIteration to halt simulation
                raise StopIteration("Early stopping triggered.")
        return result

    def num_fit_clients(self, num_available):
        """
        Determine how many clients to sample for fitting based on fraction_fit and minimums.
        """
        num_clients = int(num_available * self.fraction_fit)
        return max(num_clients, self.min_fit_clients), self.min_available_clients



# Define the FedMA strategy with specified parameters
strategy = CustomFedMA(
    fraction_fit=1.0,
    min_fit_clients=4,
    min_available_clients=4,
    evaluate_fn=evaluate_fn,
    on_fit_config_fn=fit_config
)

# Start the federated learning simulation
try:
    history = fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=4,
        config=fl.server.ServerConfig(num_rounds=72),  # Total federated rounds
        strategy=strategy,
        client_resources={"num_cpus": 2, "num_gpus": 0.5},
        ray_init_args={
            "num_cpus": 16,
            "object_store_memory": 40 * 1024**3
        }
    )
except StopIteration as e:
    # Catch early stopping signal to exit simulation gracefully
    print(e)

print("Federated learning simulation completed.")



# Plot metrics recorded over federated rounds
rounds = range(1, len(strategy.metrics_history['accuracy']) + 1)
plt.figure(figsize=(12, 8))
plt.plot(rounds, strategy.metrics_history['val_loss'], label='Validation Loss')
plt.plot(rounds, strategy.metrics_history['accuracy'], label='Accuracy')
plt.plot(rounds, strategy.metrics_history['precision'], label='Precision')
plt.plot(rounds, strategy.metrics_history['recall'], label='Recall')
plt.plot(rounds, strategy.metrics_history['f1'], label='F1 Score')
plt.xlabel('Round')
plt.ylabel('Metric Value')
plt.title('Federated Learning Metrics Over Rounds (FedMA)')
plt.legend()
plt.grid(True)
# Save the plot to a file
plt.savefig('/home/ir739wb/ilyarekun/bc_project/federated-learning/outputs/fedma-non-iid-graph.png', dpi=300)
plt.close()



# After simulation, select best parameters based on early stopping
if strategy.early_stopping.early_stop:
    print("Using the best model parameters from early stopping.")
    best_parameters = strategy.final_parameters
else:
    print("Early stopping was not triggered. Using the final round's parameters.")
    best_parameters = strategy.final_parameters

# Load the best global model and evaluate on the test set
model = get_model()
final_ndarrays = parameters_to_ndarrays(best_parameters)
state_dict = {k: torch.tensor(v) for k, v in zip(model.state_dict().keys(), final_ndarrays)}
model.load_state_dict(state_dict)
model.eval()

all_preds = []
all_targets = []
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        _, predicted = torch.max(output, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_targets.extend(target.cpu().numpy())

# Compute and print final test metrics
accuracy = (np.array(all_preds) == np.array(all_targets)).mean()
precision, recall, f1, _ = precision_recall_fscore_support(all_targets, all_preds, average='macro', zero_division=0)

print(f"Test Accuracy: {accuracy:.4f}")
print(f"Test Precision: {precision:.4f}")
print(f"Test Recall: {recall:.4f}")
print(f"Test F1 Score: {f1:.4f}")

# Save all recorded metrics (per round and final test) to a text file
metrics_file = '/home/ir739wb/ilyarekun/bc_project/federated-learning/outputs/fedma-non-iid-metrics.txt'
with open(metrics_file, 'w') as f:
    rounds = range(1, len(strategy.metrics_history['val_loss']) + 1)
    for round_num in rounds:
        f.write(f"Round {round_num}:\n")
        f.write(f"  Validation Loss: {strategy.metrics_history['val_loss'][round_num-1]:.4f}\n")
        f.write(f"  Accuracy: {strategy.metrics_history['accuracy'][round_num-1]:.4f}\n")
        f.write(f"  Precision: {strategy.metrics_history['precision'][round_num-1]:.4f}\n")
        f.write(f"  Recall: {strategy.metrics_history['recall'][round_num-1]:.4f}\n")
        f.write(f"  F1 Score: {strategy.metrics_history['f1'][round_num-1]:.4f}\n")
    f.write("\nTest Metrics:\n")
    f.write(f"  Accuracy: {accuracy:.4f}\n")
    f.write(f"  Precision: {precision:.4f}\n")
    f.write(f"  Recall: {recall:.4f}\n")
    f.write(f"  F1 Score: {f1:.4f}\n")

print(f"Metrics saved to '{metrics_file}'")
