# Split Learning on CIFAR-10 Using ResNet-18.
This notebook demonstrates Split Learning using the CIFAR-10 dataset. The experiment involves:

1. Data Preparation: This ensures that the data distribution varies across clients, simulating a real-world federated learning setup where data is typically non-IID (non-identically and independently distributed).

2.   Client-Server Architecture: Each client trains the first few layers of a ResNet-18 model (client-side model), while the server trains the final layers (server-side model). Only intermediate representations are exchanged between the client and server, preserving data privacy.
3.   Federated Averaging: Five clients are used, each trained on a unique random subset of 500 samples. After training, the server models are averaged to create a global model.
4. Inference and Evaluation: Individual client models are evaluated on a common test subset of 50 samples. The averaged global model is also evaluated to compare its performance with individual client models.


This setup demonstrates the feasibility of combining Federated Learning and Split Learning for privacy-preserving distributed training while leveraging model aggregation for enhanced performance.

In [None]:
# import all dependencies
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models

In [None]:
# Function to prepare datasets and loaders
def get_datasets():
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    full_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    client_subsets = [torch.utils.data.Subset(full_dataset, torch.randperm(len(full_dataset))[:500]) for _ in range(5)]
    train_loaders = [torch.utils.data.DataLoader(subset, batch_size=32, shuffle=True) for subset in client_subsets]

    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    test_subset = torch.utils.data.Subset(test_dataset, torch.randperm(len(test_dataset))[:50])
    test_loader = torch.utils.data.DataLoader(test_subset, batch_size=32, shuffle=False)

    return train_loaders, test_loader

In [None]:
# Function to create ResNet-based client and server models
def get_resnet_model(client_layers, server_start, server_end, output_classes=10):
    class ClientModel(nn.Module):
        def __init__(self):
            super(ClientModel, self).__init__()
            resnet = models.resnet18(pretrained=True)
            self.client_part = nn.Sequential(*list(resnet.children())[:client_layers])

        def forward(self, x):
            return self.client_part(x)

    class ServerModel(nn.Module):
        def __init__(self):
            super(ServerModel, self).__init__()
            resnet = models.resnet18(pretrained=True)
            self.feature_extractor = nn.Sequential(*list(resnet.children())[server_start:server_end])
            self.classifier = nn.Linear(512, output_classes)

        def forward(self, x):
            x = self.feature_extractor(x)
            x = torch.flatten(x, 1)
            return self.classifier(x)

    return ClientModel, ServerModel

In [None]:
# Function to train a client
def train_client(client_model, server_model, train_loader, criterion, optimizer, device, epochs=5):
    client_model.eval()  # Client model is frozen
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            with torch.no_grad():
                intermediate_output = client_model(inputs)
            optimizer.zero_grad()
            outputs = server_model(intermediate_output)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch {epoch + 1}, Loss: {running_loss / len(train_loader):.4f}")



In [None]:
# Federated averaging function
def federated_averaging(server_models, num_clients):
    averaged_state_dict = {
        key: torch.stack([server_models[c].state_dict()[key].float() for c in range(num_clients)]).mean(0)
        if server_models[0].state_dict()[key].dtype in [torch.float32, torch.float64] else server_models[0].state_dict()[key]
        for key in server_models[0].state_dict()
    }
    averaged_model = server_models[0].__class__()
    averaged_model.load_state_dict(averaged_state_dict)
    return averaged_model

In [None]:
# Function to evaluate models
def evaluate_model(client_model, server_model, test_loader, device):
    client_model.eval()
    server_model.eval()
    correct, total = 0, 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            intermediate_output = client_model(inputs)
            outputs = server_model(intermediate_output)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return 100 * correct / total

In [None]:

# Main script
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loaders, test_loader = get_datasets()
    ClientModel, ServerModel = get_resnet_model(1, 1, -1)

    num_clients = 5
    epochs = 5

    client_models = [ClientModel().to(device) for _ in range(num_clients)]
    server_models = [ServerModel().to(device) for _ in range(num_clients)]

    criterion = nn.CrossEntropyLoss()

    # Train each client's server model
    for client_id in range(num_clients):
        print(f"\nTraining Client {client_id + 1}")
        server_optimizer = optim.SGD(server_models[client_id].classifier.parameters(), lr=0.001, momentum=0.9)
        train_client(client_models[client_id], server_models[client_id], train_loaders[client_id], criterion, server_optimizer, device, epochs)

    # Federated Averaging
    averaged_server_model = federated_averaging(server_models, num_clients).to(device)

    # Evaluation
    print("\nEvaluating Individual Client Models")
    for client_id in range(num_clients):
        accuracy = evaluate_model(client_models[client_id], server_models[client_id], test_loader, device)
        print(f"Client {client_id + 1} Accuracy: {accuracy:.2f}%")

    print("\nEvaluating Averaged Model")
    average_accuracy = evaluate_model(client_models[0], averaged_server_model, test_loader, device)
    print(f"Averaged Model Accuracy: {average_accuracy:.2f}%")

Files already downloaded and verified
Files already downloaded and verified

Training Client 1
Epoch 1, Loss: 2.3497
Epoch 2, Loss: 2.0476
Epoch 3, Loss: 1.7757
Epoch 4, Loss: 1.5271
Epoch 5, Loss: 1.3528

Training Client 2
Epoch 1, Loss: 2.3116
Epoch 2, Loss: 2.0590
Epoch 3, Loss: 1.7802
Epoch 4, Loss: 1.5473
Epoch 5, Loss: 1.3919

Training Client 3
Epoch 1, Loss: 2.3843
Epoch 2, Loss: 2.0683
Epoch 3, Loss: 1.7760
Epoch 4, Loss: 1.5524
Epoch 5, Loss: 1.3858

Training Client 4
Epoch 1, Loss: 2.3718
Epoch 2, Loss: 2.0571
Epoch 3, Loss: 1.7492
Epoch 4, Loss: 1.5136
Epoch 5, Loss: 1.3368

Training Client 5
Epoch 1, Loss: 2.4294
Epoch 2, Loss: 2.1347
Epoch 3, Loss: 1.8136
Epoch 4, Loss: 1.5661
Epoch 5, Loss: 1.4142

Evaluating Individual Client Models
Client 1 Accuracy: 60.00%
Client 2 Accuracy: 54.00%
Client 3 Accuracy: 64.00%
Client 4 Accuracy: 54.00%
Client 5 Accuracy: 44.00%

Evaluating Averaged Model
Averaged Model Accuracy: 64.00%
