# Federated Learning on CIFAR-10 Using ResNet-18.

This notebook demonstrates a simple implementation of Federated Learning using the ResNet18 model for image classification on the CIFAR-10 dataset. The process includes the following steps:


1.   Dataset Preparation: CIFAR-10 dataset is split into subsets for multiple clients, each with 500 samples.
2.   Client-Side Training: Each client trains a model locally using a subset of the CIFAR-10 dataset. Only the final classification layer of ResNet18 is trained, while the pre-trained weights of the rest of the layers are frozen.
3. Model Aggregation: After each client completes training, their model weights are aggregated using Federated Averaging, which averages the weights of each model.
4. Evaluation: The aggregated model is evaluated on a small test subset, and the accuracy of each individual client model is also assessed.

This approach demonstrates the core principles of Federated Learning, where models are trained locally on different clients and then combined to create a global model.

In [None]:
# import all dependencies
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import torch.optim as optim
import random
import os

In [None]:
# Step 1: Define Transform and Dataset
def get_datasets():
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    cifar_train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    cifar_test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    return cifar_train, cifar_test


In [None]:
# Step 2: Define the Model
def get_resnet_model(num_classes=10):
    model = models.resnet18(pretrained=True)
    for param in model.parameters():
        param.requires_grad = False
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model

In [None]:
# Step 3: Train a Single Client
def train_client(client_id, train_loader, num_epochs=10):
    print("\n")
    print(f"Training Client {client_id}")
    model = get_resnet_model()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)

    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(train_loader):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Client {client_id}, Epoch {epoch + 1}, Loss: {running_loss / (i + 1):.4f}")
    return model

In [None]:
# Step 4: Save Model Weights
def save_model(model, filename):
    path = os.path.join('/content/', filename)
    torch.save(model.state_dict(), path)
    print(f"Model saved as {path}")
    return path


In [None]:
# Step 5: Aggregate Weights
def federated_averaging(models):
    print("\nPerforming Federated Averaging...")
    avg_weights = {key: torch.mean(torch.stack([model.state_dict()[key].float() for model in models]), dim=0)
                   for key in models[0].state_dict()}
    aggregated_model = get_resnet_model()
    aggregated_model.load_state_dict(avg_weights)
    return aggregated_model

In [None]:

# Step 6: Evaluate a Model
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

In [None]:

# Main Function
if __name__ == "__main__":
    # Prepare Datasets
    cifar_train, cifar_test = get_datasets()

    # Federated Learning Setup
    num_clients = 5
    num_samples_per_client = 500
    num_epochs = 5
    train_loaders = []
    models_list = []

    for client_id in range(num_clients):
        indices = torch.randperm(len(cifar_train))[:num_samples_per_client]
        subset = torch.utils.data.Subset(cifar_train, indices)
        train_loader = torch.utils.data.DataLoader(subset, batch_size=32, shuffle=True)
        train_loaders.append(train_loader)

        model = train_client(client_id + 1, train_loader, num_epochs=num_epochs)
        models_list.append(model)

        # Save Model
        save_model(model, f"resnet18_client_{client_id + 1}.pt")

    # Perform Federated Averaging
    aggregated_model = federated_averaging(models_list)
    save_model(aggregated_model, "avg_resnet18.pt")

    # Test Aggregated Model
    test_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(cifar_test, random.sample(range(len(cifar_test)), 50)), batch_size=1)
    aggregated_accuracy = evaluate_model(aggregated_model, test_loader)
    print("\nEvaluating Averaged Model")
    print(f"Averaged Model Accuracy: {aggregated_accuracy:.2f}%")

    # Test Individual Models
    print("\nEvaluating Individual Client Models")
    for client_id, model in enumerate(models_list):
        accuracy = evaluate_model(model, test_loader)
        print(f"Client {client_id + 1} Model Accuracy: {accuracy:.2f}%")

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:17<00:00, 9.60MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


Training Client 1


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 198MB/s]


Client 1, Epoch 1, Loss: 2.4114
Client 1, Epoch 2, Loss: 2.0779
Client 1, Epoch 3, Loss: 1.7546
Client 1, Epoch 4, Loss: 1.5381
Client 1, Epoch 5, Loss: 1.3699
Model saved as /content/resnet18_client_1.pt


Training Client 2
Client 2, Epoch 1, Loss: 2.4401
Client 2, Epoch 2, Loss: 2.1313
Client 2, Epoch 3, Loss: 1.8229
Client 2, Epoch 4, Loss: 1.5910
Client 2, Epoch 5, Loss: 1.3885
Model saved as /content/resnet18_client_2.pt


Training Client 3
Client 3, Epoch 1, Loss: 2.3791
Client 3, Epoch 2, Loss: 2.0709
Client 3, Epoch 3, Loss: 1.8133
Client 3, Epoch 4, Loss: 1.5934
Client 3, Epoch 5, Loss: 1.4287
Model saved as /content/resnet18_client_3.pt


Training Client 4
Client 4, Epoch 1, Loss: 2.3260
Client 4, Epoch 2, Loss: 2.0214
Client 4, Epoch 3, Loss: 1.7551
Client 4, Epoch 4, Loss: 1.5411
Client 4, Epoch 5, Loss: 1.3835
Model saved as /content/resnet18_client_4.pt


Training Client 5
Client 5, Epoch 1, Loss: 2.4222
Client 5, Epoch 2, Loss: 2.0690
Client 5, Epoch 3, Loss: 1.7768
Clie