# Dataset & model setup

In [None]:
# Import modules
import os
import sys
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import random
import torch.nn.functional as F
import copy
import matplotlib.pyplot as plt
import numpy as np

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Load MNIST Dataset
d = './data'
if not os.path.exists(d):
    os.mkdir(d)

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
train_set = datasets.MNIST(root=d, train=True, transform=trans, download=True)
test_set = datasets.MNIST(root=d, train=False, transform=trans, download=True)

batch_size = 32
global_train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
global_test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)


In [None]:
# Define MNIST model
class MLP(nn.Module):

    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

# Client Creation

In [None]:
n_clients = 10

def create_client(client_id, local_dataset, batch_size=32):
    model = MLP().to(device)
    loader = torch.utils.data.DataLoader(
                  dataset=local_dataset,
                  batch_size=batch_size,
                  shuffle=True)
    return {"client_id": client_id,
            "local_dataset": loader,
            "local_model" : model,
            "optimizer": optim.SGD(model.parameters(), lr=0.01, momentum=0.9)}

# Partition datapoints
local_datasets = [[] for i in range(n_clients)]
for i, datapoint in enumerate(train_set):
    local_datasets[i%n_clients].append(datapoint)

# Create clients
clients = [create_client(i, local_datasets[i]) for i in range(n_clients)]

# Federated Learning Training

In [None]:
def client_load_model(client, model):
    client["local_model"].load_state_dict(model.state_dict())

def client_local_training(client):
    criterion = nn.CrossEntropyLoss()
    optimizer = client["optimizer"]
    dataset = client["local_dataset"]
    model = client["local_model"].to(device)
    model.train()

    for batch_idx, (x, target) in enumerate(dataset):
        x, target = x.to(device), target.to(device)
        out = model(x)
        loss = criterion(out, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def server_aggregate_models(clients):
    models = [x["local_model"].state_dict() for x in clients]
    averaged_model = copy.deepcopy(models[0])
    with torch.no_grad():
        for k in averaged_model.keys():
            averaged_model[k] = sum(m[k] for m in models) / len(models)
    return averaged_model

def evaluate(model, dataset):
    model.to(device)
    model.eval()
    criterion = nn.CrossEntropyLoss()
    total_loss, total, correct = 0, 0, 0

    with torch.no_grad():
        for x, target in dataset:
            x, target = x.to(device), target.to(device)
            out = model(x)
            loss = criterion(out, target)
            _, pred_label = torch.max(out, 1)
            total_loss += loss.item()
            total += x.size(0)
            correct += (pred_label == target).sum().item()

    avg_loss = total_loss / len(dataset)
    acc = correct / total
    print(f"loss: {avg_loss:.4f}, acc: {acc:.4f}")
    return avg_loss, acc

def federated_learning_loop(save_dir):
    # Set seed
    torch.manual_seed(42)
    global_model = MLP().to(device)

    client_participation_fraction = 0.2
    rounds = 1000
    os.makedirs(save_dir, exist_ok=True)

    metrics = {
        "round": [],
        "loss": [],
        "accuracy": []
    }
    best_acc = 0.0

    for r in range(rounds):
        print(f"\nRound {r+1}")

        # Broadcast global model to clients
        for c in clients:
            client_load_model(c, global_model)

        # Selected clients perform local training
        for c in clients:
            if random.random() <= client_participation_fraction:
                client_local_training(c)

        # Aggregate models from participating clients
        aggregated_model = server_aggregate_models(clients)

        # Update global model
        global_model.load_state_dict(aggregated_model)

        # Evaluation
        loss, acc = evaluate(global_model, global_test_loader)

        # Save metrics for plotting
        metrics["round"].append(r+1)
        metrics["loss"].append(loss)
        metrics["accuracy"].append(acc)

        # Save model periodically
        save_every = 50
        if (r + 1) % save_every == 0:
            torch.save(global_model.state_dict(), f"{save_dir}/model_round_{r+1}.pth")
            print(f"Saved checkpoint: model_round_{r+1}.pth")

        # Save best model so far
        if acc > best_acc:
            best_acc = acc
            torch.save(global_model.state_dict(), f"{save_dir}/best_model.pth")
            print(f"New best model saved (acc={acc:.4f})")

    # Save final model
    torch.save(global_model.state_dict(), f"{save_dir}/final_model.pth")
    print("Final model saved")

    return metrics

In [None]:
# Run loop
save_dir = "checkpoints"
metrics = federated_learning_loop(save_dir)

In [None]:
def plot_metrics(metrics, filename):
    rounds = metrics["round"]
    loss = np.array(metrics["loss"])
    acc = np.array(metrics["accuracy"])

    plt.style.use('seaborn-v0_8-paper')

    plt.figure(figsize=(12,5))

    plt.subplot(1,2,1)
    plt.plot(rounds, loss, label="Loss", color='tab:blue', linewidth=2)
    plt.xlabel("Round", fontsize=14)
    plt.ylabel("Loss", fontsize=14)
    plt.title("Global Model Loss", fontsize=16)
    #plt.legend(fontsize=12)
    plt.grid(True)

    plt.subplot(1,2,2)
    plt.plot(rounds, acc, label="Accuracy", color='tab:orange', linewidth=2)
    plt.xlabel("Round", fontsize=14)
    plt.ylabel("Accuracy", fontsize=14)
    plt.title("Global Model Accuracy", fontsize=16)
    #plt.legend(fontsize=12)
    plt.grid(True)

    plt.tight_layout()
    plt.savefig(filename, dpi=300)
    plt.show()
    print(f"Plot saved to {filename}")

In [None]:
# Plotting metrics
filename = "training_metrics.png"
plot_metrics(metrics, filename)

In [None]:
!zip -r /content/{save_dir}.zip /content/{save_dir}/
from google.colab import files
files.download(f"/content/{save_dir}.zip")
files.download(f"/content/{filename}")