In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import copy
import random
from torch.utils.data import DataLoader, random_split

In [3]:
# Set seed for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
# Define the model architecture
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(64 * 5 * 5, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = x.view(-1, 64 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.log_softmax(x, dim=1)

# Function to simulate client training with MNIST data
def client_training(train_loader, model, criterion, optimizer, num_epochs=1):
    model.train()
    for epoch in range(num_epochs):
        total_loss, correct, total = 0.0, 0.0, 0.0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        accuracy = correct / total
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss}, Accuracy: {accuracy}")

# Function to test the trained model on client data
def client_testing(test_loader, model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

In [5]:
def federated_learning(num_clients=5, num_local_epochs=5, num_global_epochs=5):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    
    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths)


    client_indices = [dataset.indices for dataset in datasets]
    
    drift_clients = [2,6,9]

    # Initialize global model
    global_model = CNN().to(device)
    previous_shap_values_per_client = [None] * num_clients

    # Federated learning process
    for global_epoch in range(num_global_epochs):
        print(f"Global Epoch {global_epoch + 1}/{num_global_epochs}")

         # Swap classes if the current global epoch is more than half

        if global_epoch == 10: #num_global_epochs // 2:
            print("Swapping classes 0 and 1 in the train dataset")
            for client_id in drift_clients:
                indices = client_indices[client_id]
                for idx in indices:
#             for i in range(len(trainset.targets)):
                    if trainset.targets[idx] == 8:
                        trainset.targets[idx] = 3
                    elif trainset.targets[idx] == 3:
                        trainset.targets[idx] = 8

                    # Swap 5 and 6
                    elif trainset.targets[idx] == 6:
                        trainset.targets[idx] = 5
                    elif trainset.targets[idx] == 5:
                        trainset.targets[idx] = 6
                    
        client_train_loaders = [torch.utils.data.DataLoader(torch.utils.data.Subset(trainset, indices), batch_size=32, shuffle=True) for indices in client_indices]
        client_test_loaders = [torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False) for _ in range(num_clients)]
        client_models = [copy.deepcopy(global_model) for _ in range(num_clients)]


        # Perform local training and compute SHAP values
        shap_values_per_client = []
        for i, train_loader in enumerate(client_train_loaders):
            print(f"\tTraining Client {i + 1}/{num_clients}")
            criterion = nn.CrossEntropyLoss()
            optimizer = optim.Adam(client_models[i].parameters(), lr=0.001)
            client_training(train_loader, client_models[i], criterion, optimizer, num_local_epochs)


        # Aggregate client models into global model
        print("\tAggregating client models")
        global_model = aggregate_models(global_model, client_models)

    # Test models on each client
    accuracies = []
    for i, test_loader in enumerate(client_test_loaders):
        print(f"Testing Client {i + 1}/{num_clients}")
        accuracy = client_testing(test_loader, global_model)
        accuracies.append(accuracy)
        print(f"\tAccuracy for Client {i + 1}/{num_clients}: {accuracy}")

    return accuracies

# Aggregation of models' weights using FedAvg
def aggregate_models(global_model, client_models):
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_model.state_dict()[k].float() for client_model in client_models], 0).mean(0)
    global_model.load_state_dict(global_dict)
    return global_model

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Example usage
if __name__ == "__main__":
    client_accuracies = federated_learning(num_clients=60, num_local_epochs=1, num_global_epochs=40)
    print("Client Accuracies:", client_accuracies)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 36257086.68it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1059582.36it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 7970362.37it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2486365.02it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Global Epoch 1/40
	Training Client 1/60
Epoch 1/1, Loss: 49.79967486858368, Accuracy: 0.507
	Training Client 2/60
Epoch 1/1, Loss: 48.39054638147354, Accuracy: 0.55
	Training Client 3/60
Epoch 1/1, Loss: 47.43098384141922, Accuracy: 0.508
	Training Client 4/60
Epoch 1/1, Loss: 47.85269504785538, Accuracy: 0.533
	Training Client 5/60
Epoch 1/1, Loss: 50.38480889797211, Accuracy: 0.516
	Training Client 6/60
Epoch 1/1, Loss: 48.61078721284866, Accuracy: 0.528
	Training Client 7/60
Epoch 1/1, Loss: 51.79433339834213, Accuracy: 0.464
	Training Client 8/60
Epoch 1/1, Loss: 52.09694707393646, Accuracy: 0.494
	Training Client 9/60
Epoch 1/1, Loss: 51.132887452840805, Accuracy: 0.525
	Training Client 10/60
Epoch 1/1, Loss: 47.5314040184021, Accuracy: 0.56
	Training Client 11/60
Epoch 1/1, Loss: 47.49448761343956, Accuracy: 0.564
	Training Client 12/60
Epoch 1/1, Loss: 50.81087863445282, Accuracy: 0.506
	Training Client 1