# Dataset & model setup

In [None]:
# Import modules
import os
import sys
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.optim as optim
import random
import torch.nn.functional as F
import copy
import matplotlib.pyplot as plt
import numpy as np
import time
import io
import json

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Load MNIST Dataset
d = './data'
if not os.path.exists(d):
    os.mkdir(d)

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
train_set = datasets.MNIST(root=d, train=True, transform=trans, download=True)
test_set = datasets.MNIST(root=d, train=False, transform=trans, download=True)

batch_size = 32
global_train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
global_test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)

In [None]:
# Define MNIST model
class MLP(nn.Module):

    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

# Client Creation

In [None]:
n_clients = 10

def create_client(client_id, local_dataset, batch_size=32):
    model = MLP().to(device)
    loader = torch.utils.data.DataLoader(
                  dataset=local_dataset,
                  batch_size=batch_size,
                  shuffle=True)
    return {"client_id": client_id,
            "local_dataset": loader,
            "local_model" : model,
            "optimizer": optim.SGD(model.parameters(), lr=0.01, momentum=0.9)}

# Partition datapoints
local_datasets = [[] for i in range(n_clients)]
for i, datapoint in enumerate(train_set):
    local_datasets[i%n_clients].append(datapoint)

# Create clients
clients = [create_client(i, local_datasets[i]) for i in range(n_clients)]

# Federated Learning Training

In [None]:
def client_load_model(client, model):
    client["local_model"].load_state_dict(model.state_dict())

