In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import shap
from scipy.stats import mannwhitneyu, ttest_ind, ks_2samp, ttest_ind_from_stats
import copy
import random
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split

In [2]:
# Define the model architecture
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(64 * 5 * 5, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = x.view(-1, 64 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.log_softmax(x, dim=1)

# Function to simulate client training with MNIST data
def client_training(train_loader, model, criterion, optimizer, num_epochs=1):
    model.train()
    for epoch in range(num_epochs):
        total_loss, correct, total = 0.0, 0.0, 0.0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        accuracy = correct / total
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss}, Accuracy: {accuracy}")
        
    return total_loss

# Function to test the trained model on client data
def client_testing(test_loader, model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

In [3]:
# Function to compute SHAP values
def compute_shap_values(model, shap_data, background):
    explainer = shap.DeepExplainer(model, background)
    x, labels = explainer.shap_values(shap_data, ranked_outputs=1, check_additivity=False)
    
    x_shap_value = np.array(x[0])
    shap_values_reshaped = x_shap_value.reshape(x_shap_value.shape[0], -1)
    features = [f'Feature_{i}' for i in range(shap_values_reshaped.shape[1])]
    shap_mean_values = np.mean(shap_values_reshaped, axis=0)
    
    valid_indices = np.sum(np.abs(shap_values_reshaped), axis=0) >= 5
    filtered_shap_values = shap_mean_values[valid_indices]
    filtered_features = [feature for i, feature in enumerate(features) if valid_indices[i]]
    
    return filtered_shap_values, filtered_features

In [10]:
# Federated learning process with shap analysis
def federated_learning(num_clients=5, num_local_epochs=5, num_global_epochs=5):
    
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths)
    
    client_indices = [dataset.indices for dataset in datasets]
    
    # clients that will experience concept drift
    drift_clients = [2,6,9]

    # Initialize global model
    global_model = CNN().to(device)
    previous_shap_values_per_client = [None] * num_clients
    previous_shap_features_per_client = [None] * num_clients
    drastic_change_factor = 3  # Example threshold for drastic change
    previous_client_losses = [0.0] * num_clients
    drastic_change_detected = [False] * num_clients
    sustained_high_loss = [False] * num_clients
    
    # Initialize feature occurrence dictionary
    feature_occurrence_dict = {}
    frequent_features = []

    # Initialize counters for drift detection performance
    true_positives = 0
    true_negatives = 0
    false_positives = 0
    false_negatives = 0

    # Federated learning process
    for global_epoch in range(num_global_epochs):
        print(f"Global Epoch {global_epoch + 1}/{num_global_epochs}")

#         # Swap features if the current global epoch is more than half
#         if global_epoch == 10: # (num_global_epochs // 2):
#             print("Swapping features for classes 3, 8, 5, and 6 in the train dataset for each client")
#             for client_id in drift_clients:
#                 indices = client_indices[client_id]
#                 for idx in indices:
#                     if trainset.targets[idx] in [5,6]:
#                         trainset.data[idx] = torch.flip(trainset.data[idx], [0, 1])
                        
        if global_epoch == 10:  # (num_global_epochs // 2):
            print("Swapping 100 features between classes 3 and 8, 5 and 6 in the train dataset for each client")
            for client_id in drift_clients:
                indices = client_indices[client_id]
                for idx in indices:
                    if trainset.targets[idx] in [3, 8, 5, 6]:
                        # Choose 100 random pixels to swap
                        h, w = trainset.data[idx].shape
                        num_features_to_swap = 700
                        swap_indices = random.sample(range(h * w), num_features_to_swap)
                        for swap_idx in swap_indices:
                            y, x = divmod(swap_idx, w)
                            # Swap the pixel value with another pixel within the same image
                            swap_y, swap_x = random.randint(0, h-1), random.randint(0, w-1)
                            trainset.data[idx, y, x], trainset.data[idx, swap_y, swap_x] = trainset.data[idx, swap_y, swap_x], trainset.data[idx, y, x]
                    
        client_train_loaders = [torch.utils.data.DataLoader(torch.utils.data.Subset(trainset, indices), batch_size=32, shuffle=True) for indices in client_indices]
        client_test_loaders = [torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False) for _ in range(num_clients)]
        client_models = [copy.deepcopy(global_model) for _ in range(num_clients)]

        # Perform local training and compute SHAP values
        shap_values_per_client = []
        shap_features_per_client = []
        client_losses = []
        
        for i, train_loader in enumerate(client_train_loaders):
            print(f"\tTraining Client {i + 1}/{num_clients}")
            criterion = nn.CrossEntropyLoss()
            optimizer = optim.Adam(client_models[i].parameters(), lr=0.001)
            loss = client_training(train_loader, client_models[i], criterion, optimizer, num_local_epochs)
            client_losses.append(loss)   

            client_shap_data = []
            client_shap_labels = []
            client_class_count = {i: 0 for i in range(10)}

            for batch_idx, (images, labels) in enumerate(train_loader):
                for img, lbl in zip(images, labels):
                    if client_class_count[lbl.item()] < 10:  # Accumulate up to num images per class
                        client_shap_data.append(img)
                        client_shap_labels.append(lbl)
                        client_class_count[lbl.item()] += 1

                # Check if we've accumulated 10 images for each class
                if all(count >= 10 for count in client_class_count.values()):
                    break
            
            if all(count >= 10 for count in client_class_count.values()):
                shap_data = torch.stack(client_shap_data).to(device)
                shap_labels = torch.stack(client_shap_labels).to(device)

                # Split shap_data into 90% and 10% sets
                num_samples = shap_data.size(0)
                split_idx = int(num_samples * 0.9)
                background = shap_data[split_idx:].to(device)  # 10% of data for background
                shap_data = shap_data[:split_idx].to(device)  # 90% of data for SHAP computation
                shap_labels = shap_labels[:split_idx].to(device)  # Corresponding labels

                # Compute SHAP values by class
                print(f"\t\tComputed SHAP values for Client {i + 1}/{num_clients}")
                shap_values, shap_features = compute_shap_values(client_models[i], shap_data, background)
                shap_values_per_client.append(shap_values)
                shap_features_per_client.append(shap_features)

        # Perform Kolmogorov-Smirnov test on SHAP values
        print("\tPerforming Kolmogorov–Smirnov test for SHAP values")
        p_values = []
        for i, (shap_values_client, shap_features_client) in enumerate(zip(shap_values_per_client, shap_features_per_client)):
            if previous_shap_values_per_client[i] is not None:
                if global_epoch >= 5 and frequent_features:
                    common_features = list(set(shap_features_client) & set(previous_shap_features_per_client[i]) & set(frequent_features))
                else:
                    common_features = list(set(shap_features_client) & set(previous_shap_features_per_client[i]))
                if common_features:
                    current_indices = [shap_features_client.index(feature) for feature in common_features]
                    previous_indices = [previous_shap_features_per_client[i].index(feature) for feature in common_features]
                    if current_indices and previous_indices:
                        stat, p_value = ks_2samp(shap_values_client[current_indices], previous_shap_values_per_client[i][previous_indices])
                        p_values.append(p_value)
                else:
                    p_values.append(1.0)
            else:
                p_values.append(1.0)
        
        # Update SHAP values and feature occurrence dictionary
        for i, (shap_values, shap_features) in enumerate(zip(shap_values_per_client, shap_features_per_client)):
            previous_shap_values_per_client[i] = shap_values
            previous_shap_features_per_client[i] = shap_features
            for feature in shap_features:
                feature_occurrence_dict[feature] = feature_occurrence_dict.get(feature, 0) + 1

        frequent_features = [feature for feature, count in feature_occurrence_dict.items() if count >= 20]
