In [1]:
!cd /home/ir739wb/ilyarekun/bc_project/federated-learning/fed-avg-code



In [None]:
# Import necessary libraries for file handling, federated learning, deep learning, and data processing
import os  # For file and directory operations
import flwr as fl  # Flower framework for federated learning
import torch  # PyTorch for building and training neural networks
import torch.nn as nn  # Neural network modules from PyTorch
import torch.optim as optim  # Optimizers from PyTorch
import torchvision.transforms as transforms  # Image transformations
from torch.utils.data import DataLoader, Subset  # Data loading and subset utilities
from torchvision.datasets import ImageFolder  # Dataset class for image folders
import numpy as np  # Numerical computations
import kagglehub  # For downloading datasets from Kaggle
import shutil  # For file moving operations
from sklearn.metrics import precision_recall_fscore_support  # Metrics for evaluation
import matplotlib.pyplot as plt  # Plotting utilities
from flwr.common import parameters_to_ndarrays  # Convert Flower parameters to NumPy arrays

# Set seeds for reproducibility across random operations
seed = 42
torch.manual_seed(seed)  # Seed for PyTorch CPU
torch.cuda.manual_seed(seed)  # Seed for PyTorch CUDA (single GPU)
torch.cuda.manual_seed_all(seed)  # Seed for all CUDA devices
np.random.seed(seed)  # Seed for NumPy
torch.backends.cudnn.deterministic = True  # Ensure deterministic behavior in CuDNN
torch.backends.cudnn.benchmark = False  # Disable benchmarking for reproducibility

# Define the device for computation (GPU if available, else CPU)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")  # Inform the user about the selected device

# Define the CNN model for brain tumor classification
class BrainCNN(nn.Module):
    def __init__(self):
        super(BrainCNN, self).__init__()
        # Convolutional layers with ReLU, batch normalization, max pooling, and dropout
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3),  # Input: 3 channels (RGB), Output: 64 channels
            nn.ReLU(),  # Activation function
            nn.BatchNorm2d(64),  # Normalize across 64 channels
            nn.MaxPool2d(2),  # Downsample by 2x
            nn.Dropout2d(0.45),  # Dropout to prevent overfitting
            nn.Conv2d(64, 128, kernel_size=7, stride=1, padding=3),  # Increase channels to 128
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.45),
            nn.Conv2d(128, 128, kernel_size=7, stride=1, padding=3),  # Maintain 128 channels
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2),
            nn.Dropout2d(p=0.45),
            nn.Conv2d(128, 256, kernel_size=7, stride=1, padding=3),  # Increase to 256 channels
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.45),
            nn.Conv2d(256, 256, kernel_size=7, stride=1, padding=3),  # Maintain 256 channels
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(2),
            nn.Dropout2d(p=0.4),
            nn.Conv2d(256, 512, kernel_size=7, stride=1, padding=3),  # Increase to 512 channels
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.4)
        )
        # Fully connected layers for classification
        self.fc_layers = nn.Sequential(
            nn.Linear(512 * 3 * 3, 1024),  # Flatten output (512 * 3 * 3 from conv layers) to 1024
            nn.ReLU(),
            nn.Dropout(p=0.4),
            nn.Linear(1024, 512),  # Reduce to 512
            nn.ReLU(),
            nn.Dropout(p=0.4),
            nn.Linear(512, 4),  # Output layer for 4 classes (tumor types)
        )
        
    def forward(self, x):
        out = self.conv_layers(x)  # Pass input through convolutional layers
        out = out.view(out.size(0), -1)  # Flatten the output for fully connected layers
        out = self.fc_layers(out)  # Pass through fully connected layers
        return out

# Class to implement early stopping based on validation loss
class EarlyStopping:
    def __init__(self, patience=5, delta=0, threshold=0.19):
        self.patience = patience  # Number of rounds to wait for improvement
        self.delta = delta  # Minimum change to qualify as improvement
        self.threshold = threshold  # Loss threshold to stop training
        self.best_score = None  # Best negative validation loss
        self.early_stop = False  # Flag to indicate stopping
        self.counter = 0  # Count rounds without improvement
        self.best_model_state = None  # Store best model parameters

    def __call__(self, val_loss, model):
        # Stop if validation loss is below threshold
        if val_loss <= self.threshold:
            print(f"Val loss {val_loss:.5f} ниже порогового значения {self.threshold}.")  # Russian: "below threshold"
            self.early_stop = True
            self.best_model_state = model.state_dict()  # Save best model state
            return

        score = -val_loss  # Use negative loss for comparison (higher is better)
        if self.best_score is None:
            self.best_score = score  # Initialize best score
            self.best_model_state = model.state_dict()
        elif score < self.best_score + self.delta:
            self.counter += 1  # Increment counter if no improvement
            if self.counter >= self.patience:
                self.early_stop = True  # Trigger early stopping
        else:
            self.best_score = score  # Update best score
            self.best_model_state = model.state_dict()
            self.counter = 0  # Reset counter on improvement

    def load_best_model(self, model):
        model.load_state_dict(self.best_model_state)  # Load best model parameters

