In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from scipy.ndimage import rotate
import shap
from scipy.stats import mannwhitneyu, ttest_ind, ks_2samp, ttest_ind_from_stats
import copy
import random
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split

In [2]:
# Set seed for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
# Define the model architecture
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(64 * 5 * 5, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = x.view(-1, 64 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.log_softmax(x, dim=1)

# Function to simulate client training with MNIST data
def client_training(train_loader, model, criterion, optimizer, num_epochs=1):
    model.train()
    for epoch in range(num_epochs):
        total_loss, correct, total = 0.0, 0.0, 0.0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        accuracy = correct / total
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss}, Accuracy: {accuracy}")
        
    return total_loss

# Function to test the trained model on client data
def client_testing(test_loader, model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

In [4]:
# Function to apply gradual rotation drift
def apply_rotation_drift(images, current_epoch, max_rotation, start_epoch, end_epoch, total_epochs):
    if current_epoch < start_epoch:
        rotation_angle = 0
    elif start_epoch <= current_epoch <= end_epoch:
        transition_progress = (current_epoch - start_epoch) / (end_epoch - start_epoch)
        rotation_angle = transition_progress * max_rotation
    else:
        rotation_angle = 0
    
    fraction_rotated = (current_epoch - start_epoch + 1) / (end_epoch - start_epoch + 1)
    num_images_to_rotate = int(fraction_rotated * len(images))
    
    drifted_images = images.clone()
    if num_images_to_rotate > 0 and fraction_rotated<=1:
        indices_to_rotate = torch.randperm(len(images))[:num_images_to_rotate]
        for idx in indices_to_rotate:
            drifted_images[idx] = torch.tensor(rotate(images[idx].numpy(), rotation_angle, reshape=False))
    
    return drifted_images

In [5]:
def federated_learning(num_clients=60, num_local_epochs=5, num_global_epochs=40):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths)

    client_indices = [dataset.indices for dataset in datasets]

    # Clients that will experience rotation drift
    rotation_drift_clients = [2, 6, 9,13,15,27]

    # Clients that will experience class swap drift
    class_swap_drift_clients = [3,10, 18, 41, 55]

    # Initialize global model
    global_model = CNN().to(device)
    previous_client_losses = [[] for _ in range(num_clients)]
    drastic_change_detected = [False] * num_clients
    sustained_high_loss = [False] * num_clients
    drastic_change = [False] * num_clients
    sustained_high = [False] * num_clients
    fixed_long_averages = [None] * num_clients  # Store fixed long window averages after drift detection
    
    # Initialize detected drift clients
    detected_drift_clients = []
    fixed_detected_drift_clients = {'sudden': [], 'gradual': []}
    global_model_sdrift = CNN().to(device)
    global_model_gdrift = CNN().to(device)
    
    
    # Dictionary to track how many epochs each client has been detected for drift
    drift_count = {i: 0 for i in range(num_clients)}

    # Define window sizes
    short_window = 15
    long_window = 20

    # Define factor for drastic change detection
    drastic_change_factor = 3

    # Initialize counters for drift detection performance
    true_positives = 0
    true_negatives = 0
    false_positives = 0
    false_negatives = 0

    # Federated learning process
    for global_epoch in range(num_global_epochs):
        print(f"Global Epoch {global_epoch + 1}/{num_global_epochs}")

        # Apply rotation drift gradually starting from a certain epoch
        if global_epoch in range(10, 31):  # Example threshold for starting drift
            print("Applying rotation drift to specific clients")
            for client_id in rotation_drift_clients:
                indices = client_indices[client_id]
                images = trainset.data[indices]
                drifted_images = apply_rotation_drift(images, global_epoch, max_rotation=30, start_epoch=10, end_epoch=30, total_epochs=num_global_epochs)
                trainset.data[indices] = drifted_images

        # Apply class swap drift at a specific epoch
        if global_epoch == 10:
            print("Swapping classes 3 and 8 in the train dataset for selected clients")
            for client_id in class_swap_drift_clients:
                indices = client_indices[client_id]
                for idx in indices:
                    if trainset.targets[idx] == 8:
                        trainset.targets[idx] = 3
                    elif trainset.targets[idx] == 3:
                        trainset.targets[idx] = 8

        client_train_loaders = [torch.utils.data.DataLoader(torch.utils.data.Subset(trainset, indices), batch_size=32, shuffle=True) for indices in client_indices]
        client_test_loaders = [torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False) for _ in range(num_clients)]
        client_models = []

        # Perform local training and compute SHAP values
        client_losses = []
        
        for i in range(num_clients):
            if i in fixed_detected_drift_clients['gradual'] and i not in fixed_detected_drift_clients['sudden'] :
                client_models.append(copy.deepcopy(global_model_gdrift))
            if i in fixed_detected_drift_clients['sudden'] and i not in fixed_detected_drift_clients['gradual'] :
                client_models.append(copy.deepcopy(global_model_sdrift))
            if i in fixed_detected_drift_clients['sudden'] and i in fixed_detected_drift_clients['gradual']:
                client_models.append(copy.deepcopy(global_model_gdrift))  # Presuming gdrift takes precedence
            if i not in fixed_detected_drift_clients['sudden'] and i not in fixed_detected_drift_clients['gradual']:
                client_models.append(copy.deepcopy(global_model))
        
        

        for i, train_loader in enumerate(client_train_loaders):
            print(f"\tTraining Client {i + 1}/{num_clients}")
            criterion = nn.CrossEntropyLoss()

            # Check for gradual drift only
            if i in fixed_detected_drift_clients['gradual'] and i not in fixed_detected_drift_clients['sudden']:
                optimizer = optim.Adam(client_models[i].parameters(), lr=0.001, betas=(0.6, 0.7))

            # Check for both sudden and gradual drifts
            if i in fixed_detected_drift_clients['sudden'] and i in fixed_detected_drift_clients['gradual']:
                optimizer = optim.Adam(client_models[i].parameters(), lr=0.001, betas=(0.6, 0.7))

            # Check for sudden drift only
            if i in fixed_detected_drift_clients['sudden'] and i not in fixed_detected_drift_clients['gradual']:
                optimizer = optim.Adam(client_models[i].parameters(), lr=0.001)

            # Default case for no drift
            if i not in fixed_detected_drift_clients['sudden'] and i not in fixed_detected_drift_clients['gradual']:
                optimizer = optim.Adam(client_models[i].parameters(), lr=0.001)
                
            loss = client_training(train_loader, client_models[i], criterion, optimizer, num_local_epochs)

                
                
                
            client_losses.append(loss)

            # Update previous client losses
            previous_client_losses[i].append(loss)
            if len(previous_client_losses[i]) > long_window:
                previous_client_losses[i].pop(0)

        # Concept drift detection for sudden drift
        if global_epoch >= 5:
            for i in range(num_clients):
                prev_loss = previous_client_losses[i][-2]
                if not drastic_change[i]:
                    if client_losses[i] > (prev_loss * drastic_change_factor):
                        drastic_change[i] = True
                else:
                    if client_losses[i] >= 4:
                        sustained_high[i] = True
                    else:
                        drastic_change[i] = False
                        sustained_high[i] = False

            detected_drift_clients_1 = [
                i for i, (drastic, sustained) in enumerate(zip(drastic_change, sustained_high))
                if drastic and sustained
            ]
            
            print(f'sdrft_client{detected_drift_clients_1}')
            
        # Concept drift detection and performance metrics
        if global_epoch in range(5, 29):
            # Gradual detection using moving averages
            for i in range(num_clients):
                if len(previous_client_losses[i]) >= long_window:
                    short_avg = np.mean(previous_client_losses[i][-short_window:])
                    
                    if fixed_long_averages[i] is None:
                        long_avg = np.mean(previous_client_losses[i][-long_window:])
                    else:
                        long_avg = fixed_long_averages[i]

                    if client_losses[i] > short_avg:
                        drastic_change_detected[i] = True
                    else:
                        drastic_change_detected[i] = False

                    if client_losses[i] > long_avg:
                        sustained_high_loss[i] = True
                        if fixed_long_averages[i] is None:
                            fixed_long_averages[i] = long_avg
                    else:
                        sustained_high_loss[i] = False

            detected_drift_clients_2 = [
                i for i, (drastic, sustained,d_change, s_high) in enumerate(
                    zip(drastic_change_detected, sustained_high_loss, drastic_change, sustained_high)
                ) if drastic and sustained and not d_change and not s_high
            ]
            
            print(f'gdrft_client{detected_drift_clients_2}')
            
            
            detected_drift_clients = list(set(detected_drift_clients_1) | set(detected_drift_clients_2))
            
            
            # Check if there are new clients in detected_drift_clients
            for client in detected_drift_clients:
                drift_count[client] += 1
                
            print(drift_count)
                

            # Determine the maximum drift count among the fixed detected drift clients
            max_drift_count_sudden = max((drift_count[client] for client in fixed_detected_drift_clients['sudden']), default=0)
            max_drift_count_gradual = max((drift_count[client] for client in fixed_detected_drift_clients['gradual']), default=0)
            
            

            # Update fixed_detected_drift_clients with any new drifted clients with higher drift count
            for client in detected_drift_clients_1:
                if client not in fixed_detected_drift_clients['sudden'] and drift_count[client] > 3:
                    fixed_detected_drift_clients['sudden'].append(client)
                    
            for client in detected_drift_clients_2:
                if client not in fixed_detected_drift_clients['gradual'] and drift_count[client] > 3:
                    fixed_detected_drift_clients['gradual'].append(client)
                        



            for client_id in range(num_clients):
                ground_truth_drift = ((client_id in class_swap_drift_clients) and (global_epoch >= 10)) | ((client_id in rotation_drift_clients) and (global_epoch >= 19) and (global_epoch <= 29))
                detected_drift = client_id in detected_drift_clients

                if detected_drift and ground_truth_drift:
                    true_positives += 1
                elif not detected_drift and not ground_truth_drift:
                    true_negatives += 1
                elif detected_drift and not ground_truth_drift:
                    false_positives += 1
                elif not detected_drift and ground_truth_drift:
                    false_negatives += 1

                if detected_drift == ground_truth_drift:
                    if client_id in detected_drift_clients_1 and client_id in class_swap_drift_clients:
                        print(f"Global Epoch {global_epoch + 1}, Client {client_id + 1}: True Drift Detected, Drift Type: Sudden .")
                    if client_id in detected_drift_clients_2 and client_id in rotation_drift_clients:
                        print(f"Global Epoch {global_epoch + 1}, Client {client_id + 1}: True Drift Detected, Drift Type: Gradual .")
                else:
                    print(f"Global Epoch {global_epoch + 1}, Client {client_id + 1}: False Drift Detected.")

        # Aggregate client models into global model
        print("\tAggregating client models")
        # Convert lists to sets for efficient operations
        sudden_clients = set(fixed_detected_drift_clients['sudden'])
        gradual_clients = set(fixed_detected_drift_clients['gradual'])
        print(f'fixed_detected_drift_clients:{fixed_detected_drift_clients}')
        
        
        # Determine if there are common clients between sudden and gradual drifts
        common_dclients = sudden_clients & gradual_clients
        print(f"common_dclients:{common_dclients}")
        all_dclients = sudden_clients | gradual_clients

        
        # Aggregate models for different drift categories
        if not common_dclients:
            sdrifted_client_models = [client_models[i] for i in sudden_clients]
            gdrifted_client_models = [client_models[i] for i in gradual_clients]
        else:
            gdrifted_client_models = [client_models[i] for i in all_dclients]
            
        # Include common clients in either drift category if needed
        # Example: Add common clients to both lists if they should be handled separately
        # sdrifted_client_models += [client_models[i] for i in common_clients]
        # gdrifted_client_models += [client_models[i] for i in common_clients]
    
        
        non_drifted_client_models = [client_models[i] for i in range(num_clients) if i not in all_dclients]

        
            
        if non_drifted_client_models:
            global_model = aggregate_models(global_model, non_drifted_client_models)
            
        if not common_dclients:
            if gdrifted_client_models:
                global_model_gdrift = aggregate_models(global_model_gdrift, gdrifted_client_models)
            if sdrifted_client_models:
                global_model_sdrift = aggregate_models(global_model_sdrift, sdrifted_client_models)
        if common_dclients:
            if gdrifted_client_models:
                global_model_gdrift = aggregate_models(global_model_gdrift, gdrifted_client_models)



    # Test models on each client
    accuracies = []
    for i, test_loader in enumerate(client_test_loaders):
        print(f"Testing Client {i + 1}/{num_clients}")
        accuracy = client_testing(test_loader, global_model)
        accuracies.append(accuracy)
        print(f"\tAccuracy for Client {i + 1}/{num_clients}: {accuracy}")

    # Calculate performance metrics of drift detection
    accuracy = (true_positives + true_negatives) / (true_positives + true_negatives + false_positives + false_negatives)
    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
    f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    print(f"Drift Detection metrics::: true_positives:{true_positives}, true_negatives:{true_negatives}, false_positives:{false_positives}, false_negatives:{false_negatives}")
    print(f"Drift Detection Performance:\nAccuracy: {accuracy}\nPrecision: {precision}\nRecall: {recall}\nF1 Score: {f1_score}")

    return accuracies

