In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import shap
from scipy.stats import mannwhitneyu, ttest_ind, ks_2samp, ttest_ind_from_stats
import copy
import random
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split

In [3]:
# Set seed for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
# Define the model architecture
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(64 * 5 * 5, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = x.view(-1, 64 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.log_softmax(x, dim=1)

# Function to simulate client training with MNIST data
def client_training(train_loader, model, criterion, optimizer, num_epochs=1):
    model.train()
    for epoch in range(num_epochs):
        total_loss, correct, total = 0.0, 0.0, 0.0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        accuracy = correct / total
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss}, Accuracy: {accuracy}")
        
    return total_loss
    

# Function to test the trained model on client data
def client_testing(test_loader, model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

In [6]:
# Federated learning process with multiple global model
def federated_learning(num_clients=5, num_local_epochs=5, num_global_epochs=5):
    
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths)
    
    client_indices = [dataset.indices for dataset in datasets]
    
    # clients that will experience concept drift
    # drift_clients = [2,6,9]
    # drift_clients = [12, 30, 31, 0, 32, 20, 53, 15, 5, 13, 42, 10, 55, 35, 41, 48, 58, 40, 38, 9, 21, 14, 24, 43, 6, 57, 37, 4, 46, 47, 50, 11, 34, 59, 2, 51, 1, 23, 28, 56, 19, 16]
    # drift_clients = [6, 50, 20, 5, 37, 42, 47, 13, 28, 15, 57, 10, 9, 31, 2, 41, 35, 1, 23, 58, 38, 46, 21, 48, 24, 43, 32, 4, 51, 56]

    # ground_truth_clients = [2, 6, 9, 10, 13, 15, 21, 23, 28, 35, 37, 41, 43, 47, 51, 56]
    drift_clients = [2, 6, 9, 28, 41, 56]

    # Initialize global model
    global_model = CNN().to(device)
    drastic_change_factor = 3  # Example threshold for drastic change
    previous_client_losses = [[] for _ in range(num_clients)]
    drastic_change_detected = [False] * num_clients
    sustained_high_loss = [False] * num_clients
    
    
    # Initialize detected drift clients
    detected_drift_clients = []
    fixed_detected_drift_clients = []
    global_model_sdrift = CNN().to(device)

    
    # Initialize counters for drift detection performance
    true_positives = 0
    true_negatives = 0
    false_positives = 0
    false_negatives = 0


    # Federated learning process
    for global_epoch in range(num_global_epochs):
        print(f"Global Epoch {global_epoch + 1}/{num_global_epochs}")

        # Swap classes if the current global epoch is more than half
        if global_epoch == 10: # (num_global_epochs // 2):
            print("Swapping classes 3 and 8 in the train dataset for selected clients")
            for client_id in drift_clients:
                indices = client_indices[client_id]
                for idx in indices:
                    if trainset.targets[idx] == 8:
                        trainset.targets[idx] = 3
                    elif trainset.targets[idx] == 3:
                        trainset.targets[idx] = 8
                        
                    elif trainset.targets[idx] == 5:
                        trainset.targets[idx] = 6
                    elif trainset.targets[idx] == 6:
                        trainset.targets[idx] = 5
                    
                    
        client_train_loaders = [torch.utils.data.DataLoader(torch.utils.data.Subset(trainset, indices), batch_size=32, shuffle=True) for indices in client_indices]
        client_test_loaders = [torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False) for _ in range(num_clients)]
        # Distribute the respective global models back to the clients
        client_models = []

        for i in range(num_clients):
            if i in fixed_detected_drift_clients:
                client_models.append(copy.deepcopy(global_model_sdrift))
            else:
                client_models.append(copy.deepcopy(global_model))


        # Perform local training and optimize drifted clients classifier
        client_losses = []
        
        for i, train_loader in enumerate(client_train_loaders):
            print(f"\tTraining Client {i + 1}/{num_clients}")
            criterion = nn.CrossEntropyLoss()
            
            # Change beta1 and beta2 if concept drift detected
