In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from collections import OrderedDict
import numpy as np

class MyCNNModel(nn.Module):
    def __init__(self):
        super(MyCNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def train_model(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = nn.CrossEntropyLoss()(output, target)
        loss.backward()
        optimizer.step()

def test_model(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += nn.CrossEntropyLoss(reduction='sum')(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    return test_loss / len(test_loader.dataset), correct / len(test_loader.dataset)

def average_weights(weight_list):
    avg_weights = OrderedDict()
    for key in weight_list[0].keys():
        avg_weights[key] = sum([client_weights[key] for client_weights in weight_list]) / len(weight_list)
    return avg_weights
    
def create_client_dataset(dataset, num_clients):
    indices = list(range(len(dataset)))
    np.random.shuffle(indices)
    client_indices = np.array_split(indices, num_clients)
    client_datasets = [Subset(dataset, index_list) for index_list in client_indices]
    return client_datasets

# Load the CIFAR-10 dataset
transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

num_clients = 5
client_datasets = create_client_dataset(train_dataset, num_clients)
train_loaders = [DataLoader(client_dataset, batch_size=100, shuffle=True) for client_dataset in client_datasets]
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 10


def federated_learning(train_loaders, test_loader, device, epochs, num_clients, communication_rounds, client_weighting, clients_per_round):
    global_model = MyCNNModel().to(device)
    client_models = [MyCNNModel().to(device) for _ in range(num_clients)]

    for client_model in client_models:
        client_model.load_state_dict(global_model.state_dict())

    for com_round in range(1, communication_rounds + 1):
        print(f'Communication round {com_round}/{communication_rounds}')

        selected_clients = np.random.choice(range(num_clients), size=clients_per_round, replace=False)
        client_weights = []

        for idx in selected_clients:
            client_model = client_models[idx]
            optimizer = optim.SGD(client_model.parameters(), lr=0.01, momentum=0.9)
            train_model(client_model, device, train_loaders[idx], optimizer, epochs)

            client_weight = client_weighting[idx] if client_weighting else 1
            weighted_client_state = OrderedDict()

            for key in client_model.state_dict().keys():
                weighted_client_state[key] = client_model.state_dict()[key] * client_weight

            client_weights.append(weighted_client_state)

        global_weights = average_weights(client_weights)
        global_model.load_state_dict(global_weights)

        test_loss, accuracy = test_model(global_model, device, test_loader)
        print(f'Test loss: {test_loss:.4f}, Accuracy: {accuracy * 100:.2f}%\n')

    return global_model

communication_rounds = 50
client_weighting = [1, 1, 1, 1, 1]  # Equal weighting for all clients
clients_per_round = 5

federated_model = federated_learning(train_loaders, test_loader, device, epochs, num_clients, communication_rounds, client_weighting, clients_per_round)


Files already downloaded and verified
Files already downloaded and verified
Communication round 1/50
Test loss: 2.2610, Accuracy: 14.68%

Communication round 2/50
Test loss: 2.0581, Accuracy: 24.30%

Communication round 3/50
Test loss: 1.9453, Accuracy: 27.94%

Communication round 4/50
Test loss: 1.8436, Accuracy: 31.63%

Communication round 5/50
Test loss: 1.7379, Accuracy: 35.24%

Communication round 6/50
Test loss: 1.7029, Accuracy: 37.77%

Communication round 7/50
Test loss: 1.6659, Accuracy: 38.32%

Communication round 8/50
Test loss: 1.6211, Accuracy: 40.57%

Communication round 9/50
Test loss: 1.5851, Accuracy: 42.77%

Communication round 10/50
Test loss: 1.5666, Accuracy: 42.81%

Communication round 11/50
Test loss: 1.5409, Accuracy: 44.17%

Communication round 12/50
Test loss: 1.5442, Accuracy: 43.65%

Communication round 13/50
Test loss: 1.4969, Accuracy: 45.37%

Communication round 14/50
Test loss: 1.4994, Accuracy: 45.81%

Communication round 15/50
Test loss: 1.4793, Accura