In [None]:
import numpy as np
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

In [19]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.dropout1(F.max_pool2d(x, 2))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(self.dropout2(x))
        x = F.softmax(x, dim=1)
        return x

In [None]:
def get_mnist_dataset():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('../data', train=False, transform=transform)

    return train_dataset, test_dataset

In [None]:
def distribute_client_data(dataset, num_clients, iid=True):
    if iid:
        num_items_per_client = len(dataset) // num_clients
        client_datasets = []

        indices = torch.randperm(len(dataset))
        for i in range(num_clients):
            start_idx = i * num_items_per_client
            end_idx = (i + 1) * num_items_per_client if i < num_clients - 1 else len(dataset)
            client_indices = indices[start_idx:end_idx]
            client_datasets.append(Subset(dataset, client_indices))
    else:
        labels = dataset.targets.numpy()
        sorted_indices = np.argsort(labels)
        client_datasets = []
        shards_per_client = 2

        num_shards = num_clients * shards_per_client
        items_per_shard = len(dataset) // num_shards
        shard_indices = []

        for i in range(num_shards):
            start_idx = i * items_per_shard
            end_idx = (i + 1) * items_per_shard if i < num_shards - 1 else len(sorted_indices)
            shard_indices.append(sorted_indices[start_idx:end_idx])

        np.random.shuffle(shard_indices)

        for i in range(num_clients):
            client_idx = []
            for j in range(shards_per_client):
                client_idx.extend(shard_indices[i * shards_per_client + j])
            client_datasets.append(Subset(dataset, client_idx))

    return client_datasets

In [22]:
class Client:
    def __init__(self, dataset, client_id, device):
        self.dataset = dataset
        self.client_id = client_id
        self.device = device
        self.model = CNN().to(device)
        self.dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

    def train(self, epochs=1):
        self.model.train()
        optimizer = optim.Adam(self.model.parameters(), lr=0.001)

        for epoch in range(epochs):
            total_loss = 0
            for batch_idx, (data, target) in enumerate(self.dataloader):
                data, target = data.to(self.device), target.to(self.device)

                optimizer.zero_grad()
                output = self.model(data)
                loss = F.cross_entropy(output, target)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()

        return total_loss / len(self.dataloader)

    def evaluate(self, test_loader):
        self.model.eval()
        test_loss = 0
        correct = 0

        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                test_loss += F.cross_entropy(output, target, reduction='sum').item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)
        accuracy = 100. * correct / len(test_loader.dataset)

        return test_loss, accuracy

    def get_parameters(self):
        return {k: v.cpu() for k, v in self.model.state_dict().items()}

    def set_parameters(self, parameters):
        params_on_device = {k: v.to(self.device) for k, v in parameters.items()}
        self.model.load_state_dict(params_on_device)

In [23]:
class Server:
    def __init__(self, test_dataset, device):
        self.clients = []
        self.device = device
        self.global_model = CNN().to(device)
        self.test_loader = DataLoader(test_dataset, batch_size=128)

    def add_client(self, client):
        self.clients.append(client)

    def aggregate_parameters(self, client_parameters):
        global_dict = OrderedDict()

        for k in client_parameters[0].keys():
            global_dict[k] = torch.stack([client_parameters[i][k] for i in range(len(client_parameters))], 0).mean(0)

        return global_dict

    def update_global_model(self):
        client_parameters = [client.get_parameters() for client in self.clients]
        global_parameters = self.aggregate_parameters(client_parameters)

        self.global_model.load_state_dict({k: v.to(self.device) for k, v in global_parameters.items()})

        for client in self.clients:
            client.set_parameters(global_parameters)

        return client_parameters[0]  # Return client 1's parameters for potential attack

    def evaluate_global_model(self):
        self.global_model.eval()
        test_loss = 0
        correct = 0

        with torch.no_grad():
            for data, target in self.test_loader:
                data, target = data.to(self.device), target.to(self.device)
                output = self.global_model(data)
                test_loss += F.cross_entropy(output, target, reduction='sum').item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(self.test_loader.dataset)
        accuracy = 100. * correct / len(self.test_loader.dataset)

        return test_loss, accuracy