# Function to preprocess the brain tumor MRI dataset in a non-IID manner
def data_preprocessing_tumor_NON_IID(num_clients=4):
    # Download dataset from Kaggle
    dataset_path = kagglehub.dataset_download("masoudnickparvar/brain-tumor-mri-dataset")
    train_path = os.path.join(dataset_path, "Training")  # Training data path
    test_path = os.path.join(dataset_path, "Testing")  # Testing data path
    general_dataset_path = os.path.join(dataset_path, "General_Dataset")  # Combined dataset path
    os.makedirs(general_dataset_path, exist_ok=True)  # Create directory if it doesn't exist
    
    # Combine training and testing data into a single dataset
    for source_path in [train_path, test_path]:
        for class_name in os.listdir(source_path):
            class_path = os.path.join(source_path, class_name)
            general_class_path = os.path.join(general_dataset_path, class_name)
            os.makedirs(general_class_path, exist_ok=True)
            for img_name in os.listdir(class_path):
                img_path = os.path.join(class_path, img_name)
                shutil.move(img_path, os.path.join(general_class_path, img_name))  # Move images
    
    # Define image transformations
    transform = transforms.Compose([
        transforms.CenterCrop((400, 400)),  # Crop to 400x400
        transforms.Resize((200, 200)),  # Resize to 200x200
        transforms.ToTensor(),  # Convert to tensor
    ])
    
    # Load the combined dataset
    general_dataset = ImageFolder(root=general_dataset_path, transform=transform)
    targets = general_dataset.targets  # Class labels
    classes = list(set(targets))  # Unique classes
    
    # Split dataset into train, validation, and test sets
    train_indices = []
    val_indices = []
    test_indices = []
    train_ratio = 0.7  # 70% for training
    val_ratio = 0.2  # 20% for validation (remaining 10% for testing)
    
    for class_label in classes:
        class_indices = [i for i, target in enumerate(targets) if target == class_label]
        class_size = len(class_indices)
        train_size = int(train_ratio * class_size)
        val_size = int(val_ratio * class_size)
        train_indices.extend(class_indices[:train_size])  # Training indices
        val_indices.extend(class_indices[train_size:train_size + val_size])  # Validation indices
        test_indices.extend(class_indices[train_size + val_size:])  # Test indices
    
    # Create subsets for train, validation, and test
    train_set = Subset(general_dataset, train_indices)
    val_set = Subset(general_dataset, val_indices)
    test_set = Subset(general_dataset, test_indices)
    
    # Create data loaders for validation and test sets
    val_loader = DataLoader(val_set, batch_size=64, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=64, shuffle=False)
    
    # Define non-IID distribution for clients (each client has a different class proportion)
    distribution = {
        0: [0.70, 0.15, 0.10, 0.05],  # Client 0: mostly class 0
        1: [0.15, 0.70, 0.10, 0.05],  # Client 1: mostly class 1
        2: [0.10, 0.15, 0.70, 0.05],  # Client 2: mostly class 2
        3: [0.05, 0.10, 0.15, 0.70]   # Client 3: mostly class 3
    }
    
    client_indices = {client: [] for client in range(num_clients)}  # Indices for each client
    
    # Distribute training data to clients in a non-IID manner
    for class_label in classes:
        class_train_indices = [idx for idx in train_indices if general_dataset.targets[idx] == class_label]
        np.random.shuffle(class_train_indices)  # Shuffle indices
        
        n = len(class_train_indices)  # Total samples for this class
        allocation = []
        for client in range(num_clients):
            cnt = int(distribution[class_label][client] * n)  # Allocate based on distribution
            allocation.append(cnt)
        allocation[-1] = n - sum(allocation[:-1])  # Adjust last client to use all remaining samples
        
        start = 0
        for client in range(num_clients):
            cnt = allocation[client]
            client_indices[client].extend(class_train_indices[start:start + cnt])  # Assign indices
            start += cnt
    
    # Create data loaders for each client
    client_train_loaders = []
    for client in range(num_clients):
        subset = Subset(general_dataset, client_indices[client])
        loader = DataLoader(subset, batch_size=64, shuffle=True)  # Shuffle for training
        client_train_loaders.append(loader)
    
    return client_train_loaders, val_loader, test_loader  # Return all loaders