#         print(f"\tFrequent Features: {frequent_features}")
        

        # Detect Concept Drift using p-values and other metrics
        if global_epoch>=5:
            detected_drift_clients = [i for i, p_value in enumerate(p_values) if p_value <= 0.005]

            # Calculate federated accuracy for each client model
            for client_id in range(num_clients):
                ground_truth_drift = (client_id in drift_clients) and (global_epoch>=10)
                detected_drift = client_id in detected_drift_clients

                if detected_drift and ground_truth_drift:
                    true_positives += 1
                elif not detected_drift and not ground_truth_drift:
                    true_negatives += 1
                elif detected_drift and not ground_truth_drift:
                    false_positives += 1
                elif not detected_drift and ground_truth_drift:
                    false_negatives += 1

                if detected_drift == ground_truth_drift:
                    print(f"Global Epoch {global_epoch + 1}, Client {client_id + 1}: Drift detection is accurate.")
                else:
                    print(f"Global Epoch {global_epoch + 1}, Client {client_id + 1}: Drift detection is inaccurate.")

                    
    
        # Aggregate client models into global model
        print("\tAggregating client models")
        global_model = aggregate_models(global_model, client_models)

    # Test models on each client
    accuracies = []
    for i, test_loader in enumerate(client_test_loaders):
        print(f"Testing Client {i + 1}/{num_clients}")
        accuracy = client_testing(test_loader, global_model)
        accuracies.append(accuracy)
        print(f"\tAccuracy for Client {i + 1}/{num_clients}: {accuracy}")
        
    # Calculate performance metrics of drift detection
    accuracy = (true_positives + true_negatives) / (true_positives + true_negatives + false_positives + false_negatives)
    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
    f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    print(f"Drift Detection metircs::: true_positives:{true_positives},true_negatives:{true_negatives},false_positives:{false_positives},false_negatives:{false_negatives}")
    print(f"Drift Detection Performance:\nAccuracy: {accuracy}\nPrecision: {precision}\nRecall: {recall}\nF1 Score: {f1_score}")
        

    return accuracies

# Aggregation of models' weights using FedAvg
def aggregate_models(global_model, client_models):
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_model.state_dict()[k].float() for client_model in client_models], 0).mean(0)
    global_model.load_state_dict(global_dict)
    return global_model


# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Example usage
if __name__ == "__main__":
    client_accuracies=federated_learning(num_clients=10, num_local_epochs=1, num_global_epochs=20)
    print("Client Accuracies:", client_accuracies)

Global Epoch 1/20
	Training Client 1/10
Epoch 1/1, Loss: 110.54682398214936, Accuracy: 0.8206666666666667
		Computed SHAP values for Client 1/10
	Training Client 2/10
Epoch 1/1, Loss: 106.04906145483255, Accuracy: 0.824
		Computed SHAP values for Client 2/10
	Training Client 3/10
Epoch 1/1, Loss: 108.55197816714644, Accuracy: 0.8243333333333334
		Computed SHAP values for Client 3/10
	Training Client 4/10
Epoch 1/1, Loss: 109.73743585497141, Accuracy: 0.8206666666666667
		Computed SHAP values for Client 4/10
	Training Client 5/10
Epoch 1/1, Loss: 100.90648332238197, Accuracy: 0.8366666666666667
		Computed SHAP values for Client 5/10
	Training Client 6/10
Epoch 1/1, Loss: 103.60011716187, Accuracy: 0.8321666666666667
		Computed SHAP values for Client 6/10
	Training Client 7/10
Epoch 1/1, Loss: 100.37193609774113, Accuracy: 0.8361666666666666
		Computed SHAP values for Client 7/10
	Training Client 8/10
Epoch 1/1, Loss: 103.87043135985732, Accuracy: 0.8251666666666667
		Computed SHAP value