#             if i in fixed_detected_drift_clients :
#                 print(f"initiating optimization for client {i+1}")
#                 optimizer = optim.Adam(client_models[i].parameters(), lr=0.001, betas=(0.8, 0.999))
#                 loss = client_training(train_loader, client_models[i], criterion, optimizer, num_local_epochs)
#             else:
            optimizer = optim.Adam(client_models[i].parameters(), lr=0.001)
            loss = client_training(train_loader, client_models[i], criterion, optimizer, num_local_epochs)
                
            client_losses.append(loss)
            
            
            # Update previous client losses
            previous_client_losses[i].append(loss)

           
        # concept drift detection and performance metrics
        if global_epoch >= 5:
            for i in range(num_clients):
                prev_loss = previous_client_losses[i][-2]
                if not drastic_change_detected[i]:
                    if client_losses[i] > (prev_loss * drastic_change_factor):
                        drastic_change_detected[i] = True
                else:
                    if client_losses[i] >= 4:
                        sustained_high_loss[i] = True
                    else:
                        drastic_change_detected[i] = False
                        sustained_high_loss[i] = False

            detected_drift_clients = [
                i for i, (drastic, sustained) in enumerate(zip(drastic_change_detected, sustained_high_loss))
                if drastic and sustained
            ]
            
            if detected_drift_clients and not fixed_detected_drift_clients:
                fixed_detected_drift_clients = detected_drift_clients
            

            for client_id in range(num_clients):
                ground_truth_drift = (client_id in drift_clients) and (global_epoch>=10)
                detected_drift = client_id in detected_drift_clients

                if detected_drift and ground_truth_drift:
                    true_positives += 1
                elif not detected_drift and not ground_truth_drift:
                    true_negatives += 1
                elif detected_drift and not ground_truth_drift:
                    false_positives += 1
                elif not detected_drift and ground_truth_drift:
                    false_negatives += 1

                if detected_drift == ground_truth_drift:
                    print(f"Global Epoch {global_epoch + 1}, Client {client_id + 1}: Drift detection is accurate.")
                else:
                    print(f"Global Epoch {global_epoch + 1}, Client {client_id + 1}: Drift detection is inaccurate.")
                    


                    
        # Aggregate client models into global model
        print("\tAggregating client models")
        drifted_client_models = [client_models[i] for i in fixed_detected_drift_clients]
        non_drifted_client_models = [client_models[i] for i in range(num_clients) if i not in fixed_detected_drift_clients]

        if drifted_client_models:
            global_model_sdrift = aggregate_models(global_model_sdrift, drifted_client_models)

        if non_drifted_client_models:
            global_model = aggregate_models(global_model, non_drifted_client_models)
            


    # Test models on each client
    accuracies = []
    for i, test_loader in enumerate(client_test_loaders):
        print(f"Testing Client {i + 1}/{num_clients}")
        accuracy = client_testing(test_loader, global_model)
        accuracies.append(accuracy)
        print(f"\tAccuracy for Client {i + 1}/{num_clients}: {accuracy}")
        
    # Calculate performance metrics of drift detection
    accuracy = (true_positives + true_negatives) / (true_positives + true_negatives + false_positives + false_negatives)
    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
    f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    print(f"Drift Detection metircs::: true_positives:{true_positives},true_negatives:{true_negatives},false_positives:{false_positives},false_negatives:{false_negatives}")
    print(f"Drift Detection Performance:\nAccuracy: {accuracy}\nPrecision: {precision}\nRecall: {recall}\nF1 Score: {f1_score}")
        

    return accuracies

# Aggregation of models' weights using FedAvg
def aggregate_models(global_model, client_models):
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_model.state_dict()[k].float() for client_model in client_models], 0).mean(0)
    global_model.load_state_dict(global_dict)
    return global_model


# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Example usage
if __name__ == "__main__":
    client_accuracies=federated_learning(num_clients=60, num_local_epochs=1, num_global_epochs=40)
    print("Client Accuracies:", client_accuracies)

Global Epoch 1/40
	Training Client 1/60
Epoch 1/1, Loss: 50.54583424329758, Accuracy: 0.523
	Training Client 2/60
Epoch 1/1, Loss: 55.41438293457031, Accuracy: 0.451
	Training Client 3/60
Epoch 1/1, Loss: 49.73745948076248, Accuracy: 0.518
	Training Client 4/60
Epoch 1/1, Loss: 56.56654405593872, Accuracy: 0.466
	Training Client 5/60
Epoch 1/1, Loss: 53.078854858875275, Accuracy: 0.466
	Training Client 6/60
Epoch 1/1, Loss: 53.16413390636444, Accuracy: 0.45
	Training Client 7/60
Epoch 1/1, Loss: 54.305288553237915, Accuracy: 0.451
	Training Client 8/60
Epoch 1/1, Loss: 56.43608808517456, Accuracy: 0.434
	Training Client 9/60
Epoch 1/1, Loss: 51.65361696481705, Accuracy: 0.506
	Training Client 10/60
Epoch 1/1, Loss: 52.701732099056244, Accuracy: 0.463
	Training Client 11/60
Epoch 1/1, Loss: 52.24538081884384, Accuracy: 0.52
	Training Client 12/60
Epoch 1/1, Loss: 58.476162135601044, Accuracy: 0.396
	Training Client 13/60
Epoch 1/1, Loss: 51.254629135131836, Accuracy: 0.478
	Training Cli

KeyboardInterrupt: 