In [None]:
def run_federated_learning(num_clients=3, num_rounds=3, local_epochs=1, iid=False, device='cpu'):
    print(f"Using device: {device}")

    train_dataset, test_dataset = get_mnist_dataset()
    client_datasets = distribute_client_data(train_dataset, num_clients, iid=iid)

    server = Server(test_dataset, device)
    clients = []
    client_params_list = []

    for i in range(num_clients):
        client = Client(client_datasets[i], i, device)
        clients.append(client)
        server.add_client(client)

    global_accuracies = []
    client_losses = [[] for _ in range(num_clients)]

    # Federated learning loop
    for round_num in range(num_rounds):
        client_params_list.append([])

        print(f"\nRound {round_num+1}/{num_rounds}")

        for i, client in enumerate(clients):
            loss = client.train(epochs=local_epochs)
            client_losses[i].append(loss)
            print(f"Client {i+1} loss: {loss:.4f}")

            client_params_list[round_num].append(client.get_parameters())

        server.update_global_model()
        test_loss, accuracy = server.evaluate_global_model()
        global_accuracies.append(accuracy)
        print(f"Global model - Test loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%")

    return (global_accuracies, client_losses), (server, clients, client_params_list)

In [28]:
num_clients = 10
device = (
    'mps' if torch.mps.is_available()
    else 'cuda' if torch.cuda.is_available()
    else 'cpu'
)

(global_accuracies, client_losses), (server, clients, client_params_list) = run_federated_learning(
    num_clients=10,
    num_rounds=10,
    local_epochs=1,
    iid=True,
    device=device
)

model_params = [clients[i].get_parameters() for i in range(num_clients)]

Using device: mps

Round 1/10




Client 1 loss: 1.8023
Client 2 loss: 1.7548
Client 3 loss: 1.7869
Client 4 loss: 1.8559
Client 5 loss: 1.7742
Client 6 loss: 1.7964
Client 7 loss: 1.8075
Client 8 loss: 1.7904
Client 9 loss: 1.7917
Client 10 loss: 1.7837
Global model - Test loss: 2.3022, Accuracy: 16.31%

Round 2/10
Client 1 loss: 1.8384
Client 2 loss: 1.8285
Client 3 loss: 1.8238
Client 4 loss: 1.8366
Client 5 loss: 1.8291
Client 6 loss: 1.8364
Client 7 loss: 1.8331
Client 8 loss: 1.8176
Client 9 loss: 1.8366
Client 10 loss: 1.8377
Global model - Test loss: 1.6013, Accuracy: 86.92%

Round 3/10
Client 1 loss: 1.6747
Client 2 loss: 1.6663
Client 3 loss: 1.6617
Client 4 loss: 1.6657
Client 5 loss: 1.6690
Client 6 loss: 1.6701
Client 7 loss: 1.6639
Client 8 loss: 1.6649
Client 9 loss: 1.6713
Client 10 loss: 1.6879
Global model - Test loss: 1.5606, Accuracy: 90.20%

Round 4/10
Client 1 loss: 1.6394
Client 2 loss: 1.6333
Client 3 loss: 1.6396
Client 4 loss: 1.6375
Client 5 loss: 1.6348
Client 6 loss: 1.6276
Client 7 loss: 1

In [30]:
def average_models(param_list):
    avg_params = OrderedDict()
    
    for key in param_list[0].keys():
        avg_params[key] = torch.stack([params[key].float() for params in param_list], dim=0).mean(dim=0)

    return avg_params

global_params = average_models(model_params)

global_model = CNN()
global_model.load_state_dict(global_params)
torch.save(global_model.state_dict(), '../models/federated_model.pth')