# Preprocess data and get loaders for 4 clients
client_train_loaders, val_loader, test_loader = data_preprocessing_tumor_NON_IID(num_clients=4)

# Helper function to instantiate the model
def get_model():
    return BrainCNN().to(device)  # Create model and move to device

# Helper function to get the optimizer
def get_optimizer(model):
    return optim.SGD(model.parameters(), lr=0.001, momentum=0.7, weight_decay=0.09)  # SGD with momentum and weight decay

# Helper function to get the loss function
def get_loss_function():
    return nn.CrossEntropyLoss()  # Cross-entropy loss for classification

# Configuration function for client training
def fit_config(server_round: int):
    return {"local_epochs": 5}  # Each client trains for 5 local epochs per round

# Evaluation function for the server to assess the global model
def evaluate_fn(server_round, parameters, config):
    model = get_model()  # Create a new model instance
    state_dict = {k: torch.tensor(v) for k, v in zip(model.state_dict().keys(), parameters)}  # Load parameters
    model.load_state_dict(state_dict)
    model.eval()  # Set model to evaluation mode
    
    criterion = get_loss_function()  # Get loss function
    val_loss = 0.0  # Initialize validation loss
    all_preds = []  # Store predictions
    all_targets = []  # Store true labels
    
    # Evaluate on validation set without gradient computation
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)  # Move to device
            output = model(data)  # Forward pass
            loss = criterion(output, target)  # Compute loss
            val_loss += loss.item() * data.size(0)  # Accumulate loss
            _, predicted = torch.max(output, 1)  # Get predicted classes
            all_preds.extend(predicted.cpu().numpy())  # Store predictions
            all_targets.extend(target.cpu().numpy())  # Store targets
    
    val_loss /= len(val_loader.dataset)  # Average loss over dataset
    accuracy = (np.array(all_preds) == np.array(all_targets)).mean()  # Compute accuracy
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_targets, all_preds, average='macro', zero_division=0  # Compute additional metrics
    )
    
    # Return loss and metrics dictionary
    return val_loss, {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1
    }

# Import Context for Flower client definition
from flwr.common import Context

# Define the client function for Flower simulation
def client_fn(context: Context):
    cid = int(context.node_id) % len(client_train_loaders)  # Assign client ID
    train_loader = client_train_loaders[cid]  # Get training data for this client
    
    # Additional imports for client (should ideally be at the top, but placed here as per original code)
    import torch
    from torch import device
    from collections import OrderedDict

    # Define the Flower client class
    class FlowerClient:
        def __init__(self):
            self.model = get_model()  # Initialize model
            self.train_loader = train_loader  # Assign training data loader
            self.optimizer = get_optimizer(self.model)  # Initialize optimizer
            self.criterion = get_loss_function()  # Initialize loss function
            self.device = device  # Set device
            # Initialize local control variate (c_k) with zeros matching model parameters
            self.c_k = [torch.zeros_like(param).to(self.device) for param in self.model.parameters()]

        def fit(self, parameters, config):
            try:
                # Load server parameters into the model
                state_dict = OrderedDict({k: torch.tensor(v).to(self.device) for k, v in zip(self.model.state_dict().keys(), parameters)})
                self.model.load_state_dict(state_dict)
                self.model.train()  # Set model to training mode

                # Extract global control variate from config (Scaffold-specific)
                global_control_variate = config.get("global_control_variate")
                if global_control_variate is None:
                    raise ValueError("Global control variate not provided in config")
                c = [torch.tensor(param).to(self.device) for param in global_control_variate]

                # Store initial parameters for control variate update
                initial_parameters = [param.clone().detach() for param in self.model.parameters()]

                # Training loop with Scaffold's modified gradient
                local_epochs = config.get("local_epochs", 1)  # Number of local epochs
                eta = self.optimizer.param_groups[0]['lr']  # Learning rate
                T = local_epochs * len(self.train_loader)  # Total training steps

                if T == 0:
                    raise ValueError("Training loader is empty or not properly initialized")

                for epoch in range(local_epochs):
                    for data, target in self.train_loader:
                        data, target = data.to(self.device), target.to(self.device)  # Move to device
                        self.optimizer.zero_grad()  # Clear gradients
                        output = self.model(data)  # Forward pass
                        loss = self.criterion(output, target)  # Compute loss
                        loss.backward()  # Backpropagation

                        # Modify gradients using Scaffold control variates
                        for param, ck, gc in zip(self.model.parameters(), self.c_k, c):
                            if param.grad.shape != ck.shape or param.grad.shape != gc.shape:
                                raise ValueError(f"Shape mismatch: param.grad {param.grad.shape}, ck {ck.shape}, gc {gc.shape}")
                            param.grad = param.grad - ck + gc  # Adjust gradient

                        self.optimizer.step()  # Update model parameters

                # Compute updated client control variate (c_k^+)
                c_k_plus = [
                    gc - (1 / (eta * T)) * (param - initial_param)
                    for param, initial_param, gc in zip(self.model.parameters(), initial_parameters, c)
                ]

                # Update local control variate
                self.c_k = [param.clone().detach() for param in c_k_plus]

                # Return updated parameters and control variate
                updated_params = [param.cpu().detach().numpy() for param in self.model.parameters()]
                return updated_params, len(self.train_loader.dataset), {"control_variate": [param.cpu().numpy().tolist() for param in c_k_plus]}

            except Exception as e:
                print(f"Error in client fit: {str(e)}")  # Log error
                raise  # Re-raise for server logging
        
        def get_parameters(self, config):
            return [param.cpu().detach().numpy() for param in self.model.parameters()]  # Return current parameters
    
    return FlowerClient().to_client()  # Convert to Flower client