def client_local_training(client):
    criterion = nn.CrossEntropyLoss()
    optimizer = client["optimizer"]
    dataset = client["local_dataset"]
    model = client["local_model"].to(device)
    model.train()

    for batch_idx, (x, target) in enumerate(dataset):
        x, target = x.to(device), target.to(device)
        out = model(x)
        loss = criterion(out, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def server_aggregate_models(clients):
    models = [x["local_model"].state_dict() for x in clients]
    averaged_model = copy.deepcopy(models[0])
    with torch.no_grad():
        for k in averaged_model.keys():
            averaged_model[k] = sum(m[k] for m in models) / len(models)
    return averaged_model

def evaluate(model, dataset):
    model.to(device)
    model.eval()
    criterion = nn.CrossEntropyLoss()
    total_loss, total, correct = 0, 0, 0

    with torch.no_grad():
        for x, target in dataset:
            x, target = x.to(device), target.to(device)
            out = model(x)
            loss = criterion(out, target)
            _, pred_label = torch.max(out, 1)
            total_loss += loss.item()
            total += x.size(0)
            correct += (pred_label == target).sum().item()

    avg_loss = total_loss / len(dataset)
    acc = correct / total
    print(f"loss: {avg_loss:.4f}, acc: {acc:.4f}")
    return avg_loss, acc

def get_model_size(model):
    buffer = io.BytesIO()
    torch.save(model.state_dict(), buffer)
    size_MB = buffer.getbuffer().nbytes / 1e6
    return size_MB

def federated_learning_loop(save_dir):
    # Set seed
    torch.manual_seed(42)
    global_model = MLP().to(device)

    client_participation_fraction = 0.2
    rounds = 1000
    eval_interval = 10  # Measure costly metrics every X rounds
    save_interval = 50  # Save model every Y rounds
    os.makedirs(save_dir, exist_ok=True)

    metrics = {
        "round": [],
        "loss": [],
        "accuracy": [],
        "round_time": [],
        "model_size": []
    }
    best_acc = 0.0
    best_loss = np.inf

    for r in range(rounds):
        print(f"\nRound {r+1}")

        # Start the timer
        start_time = time.time()

        # Broadcast global model to clients
        for c in clients:
            client_load_model(c, global_model)

        # Selected clients perform local training
        for c in clients:
            if random.random() <= client_participation_fraction:
                client_local_training(c)

        # Aggregate models from participating clients
        aggregated_model = server_aggregate_models(clients)

        # Update global model
        global_model.load_state_dict(aggregated_model)

        # Stop the timer before computing the metrics
        round_time = time.time() - start_time

        # Evaluation
        loss, acc = evaluate(global_model, global_test_loader)

        if (r + 1) % eval_interval == 0 or r == 0 or r == rounds - 1:
            model_size = get_model_size(global_model)
        else:
            model_size = metrics["model_size"][-1]

        # Save metrics for plotting
        metrics["round"].append(r+1)
        metrics["round_time"].append(round_time)
        metrics["loss"].append(loss)
        metrics["accuracy"].append(acc)
        metrics["model_size"].append(model_size)

        # Save model periodically
        if (r + 1) % save_interval == 0:
            torch.save(global_model.state_dict(), f"{save_dir}/model_round_{r+1}.pth")
            print(f"Saved checkpoint: model_round_{r+1}.pth")

        # Save best model so far
        if acc > best_acc:
            best_acc = acc
            torch.save(global_model.state_dict(), f"{save_dir}/best_acc_model.pth")
            print(f"New best acc model saved (acc={acc:.4f})")
        if loss < best_loss:
            best_loss = loss
            torch.save(global_model.state_dict(), f"{save_dir}/best_loss_model.pth")
            print(f"New best loss model saved (loss={loss:.4f})")

    # Save final model
    torch.save(global_model.state_dict(), f"{save_dir}/final_model.pth")
    print("Final model saved")

    return metrics

In [None]:
def plot_metrics(metrics, filename):
    rounds = np.array(metrics["round"])
    loss = np.array(metrics["loss"])
    acc = np.array(metrics["accuracy"])
    round_time = np.array(metrics["round_time"])
    model_size = np.array(metrics["model_size"])

    plt.style.use('seaborn-v0_8-paper')

    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    axes = axes.ravel()

    # Loss
    axes[0].plot(rounds, loss, color='tab:blue', linewidth=2)
    axes[0].set_title("Global Model Loss", fontsize=15)
    axes[0].set_xlabel("Round", fontsize=13)
    axes[0].set_ylabel("Loss", fontsize=13)
    axes[0].grid(True)

    # Accuracy
    axes[1].plot(rounds, acc, color='tab:orange', linewidth=2)
    axes[1].set_title("Global Model Accuracy", fontsize=15)
    axes[1].set_xlabel("Round", fontsize=13)
    axes[1].set_ylabel("Accuracy", fontsize=13)
    axes[1].grid(True)

    # Round time
    axes[2].plot(rounds, round_time, color='tab:green', linewidth=2)
    axes[2].set_title("Round Duration", fontsize=15)
    axes[2].set_xlabel("Round", fontsize=13)
    axes[2].set_ylabel("Time (s)", fontsize=13)
    axes[2].grid(True)

    # Model size
    axes[3].plot(rounds, model_size, color='tab:red', linewidth=2)
    axes[3].set_title("Model Size", fontsize=15)
    axes[3].set_xlabel("Round", fontsize=13)
    axes[3].set_ylabel("Size (MB)", fontsize=13)
    axes[3].grid(True)

    plt.tight_layout()
    plt.savefig(filename, dpi=300, bbox_inches="tight")
    plt.show()

    print(f"Plot saved to {filename}")

In [None]:
def summarize_metrics(metrics, filename):
    rounds = np.array(metrics["round"])
    losses = np.array(metrics["loss"])
    accs = np.array(metrics["accuracy"])
    round_times = np.array(metrics["round_time"])
    model_sizes = np.array(metrics["model_size"])

    mean_round_time = np.nanmean(round_times)
    std_round_time = np.nanstd(round_times)
    total_training_time = np.nansum(round_times)

    final_loss = losses[-1]
    final_acc = accs[-1]
    best_loss = np.nanmin(losses)
    best_loss_round = rounds[np.nanargmin(losses)]
    best_acc = np.nanmax(accs)
    best_acc_round = rounds[np.nanargmax(accs)]

    initial_model_size = model_sizes[0]
    final_model_size = model_sizes[-1]
    best_acc_model_size = model_sizes[np.nanargmax(accs)]
    best_loss_model_size = model_sizes[np.nanargmin(losses)]

    size_reduction = initial_model_size - final_model_size
    size_reduction_pct = (size_reduction / initial_model_size) * 100 if initial_model_size > 0 else 0
    size_reduction_best_acc = initial_model_size - best_acc_model_size
    size_reduction_pct_best_acc = (size_reduction_best_acc / initial_model_size) * 100 if initial_model_size > 0 else 0
    size_reduction_best_loss = initial_model_size - best_loss_model_size
    size_reduction_pct_best_loss = (size_reduction_best_loss / initial_model_size) * 100 if initial_model_size > 0 else 0

    acc_per_mb = best_acc / best_acc_model_size if best_acc_model_size > 0 else np.nan

    summary = {
        "final_loss": float(final_loss),
        "best_loss": float(best_loss),
        "final_accuracy": float(final_acc),
        "best_accuracy": float(best_acc),
        "best_loss_round": int(best_loss_round),
        "best_acc_round": int(best_acc_round),
        "mean_round_time": float(mean_round_time),
        "std_round_time": float(std_round_time),
        "total_training_time": float(total_training_time),
        "initial_model_size_MB": float(initial_model_size),
        "final_model_size_MB": float(final_model_size),
        "best_acc_model_size_MB": float(best_acc_model_size),
        "best_loss_model_size_MB": float(best_loss_model_size),
        "size_reduction_MB": float(size_reduction),
        "size_reduction_pct": float(size_reduction_pct),
        "size_reduction_best_acc_MB": float(size_reduction_best_acc),
        "size_reduction_pct_best_acc": float(size_reduction_pct_best_acc),
        "size_reduction_best_loss_MB": float(size_reduction_best_loss),
        "size_reduction_pct_best_loss": float(size_reduction_pct_best_loss),
        "accuracy_per_MB": float(acc_per_mb)
    }

    print("=== Training Summary ===")
    print(f"Final loss: {summary['final_loss']:.4f}")
    print(f"Best loss: {summary['best_loss']:.4f}")
    print(f"Final accuracy: {summary['final_accuracy']:.4f}")
    print(f"Best accuracy: {summary['best_accuracy']:.4f}")
    print(f"Best loss round: {summary['best_loss_round']}")
    print(f"Best accuracy round: {summary['best_acc_round']}")
    print()
    print(f"Mean round time: {summary['mean_round_time']:.3f} s")
    print(f"Std round time: {summary['std_round_time']:.3f} s")
    print(f"Total training time: {summary['total_training_time']:.3f} s")
    print()
    print(f"Initial model size: {summary['initial_model_size_MB']:.3f} MB")
    print(f"Final model size: {summary['final_model_size_MB']:.3f} MB")
    print(f"Model size at best accuracy: {summary['best_acc_model_size_MB']:.3f} MB")
    print(f"Model size at best loss: {summary['best_loss_model_size_MB']:.3f} MB")
    print()
    print(f"Size reduction (final): {summary['size_reduction_MB']:.3f} MB ({summary['size_reduction_pct']:.2f}%)")
    print(f"Size reduction (best accuracy): {summary['size_reduction_best_acc_MB']:.3f} MB ({summary['size_reduction_pct_best_acc']:.2f}%)")
    print(f"Size reduction (best loss): {summary['size_reduction_best_loss_MB']:.3f} MB ({summary['size_reduction_pct_best_loss']:.2f}%)")
    print()
    print(f"Accuracy per MB: {summary['accuracy_per_MB']:.4f}")

    with open(filename, "w") as f:
        json.dump(summary, f, indent=4)

    print(f"\nSummary saved to {filename}")
    return summary

In [None]:
# Run loop
run_save_dir = "checkpoints"
metrics = federated_learning_loop(run_save_dir)

In [None]:
# Plot metrics
metrics_filename = "training_metrics.png"
plot_metrics(metrics, metrics_filename)

In [None]:
# Compute summary statistics
summary_filename = "summary.json"
summary = summarize_metrics(metrics, summary_filename)

In [None]:
!zip -r /content/{run_save_dir}.zip /content/{run_save_dir}/
from google.colab import files
files.download(f"/content/{run_save_dir}.zip")
files.download(f"/content/{metrics_filename}")
files.download(f"/content/{summary_filename}")