<a href="https://colab.research.google.com/github/hanjidani/FL/blob/main/FederatedLearning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Different federated method to achive global model from base nodes
## Dr. Mojahedian
### Hossein Anjidani & S Yahya Tehrani

# Model Definition

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, Subset

# Define a simple CNN model for MNIST
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(-1, 64 * 7 * 7)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x
# Define a function to train a model on local data
def train_local_model(model, data_loader, epochs, lr):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr)
    model.train()
    for _ in range(epochs):
        for inputs, targets in data_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
    return model.state_dict()

# Define a function to evaluate the global model
def evaluate_model(model, test_loader):
    criterion = nn.CrossEntropyLoss()
    model.eval()
    correct = 0
    total = 0
    test_loss = 0.0
    with torch.no_grad():
        for inputs, targets in test_loader:
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

    accuracy = 100 * correct / total
    avg_loss = test_loss / len(test_loader)
    return accuracy, avg_loss
# Load and preprocess the MNIST dataset
def load_mnist_data(num_agents):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    mnist_train = datasets.MNIST('.', train=True, download=True, transform=transform)
    mnist_test = datasets.MNIST('.', train=False, download=True, transform=transform)

    # Split the dataset among agents
    data_len = len(mnist_train) // num_agents
    data_loaders = []
    for i in range(num_agents):
        indices = list(range(i * data_len, (i + 1) * data_len))
        subset = Subset(mnist_train, indices)
        data_loader = DataLoader(subset, batch_size=32, shuffle=True)
        data_loaders.append(data_loader)

    # Create a test data loader
    test_loader = DataLoader(mnist_test, batch_size=1000, shuffle=False)

    return data_loaders, test_loader

# FedAVG

In [None]:
# Simulate the federated learning process
def federated_averaging(global_model, data_loaders, test_loader, epochs, lr, rounds):
    num_agents = len(data_loaders)
    for round in range(rounds):
        local_models = []
        for i in range(num_agents):
            local_model = SimpleCNN()
            local_model.load_state_dict(global_model.state_dict())
            local_state_dict = train_local_model(local_model, data_loaders[i], epochs, lr)
            local_models.append(local_state_dict)

        global_state_dict = global_model.state_dict()
        for key in global_state_dict.keys():
            global_state_dict[key] = sum(local_model[key] for local_model in local_models) / num_agents

        global_model.load_state_dict(global_state_dict)

        # Evaluate the global model on the test set
        accuracy, avg_loss = evaluate_model(global_model, test_loader)
        print(f"Round {round + 1}/{rounds} - Accuracy: {accuracy:.2f}%, Loss: {avg_loss:.4f}")

    return global_model

In [None]:
# Initialize the global model
global_model = SimpleCNN()

# Set parameters
epochs = 1
lr = 0.01
rounds = 10
num_agents = 5

# Load data and create data loaders for each agent
data_loaders, test_loader = load_mnist_data(num_agents)

# Perform federated averaging
global_model = federated_averaging(global_model, data_loaders, test_loader, epochs, lr, rounds)

# Save the final global model
torch.save(global_model.state_dict(), "fedavg_mnist_model.pth")

Round 1/10 - Accuracy: 91.35%, Loss: 0.3107
Round 2/10 - Accuracy: 94.35%, Loss: 0.1946
Round 3/10 - Accuracy: 95.72%, Loss: 0.1426
Round 4/10 - Accuracy: 96.57%, Loss: 0.1153
Round 5/10 - Accuracy: 97.05%, Loss: 0.0970
Round 6/10 - Accuracy: 97.30%, Loss: 0.0848
Round 7/10 - Accuracy: 97.48%, Loss: 0.0775
Round 8/10 - Accuracy: 97.70%, Loss: 0.0711
Round 9/10 - Accuracy: 97.95%, Loss: 0.0655
Round 10/10 - Accuracy: 98.07%, Loss: 0.0605


# Scaffold

In [None]:

import numpy as np
# Define a function to train a model on local data
def train_local_model(model, data_loader, epochs, lr, c_local, c_global):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr)
    model.train()
    for _ in range(epochs):
        for inputs, targets in data_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()

            # Adjust gradients using control variates
            with torch.no_grad():
                for param, c_l, c_g in zip(model.parameters(), c_local, c_global):
                    param.grad = param.grad - c_l + c_g

            optimizer.step()
    return model.state_dict()

def scaffold(global_model, data_loaders, test_loader, epochs, lr, rounds):
    num_agents = len(data_loaders)
    global_control_variate = [torch.zeros_like(param) for param in global_model.parameters()]
    local_control_variates = [[torch.zeros_like(param) for param in global_model.parameters()] for _ in range(num_agents)]

    for round in range(rounds):
        local_models = []
        for i in range(num_agents):
            local_model = SimpleCNN()
            local_model.load_state_dict(global_model.state_dict())
            local_state_dict = train_local_model(local_model, data_loaders[i], epochs, lr, local_control_variates[i], global_control_variate)
            local_models.append(local_state_dict)

            # Update local control variates
            for local_cv, param, global_cv in zip(local_control_variates[i], local_model.parameters(), global_control_variate):
                if param.grad is not None:  # Ensure that the gradient exists
                    local_cv += param.grad - global_cv

        # Update global model parameters
        global_state_dict = global_model.state_dict()
        for key in global_state_dict.keys():
            global_state_dict[key] = sum(local_model[key] for local_model in local_models) / num_agents

        global_model.load_state_dict(global_state_dict)

        # Update global control variates
        for global_cv, param in zip(global_control_variate, global_model.parameters()):
            if param.grad is not None:  # Ensure that the gradient exists
                global_cv += sum(local_model[key] for local_model in local_models) / num_agents - param.grad

        # Evaluate the global model on the test set
        accuracy, avg_loss = evaluate_model(global_model, test_loader)
        print(f"Round {round + 1}/{rounds} - Accuracy: {accuracy:.2f}%, Loss: {avg_loss:.4f}")

    return global_model

In [None]:
# Initialize the global model
global_model = SimpleCNN()

# Set parameters
epochs = 1
lr = 0.01
rounds = 10
num_agents = 5

# Load data and create data loaders for each agent
data_loaders, test_loader = load_mnist_data(num_agents)

# Perform SCAFFOLD federated learning
global_model = scaffold(global_model, data_loaders, test_loader, epochs, lr, rounds)

# Save the final global model
torch.save(global_model.state_dict(), "scaffold_mnist_model.pth")

Round 1/10 - Accuracy: 91.22%, Loss: 0.3286
Round 2/10 - Accuracy: 10.15%, Loss: 5.2164
Round 3/10 - Accuracy: 10.09%, Loss: 229.3089
Round 4/10 - Accuracy: 9.74%, Loss: 72.1614
Round 5/10 - Accuracy: 12.98%, Loss: 5.8577
Round 6/10 - Accuracy: 31.63%, Loss: 4.1285
Round 7/10 - Accuracy: 10.10%, Loss: 3.1709
Round 8/10 - Accuracy: 11.32%, Loss: 4.6196
Round 9/10 - Accuracy: 10.54%, Loss: 15.8356
Round 10/10 - Accuracy: 9.25%, Loss: 8.4898