# Aggregation of models' weights using FedAvg
def aggregate_models(global_model, client_models):
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_model.state_dict()[k].float() for client_model in client_models], 0).mean(0)
    global_model.load_state_dict(global_dict)
    return global_model

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Example usage
if __name__ == "__main__":
    client_accuracies = federated_learning(num_clients=60,num_local_epochs=1, num_global_epochs=60)
    print("Client Accuracies:", client_accuracies)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 15946247.87it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 465653.99it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 3804497.36it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2228132.02it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Global Epoch 1/60
	Training Client 1/60
Epoch 1/1, Loss: 46.32598519325256, Accuracy: 0.537
	Training Client 2/60
Epoch 1/1, Loss: 49.65331691503525, Accuracy: 0.509
	Training Client 3/60
Epoch 1/1, Loss: 46.03896060585976, Accuracy: 0.51
	Training Client 4/60
Epoch 1/1, Loss: 46.34492492675781, Accuracy: 0.558
	Training Client 5/60
Epoch 1/1, Loss: 48.653301537036896, Accuracy: 0.499
	Training Client 6/60
Epoch 1/1, Loss: 48.20056077837944, Accuracy: 0.563
	Training Client 7/60
Epoch 1/1, Loss: 49.92811328172684, Accuracy: 0.502
	Training Client 8/60
Epoch 1/1, Loss: 47.378638207912445, Accuracy: 0.542
	Training Client 9/60
Epoch 1/1, Loss: 49.788997173309326, Accuracy: 0.521
	Training Client 10/60
Epoch 1/1, Loss: 46.613293409347534, Accuracy: 0.546
	Training Client 11/60
Epoch 1/1, Loss: 47.87422716617584, Accuracy: 0.545
	Training Client 12/60
Epoch 1/1, Loss: 48.75999730825424, Accuracy: 0.552
	Training Cli

KeyboardInterrupt: 