# Custom Scaffold strategy extending FedAvg
class CustomScaffold(fl.server.strategy.FedAvg):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        model = get_model()  # Initialize model to get parameter shapes
        self.c = [np.zeros(param.shape) for param in model.parameters()]  # Global control variate
        # Store metrics history across rounds
        self.metrics_history = {
            "val_loss": [],
            "accuracy": [],
            "precision": [],
            "recall": [],
            "f1": []
        }
        self.final_parameters = None  # Store final parameters
        self.early_stopping = EarlyStopping(patience=40, delta=0.00001, threshold=0.0001)  # Early stopping instance
    
    def configure_fit(self, server_round, parameters, client_manager):
        config = super().configure_fit(server_round, parameters, client_manager)  # Default configuration
        global_control_variate = [param.tolist() for param in self.c]  # Convert global control variate
        for cfg in config:
            cfg["config"]["global_control_variate"] = global_control_variate  # Add to client config
        return config
    
    def aggregate_fit(self, server_round, results, failures):
        if not results:
            return None, {}  # Return if no results
        
        # Aggregate parameters using FedAvg
        aggregated_parameters = super().aggregate_fit(server_round, results, failures)
        
        # Aggregate client control variates
        total_weight = sum([result.num_examples for result in results])  # Total samples
        c_k_list = [result.metrics["control_variate"] for result in results]  # Client control variates
        c_k_arrays = [[np.array(param) for param in client_c_k] for client_c_k in c_k_list]
        self.c = [
            np.average([c_k[i] for c_k in c_k_arrays], weights=[result.num_examples for result in results], axis=0)
            for i in range(len(self.c))  # Weighted average of control variates
        ]
        
        return aggregated_parameters, {}  # Return aggregated parameters
    
    def evaluate(self, server_round, parameters):
        result = super().evaluate(server_round, parameters)  # Evaluate global model
        if result:
            loss, metrics = result
            # Store metrics
            self.metrics_history["val_loss"].append(loss)
            for key in metrics:
                self.metrics_history[key].append(metrics[key])
            self.final_parameters = parameters  # Update final parameters
            
            # Load parameters into model for early stopping check
            global_model = get_model()
            final_ndarrays = parameters_to_ndarrays(parameters)
            state_dict = {k: torch.tensor(v) for k, v in zip(global_model.state_dict().keys(), final_ndarrays)}
            global_model.load_state_dict(state_dict)
            
            self.early_stopping(loss, global_model)  # Check for early stopping
            if self.early_stopping.early_stop:
                print(f"Early stopping triggered at round {server_round}.")
                best_state_dict = self.early_stopping.best_model_state
                best_parameters = [best_state_dict[k].cpu().numpy() for k in global_model.state_dict().keys()]
                self.final_parameters = best_parameters  # Update with best parameters
                raise StopIteration("Early stopping triggered.")  # Stop simulation
        return result

# Define the Scaffold strategy
strategy = CustomScaffold(
    fraction_fit=1.0,  # Use all clients for fitting
    min_fit_clients=4,  # Minimum clients required for fitting
    min_available_clients=4,  # Minimum clients available
    evaluate_fn=evaluate_fn,  # Evaluation function
    on_fit_config_fn=fit_config  # Fit configuration function
)

# Start the federated learning simulation
try:
    history = fl.simulation.start_simulation(
        client_fn=client_fn,  # Client creation function
        num_clients=4,  # Number of clients
        config=fl.server.ServerConfig(num_rounds=72),  # Run for 72 rounds
        strategy=strategy,  # Use custom Scaffold strategy
        client_resources={"num_cpus": 2, "num_gpus": 0.5},  # Resources per client
        ray_init_args={
            "num_cpus": 16,  # Total CPUs for Ray
            "object_store_memory": 40 * 1024**3  # Memory for Ray object store (40GB)
        }
    )
except StopIteration as e:
    print(e)  # Handle early stopping exception
print("Federated learning simulation completed.")

# Plot metrics over rounds
rounds = range(1, len(strategy.metrics_history['accuracy']) + 1)
plt.figure(figsize=(12, 8))  # Set figure size
plt.plot(rounds, strategy.metrics_history['val_loss'], label='Validation Loss')  # Plot validation loss
plt.plot(rounds, strategy.metrics_history['accuracy'], label='Accuracy')  # Plot accuracy
plt.plot(rounds, strategy.metrics_history['precision'], label='Precision')  # Plot precision
plt.plot(rounds, strategy.metrics_history['recall'], label='Recall')  # Plot recall
plt.plot(rounds, strategy.metrics_history['f1'], label='F1 Score')  # Plot F1 score
plt.xlabel('Round')  # X-axis label
plt.ylabel('Metric Value')  # Y-axis label
plt.title('Federated Learning Metrics Over Rounds')  # Plot title
plt.legend()  # Add legend
plt.grid(True)  # Add grid
plt.savefig('/home/ir739wb/ilyarekun/bc_project/federated-learning/fed-avg-code/scaffold-non-iid-graph.png', dpi=300)  # Save plot
plt.close()  # Close figure

# Select the best model parameters
if strategy.early_stopping.early_stop:
    print("Using the best model parameters from early stopping.")
    best_parameters = strategy.final_parameters  # Use early stopping parameters
else:
    print("Early stopping was not triggered. Using the final round's parameters.")
    best_parameters = strategy.final_parameters  # Use final round parameters

# Evaluate the best model on the test set
model = get_model()  # Create model instance
final_ndarrays = parameters_to_ndarrays(best_parameters)  # Convert parameters
state_dict = {k: torch.tensor(v) for k, v in zip(model.state_dict().keys(), final_ndarrays)}
model.load_state_dict(state_dict)  # Load parameters
model.eval()  # Set to evaluation mode

all_preds = []  # Store predictions
all_targets = []  # Store true labels
with torch.no_grad():  # Disable gradients for evaluation
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)  # Move to device
        output = model(data)  # Forward pass
        _, predicted = torch.max(output, 1)  # Get predicted classes
        all_preds.extend(predicted.cpu().numpy())  # Store predictions
        all_targets.extend(target.cpu().numpy())  # Store targets

# Compute test metrics
accuracy = (np.array(all_preds) == np.array(all_targets)).mean()  # Accuracy
precision, recall, f1, _ = precision_recall_fscore_support(all_targets, all_preds, average='macro', zero_division=0)  # Additional metrics

# Print test metrics
print(f"Test Accuracy: {accuracy:.4f}")
print(f"Test Precision: {precision:.4f}")
print(f"Test Recall: {recall:.4f}")
print(f"Test F1 Score: {f1:.4f}")

# Save metrics to a file
metrics_file = '/home/ir739wb/ilyarekun/bc_project/federated-learning/fed-avg-code/scaffold-non-iid-metrics.txt'
with open(metrics_file, 'w') as f:
    rounds = range(1, len(strategy.metrics_history['val_loss']) + 1)
    for round_num in rounds:
        # Write per-round validation metrics
        f.write(f"Round {round_num}:\n")
        f.write(f"  Validation Loss: {strategy.metrics_history['val_loss'][round_num-1]:.4f}\n")
        f.write(f"  Accuracy: {strategy.metrics_history['accuracy'][round_num-1]:.4f}\n")
        f.write(f"  Precision: {strategy.metrics_history['precision'][round_num-1]:.4f}\n")
        f.write(f"  Recall: {strategy.metrics_history['recall'][round_num-1]:.4f}\n")
        f.write(f"  F1 Score: {strategy.metrics_history['f1'][round_num-1]:.4f}\n")
    # Write test metrics
    f.write("\nTest Metrics:\n")
    f.write(f"  Accuracy: {accuracy:.4f}\n")
    f.write(f"  Precision: {precision:.4f}\n")
    f.write(f"  Recall: {recall:.4f}\n")
    f.write(f"  F1 Score: {f1:.4f}\n")
print(f"Metrics saved to '{metrics_file}'")  # Confirm file save