In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import shap
from scipy.stats import mannwhitneyu, ttest_ind, ks_2samp, ttest_ind_from_stats
import copy
import random
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split

In [2]:
# Set seed for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
# Define the model architecture
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(64 * 5 * 5, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = x.view(-1, 64 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.log_softmax(x, dim=1)

# Function to simulate client training with MNIST data
def client_training(train_loader, model, criterion, optimizer, num_epochs=1):
    model.train()
    for epoch in range(num_epochs):
        total_loss, correct, total = 0.0, 0.0, 0.0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        accuracy = correct / total
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss}, Accuracy: {accuracy}")
        
    return total_loss
    

# Function to test the trained model on client data
def client_testing(test_loader, model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total


In [10]:
# Function to compute SHAP values
def compute_shap_values(model, shap_data, background):
#     reference_tensor = shap_data[0]
#     p_values = []

#     for tensor in shap_data:
#         if tensor.shape != reference_tensor.shape:
#             tensor = np.reshape(tensor, reference_tensor.shape)
#         t_stat, p_value = ttest_ind(tensor.flatten().cpu().numpy(), reference_tensor.flatten().cpu().numpy())
#         p_values.append(p_value)

#     sorted_indices = np.argsort(p_values)
#     num_keep = 100
#     indices_to_keep = sorted_indices[-num_keep:]
#     flt_tensor_list = [shap_data[index] for index in indices_to_keep]

#     data = torch.stack(flt_tensor_list).to(device)
    explainer = shap.DeepExplainer(model, background)
    x, labels = explainer.shap_values(shap_data, ranked_outputs=1, check_additivity=False)
    
    x_shap_value = np.array(x[0])
    shap_values_reshaped = x_shap_value.reshape(x_shap_value.shape[0], -1)
    features = [f'Feature_{i}' for i in range(shap_values_reshaped.shape[1])]
    shap_mean_values = np.mean(shap_values_reshaped, axis=0)
    
    valid_indices = np.sum(np.abs(shap_values_reshaped), axis=0) >= 0.5
    filtered_shap_values = shap_mean_values[valid_indices]
    filtered_features = [feature for i, feature in enumerate(features) if valid_indices[i]]
    
    return filtered_shap_values, filtered_features

In [2]:
b = [2, 6, 9, 11, 15, 22, 31, 44 , 51, 57]

a = len(b)

a

10

In [11]:
def federated_learning(num_clients=5, num_local_epochs=5, num_global_epochs=5):
    
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths)
    
    client_indices = [dataset.indices for dataset in datasets]
    
    # clients that will experience concept drift
    drift_clients = [2, 6, 9, 11, 15, 22, 31, 44 , 51, 57]

    # Initialize global model
    global_model = CNN().to(device)
    previous_shap_values_per_client = [None] * num_clients
    previous_shap_features_per_client = [None] * num_clients
    drastic_change_factor = 3  # Example threshold for drastic change
    previous_client_losses = [0.0] * num_clients
    drastic_change_detected = [False] * num_clients
    sustained_high_loss = [False] * num_clients
    
    # Initialize client-wise feature occurrence dictionary
    feature_occurrence_dict = {client_id: {} for client_id in range(num_clients)}
    frequent_features_per_client = {client_id: [] for client_id in range(num_clients)}

    # Initialize counters for drift detection performance
    true_positives = 0
    true_negatives = 0
    false_positives = 0
    false_negatives = 0

    # Federated learning process
    for global_epoch in range(num_global_epochs):
        print(f"Global Epoch {global_epoch + 1}/{num_global_epochs}")

        # Swap classes if the current global epoch is more than half
        if global_epoch == 10:
            print("Swapping classes 0 and 1 in the train dataset for each client")
            for client_id in drift_clients:
                indices = client_indices[client_id]
                for idx in indices:
                    if trainset.targets[idx] == 8:
                        trainset.targets[idx] = 3
                    elif trainset.targets[idx] == 3:
                        trainset.targets[idx] = 8
                        
                    # Swap 5 and 6
                    elif trainset.targets[idx] == 6:
                        trainset.targets[idx] = 5
                    elif trainset.targets[idx] == 5:
                        trainset.targets[idx] = 6
                    
        client_train_loaders = [torch.utils.data.DataLoader(torch.utils.data.Subset(trainset, indices), batch_size=32, shuffle=True) for indices in client_indices]
        client_test_loaders = [torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False) for _ in range(num_clients)]
        client_models = [copy.deepcopy(global_model) for _ in range(num_clients)]

        # Perform local training and compute SHAP values
        shap_values_per_client = []
        shap_features_per_client = []
        client_losses = []
        
        for i, train_loader in enumerate(client_train_loaders):
            print(f"\tTraining Client {i + 1}/{num_clients}")
            criterion = nn.CrossEntropyLoss()
            optimizer = optim.Adam(client_models[i].parameters(), lr=0.001)
            loss = client_training(train_loader, client_models[i], criterion, optimizer, num_local_epochs)
            client_losses.append(loss)

            client_shap_data = []
            client_shap_labels = []
            client_class_count = {i: 0 for i in range(10)}

            for batch_idx, (images, labels) in enumerate(train_loader):
                for img, lbl in zip(images, labels):
                    if client_class_count[lbl.item()] <= 10:  # Accumulate up to num images per class
                        client_shap_data.append(img)
                        client_shap_labels.append(lbl)
                        client_class_count[lbl.item()] += 1

                if all(count >= 10 for count in client_class_count.values()):
                    break
            
            for label in range(10):
                if client_class_count[label] < 10:
                    continue
                shap_data = torch.stack(client_shap_data).to(device)
                shap_labels = torch.stack(client_shap_labels).to(device)

            # Split shap_data into 90% and 10% sets
            num_samples = shap_data.size(0)
            split_idx = int(num_samples * 0.9)
            background = shap_data[split_idx:].to(device)  # 10% of data for background
            shap_data = shap_data[:split_idx].to(device)  # 90% of data for SHAP computation
            shap_labels = shap_labels[:split_idx].to(device)  # Corresponding labels

            # Compute SHAP values
            print(f"\t\tComputed SHAP values for Client {i + 1}/{num_clients}")
            shap_values, shap_features = compute_shap_values(client_models[i], shap_data, background)
            shap_values_per_client.append(shap_values)
            shap_features_per_client.append(shap_features)

        # Perform Kolmogorov-Smirnov test on SHAP values
        print("\tPerforming Kolmogorov–Smirnov test for SHAP values")
        p_values = []
        for i, (shap_values_client, shap_features_client) in enumerate(zip(shap_values_per_client, shap_features_per_client)):
            if previous_shap_values_per_client[i] is not None:
                if global_epoch >= 5 and frequent_features_per_client[i]:
                    common_features = list(set(shap_features_client) & set(previous_shap_features_per_client[i]) & set(frequent_features_per_client[i]))
                else:
                    common_features = list(set(shap_features_client) & set(previous_shap_features_per_client[i]))
                if common_features:
                    current_indices = [shap_features_client.index(feature) for feature in common_features]
                    previous_indices = [previous_shap_features_per_client[i].index(feature) for feature in common_features]

                    current_shap_values = shap_values_client[current_indices]
                    previous_shap_values = previous_shap_values_per_client[i][previous_indices]

                    u_statistic, p_value = ks_2samp(previous_shap_values, current_shap_values)
                    p_values.append(p_value)
                    print(f"Global Epoch {global_epoch + 1}, Client {i+1}: Kolmogorov–Smirnov test statistic: {u_statistic}, p-value: {p_value}")

            # Update previous SHAP values
            previous_shap_values_per_client[i] = shap_values_client
            previous_shap_features_per_client[i] = shap_features_client
            
        # Update client-wise feature occurrence dictionary
        for client_id, shap_features_client in enumerate(shap_features_per_client):
            for feature in shap_features_client:
                if feature in feature_occurrence_dict[client_id]:
                    feature_occurrence_dict[client_id][feature] += 1
                else:
                    feature_occurrence_dict[client_id][feature] = 1
                    
                
        # Concept drift detection and performance metrics
        if global_epoch >= 5:
            # Update frequent features list for each client
            for client_id in range(num_clients):
                frequent_features_per_client[client_id] = [
                    feature for feature, count in feature_occurrence_dict[client_id].items() if count >= 10
                ]

            # Drift detection from SHAP values
            detected_drift_clients_1 = [i for i, p_value in enumerate(p_values) if p_value < 0.005]

            # Drift detection using loss values
            for i, (prev_loss, curr_loss) in enumerate(zip(previous_client_losses, client_losses)):
                if not drastic_change_detected[i]:
                    if curr_loss > (prev_loss * drastic_change_factor):
                        drastic_change_detected[i] = True
                else:
                    if curr_loss >= 4:
                        sustained_high_loss[i] = True
                    else:
                        drastic_change_detected[i] = False
                        sustained_high_loss[i] = False

            detected_drift_clients_2 = [
                i for i, (drastic, sustained) in enumerate(zip(drastic_change_detected, sustained_high_loss))
                if drastic and sustained
            ]
            
            # List of clients from both drift detection logics
            detected_drift_clients = list(set(detected_drift_clients_1) | set(detected_drift_clients_2))

            for client_id in range(num_clients):
                ground_truth_drift = (client_id in drift_clients) and (global_epoch >= 10)
                detected_drift = client_id in detected_drift_clients_1

                if detected_drift and ground_truth_drift:
                    true_positives += 1
                elif not detected_drift and not ground_truth_drift:
                    true_negatives += 1
                elif detected_drift and not ground_truth_drift:
                    false_positives += 1
                elif not detected_drift and ground_truth_drift:
                    false_negatives += 1

                if detected_drift == ground_truth_drift:
                    print(f"Global Epoch {global_epoch + 1}, Client {client_id + 1}: Drift detection is accurate.")
                else:
                    print(f"Global Epoch {global_epoch + 1}, Client {client_id + 1}: Drift detection is not accurate.")
                    
                    
        # Update loss tracking
        previous_client_losses = client_losses.copy()
                
        # Aggregate client models into global model
        print("\tAggregating client models")
        global_model = aggregate_models(global_model, client_models)
            
            
    # Test models on each client
    accuracies = []
    for i, test_loader in enumerate(client_test_loaders):
        print(f"Testing Client {i + 1}/{num_clients}")
        accuracy = client_testing(test_loader, global_model)
        accuracies.append(accuracies)
        print(f"\tAccuracy for Client {i + 1}/{num_clients}: {accuracy}")
        
    # Calculate performance metrics of drift detection
    accuracy = (true_positives + true_negatives) / (true_positives + true_negatives + false_positives + false_negatives)
    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
    f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    print(f"Drift Detection metircs::: true_positives:{true_positives},true_negatives:{true_negatives},false_positives:{false_positives},false_negatives:{false_negatives}")
    print(f"Drift Detection Performance:\nAccuracy: {accuracy}\nPrecision: {precision}\nRecall: {recall}\nF1 Score: {f1_score}")

# Aggregation of models' weights using FedAvg
def aggregate_models(global_model, client_models):
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_model.state_dict()[k].float() for client_model in client_models], 0).mean(0)
    global_model.load_state_dict(global_dict)
    return global_model
    

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Example usage
if __name__ == "__main__":
    client_accuracies=federated_learning(num_clients=60, num_local_epochs=1, num_global_epochs=40)
    print("Client Accuracies:", client_accuracies)


Global Epoch 1/40
	Training Client 1/60
Epoch 1/1, Loss: 51.99495720863342, Accuracy: 0.477
		Computed SHAP values for Client 1/60
1
(99, 1, 28, 28)
	Training Client 2/60
Epoch 1/1, Loss: 53.000810980796814, Accuracy: 0.484
		Computed SHAP values for Client 2/60
1
(99, 1, 28, 28)
	Training Client 3/60
Epoch 1/1, Loss: 53.53186875581741, Accuracy: 0.485
		Computed SHAP values for Client 3/60
1
(98, 1, 28, 28)
	Training Client 4/60
Epoch 1/1, Loss: 53.37810117006302, Accuracy: 0.508
		Computed SHAP values for Client 4/60
1
(98, 1, 28, 28)
	Training Client 5/60
Epoch 1/1, Loss: 54.2434920668602, Accuracy: 0.485
		Computed SHAP values for Client 5/60
1
(99, 1, 28, 28)
	Training Client 6/60
Epoch 1/1, Loss: 51.45190727710724, Accuracy: 0.521
		Computed SHAP values for Client 6/60
1
(99, 1, 28, 28)
	Training Client 7/60
Epoch 1/1, Loss: 51.16041475534439, Accuracy: 0.506
		Computed SHAP values for Client 7/60
1
(99, 1, 28, 28)
	Training Client 8/60
Epoch 1/1, Loss: 50.179173827171326, Accura

KeyboardInterrupt: 

In [None]:
# Federated learning process with historical shap values
def federated_learning(num_clients=5, num_local_epochs=5, num_global_epochs=5):
    
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
#     # Restrict the training dataset to only 10,000 images
#     num_train_images = 10000
#     trainset.data = trainset.data[:num_train_images]
#     trainset.targets = trainset.targets[:num_train_images]

    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths)
    
    client_indices = [dataset.indices for dataset in datasets]
    
    # clients that will experience concept drift
    drift_clients = [2,6,9]

    # Initialize global model
    global_model = CNN().to(device)
    previous_shap_values_per_client = [None] * num_clients
    previous_shap_features_per_client = [None] * num_clients
    drastic_change_factor = 3  # Example threshold for drastic change
    previous_client_losses = [0.0] * num_clients
    drastic_change_detected = [False] * num_clients
    sustained_high_loss = [False] * num_clients
    
    # Initialize feature occurrence dictionary
    feature_occurrence_dict = {}
    frequent_features = []


    
    # Initialize counters for drift detection performance
    true_positives = 0
    true_negatives = 0
    false_positives = 0
    false_negatives = 0
    
    # Initialize historical SHAP values list
    historical_shap_values_per_client = [[] for _ in range(num_clients)]


    # Federated learning process
    for global_epoch in range(num_global_epochs):
        print(f"Global Epoch {global_epoch + 1}/{num_global_epochs}")

        # Swap classes if the current global epoch is more than half
        if global_epoch == 10: # (num_global_epochs // 2):
            print("Swapping classes 0 and 1 in the train dataset for each client")
            for client_id in drift_clients:
                indices = client_indices[client_id]
                for idx in indices:
                    if trainset.targets[idx] == 8:
                        trainset.targets[idx] = 3
                    elif trainset.targets[idx] == 3:
                        trainset.targets[idx] = 8
                    
                    
        client_train_loaders = [torch.utils.data.DataLoader(torch.utils.data.Subset(trainset, indices), batch_size=32, shuffle=True) for indices in client_indices]
        client_test_loaders = [torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False) for _ in range(num_clients)]
        client_models = [copy.deepcopy(global_model) for _ in range(num_clients)]



        # Perform local training and compute SHAP values
        shap_values_per_client = []
        shap_features_per_client = []
        client_losses = []
        
        for i, train_loader in enumerate(client_train_loaders):
            print(f"\tTraining Client {i + 1}/{num_clients}")
            criterion = nn.CrossEntropyLoss()
            optimizer = optim.Adam(client_models[i].parameters(), lr=0.001)
            loss = client_training(train_loader, client_models[i], criterion, optimizer, num_local_epochs)
            client_losses.append(loss)   

            client_shap_data = []
            client_shap_labels = []
            client_class_count = {i: 0 for i in range(10)}

            for batch_idx, (images, labels) in enumerate(train_loader):
                for img, lbl in zip(images, labels):
                    if client_class_count[lbl.item()] < 10:  # Accumulate up to num images per class
                        client_shap_data.append(img)
                        client_shap_labels.append(lbl)
                        client_class_count[lbl.item()] += 1

                # Check if we've accumulated 10 images for each class
                if all(count >= 10 for count in client_class_count.values()):
                    break
            
            if all(count >= 10 for count in client_class_count.values()):
                shap_data = torch.stack(client_shap_data).to(device)
                shap_labels = torch.stack(client_shap_labels).to(device)

                # Split shap_data into 90% and 10% sets
                num_samples = shap_data.size(0)
                split_idx = int(num_samples * 0.9)
                background = shap_data[split_idx:].to(device)  # 10% of data for background
                shap_data = shap_data[:split_idx].to(device)  # 90% of data for SHAP computation
                shap_labels = shap_labels[:split_idx].to(device)  # Corresponding labels

                # Compute SHAP values by class
                print(f"\t\tComputed SHAP values for Client {i + 1}/{num_clients}")
                shap_values, shap_features = compute_shap_values(client_models[i], shap_data, background)
                shap_values_per_client.append(shap_values)
                shap_features_per_client.append(shap_features)
                
                # Save current SHAP values to historical record
                historical_shap_values_per_client[i].append(shap_values)
                


        # Perform Kolmogorov-Smirnov test on SHAP values
        print("\tPerforming Kolmogorov–Smirnov test for SHAP values")
        p_values = []
        for i, (shap_values_client, shap_features_client) in enumerate(zip(shap_values_per_client, shap_features_per_client)):
            if len(historical_shap_values_per_client[i]) > 1:
                # Align and pad historical SHAP values if necessary
                max_length = max(len(shap) for shap in historical_shap_values_per_client[i])
                aligned_historical_shap_values = pad_shap_values(historical_shap_values_per_client[i], max_length)

                # Calculate mean of aligned historical SHAP values
                historical_shap_values_mean = np.max(aligned_historical_shap_values, axis=0)

                # Align and pad current SHAP values if necessary
                if len(shap_values_client.shape) == 1:  # If shap_values_client is 1-dimensional
                    shap_values_client = shap_values_client.reshape(-1, 1)
                current_shap_values = pad_shap_values([shap_values_client], max_length)[0]

                # Perform KS test using mean historical SHAP values and current SHAP values
                ks_stat, p_value = ks_2samp(historical_shap_values_mean.flatten(), current_shap_values.flatten())
                p_values.append(p_value)
                print(f"Global Epoch {global_epoch + 1}, Client {i+1}: Kolmogorov–Smirnov test statistic: {ks_stat}, p-value: {p_value}")
            # Update previous SHAP values
#             previous_shap_values_per_client[i] = shap_values_client
#             previous_shap_features_per_client[i] = shap_features_client
            
        # Update feature occurrence dictionary
        for shap_features_client in shap_features_per_client:
            for feature in shap_features_client:
                if feature in feature_occurrence_dict:
                    feature_occurrence_dict[feature] += 1
                else:
                    feature_occurrence_dict[feature] = 1
                

        # concept drift detection and performance metrics
        if global_epoch >= 5:
            
            #frequest features list
            frequent_features = [feature for feature, count in feature_occurrence_dict.items() if count >= 20]
            
            # drift detected from shap values
            detected_drift_clients_1 = [i for i, p_value in enumerate(p_values) if p_value < 0.005]
            
            # drift detection using loss values
            for i, (prev_loss, curr_loss) in enumerate(zip(previous_client_losses, client_losses)):
                if not drastic_change_detected[i]:
                    if curr_loss > (prev_loss * drastic_change_factor):
                        drastic_change_detected[i] = True
                else:
                    if curr_loss >= 4:
                        sustained_high_loss[i] = True
                    else:
                        drastic_change_detected[i] = False
                        sustained_high_loss[i] = False

                
                
            detected_drift_clients_2 = [
                i for i, (drastic, sustained) in enumerate(zip(drastic_change_detected, sustained_high_loss))
                if drastic and sustained
            ]
            
            # list of clients from both drift detection logics
            detected_drift_clients = list(set(detected_drift_clients_1) | set(detected_drift_clients_2))
            

            for client_id in range(num_clients):
                ground_truth_drift = (client_id in drift_clients) and (global_epoch>=11)
                detected_drift = client_id in detected_drift_clients

                if detected_drift and ground_truth_drift:
                    true_positives += 1
                elif not detected_drift and not ground_truth_drift:
                    true_negatives += 1
                elif detected_drift and not ground_truth_drift:
                    false_positives += 1
                elif not detected_drift and ground_truth_drift:
                    false_negatives += 1

                if detected_drift == ground_truth_drift:
                    print(f"Global Epoch {global_epoch + 1}, Client {client_id + 1}: Drift detection is accurate.")
                else:
                    print(f"Global Epoch {global_epoch + 1}, Client {client_id + 1}: Drift detection is inaccurate.")
                    
        # Update loss tracking
        previous_client_losses = client_losses.copy()

                    
        # Aggregate client models into global model
        print("\tAggregating client models")
        global_model = aggregate_models(global_model, client_models)

    # Test models on each client
    accuracies = []
    for i, test_loader in enumerate(client_test_loaders):
        print(f"Testing Client {i + 1}/{num_clients}")
        accuracy = client_testing(test_loader, global_model)
        accuracies.append(accuracy)
        print(f"\tAccuracy for Client {i + 1}/{num_clients}: {accuracy}")
        
    # Calculate performance metrics of drift detection
    accuracy = (true_positives + true_negatives) / (true_positives + true_negatives + false_positives + false_negatives)
    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
    f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    print(f"Drift Detection metircs::: true_positives:{true_positives},true_negatives:{true_negatives},false_positives:{false_positives},false_negatives:{false_negatives}")
    print(f"Drift Detection Performance:\nAccuracy: {accuracy}\nPrecision: {precision}\nRecall: {recall}\nF1 Score: {f1_score}")
        

    return accuracies

# Aggregation of models' weights using FedAvg
def aggregate_models(global_model, client_models):
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_model.state_dict()[k].float() for client_model in client_models], 0).mean(0)
    global_model.load_state_dict(global_dict)
    return global_model

# Function to visualize images with specific classes
def visualize_images(images, labels, classes_to_visualize=[0, 1], num_images=10):
    fig, axes = plt.subplots(1, num_images, figsize=(15, 15))
    count = 0
    for img, lbl in zip(images, labels):
        if lbl.item() in classes_to_visualize:
            axes[count].imshow(img.cpu().numpy().squeeze(), cmap='gray')
            axes[count].set_title(f"Class {lbl.item()}")
            axes[count].axis('off')
            count += 1
        if count >= num_images:
            break
    plt.show()
    
    
# Function to pad and align SHAP values to a common shape
def pad_shap_values(shap_values_list, max_length):
    padded_shap_values = []
    for shap_values in shap_values_list:
        if len(shap_values.shape) == 1:  # If shap_values is 1-dimensional
            shap_values = shap_values.reshape(-1, 1)
        if len(shap_values) < max_length:
            # Pad with zeros to the right
            padding = np.zeros((max_length - len(shap_values), shap_values.shape[1]))
            padded_shap_values.append(np.vstack((shap_values, padding)))
        else:
            padded_shap_values.append(shap_values)
    return np.array(padded_shap_values)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Example usage
if __name__ == "__main__":
    client_accuracies=federated_learning(num_clients=10, num_local_epochs=1, num_global_epochs=20)
    print("Client Accuracies:", client_accuracies)

In [None]:
#p-value
def federated_learning(num_clients=5, num_local_epochs=5, num_global_epochs=5):
    
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths)
    
    client_indices = [dataset.indices for dataset in datasets]
    
    # clients that will experience concept drift
    drift_clients = [2,6,9]

    # Initialize global model
    global_model = CNN().to(device)
    previous_shap_values_per_client = [None] * num_clients
    previous_shap_features_per_client = [None] * num_clients
    drastic_change_factor = 3  # Example threshold for drastic change
    previous_client_losses = [0.0] * num_clients
    drastic_change_detected = [False] * num_clients
    sustained_high_loss = [False] * num_clients
    
    # Initialize feature occurrence dictionary
    feature_occurrence_dict = {}
    frequent_features = []

    # Initialize p-value tracking
    p_values_per_client = [[] for _ in range(num_clients)]
    
    # Initialize counters for drift detection performance
    true_positives = 0
    true_negatives = 0
    false_positives = 0
    false_negatives = 0

    # Federated learning process
    for global_epoch in range(num_global_epochs):
        print(f"Global Epoch {global_epoch + 1}/{num_global_epochs}")

        # Swap classes if the current global epoch is more than half
        if global_epoch == 10: # (num_global_epochs // 2):
            print("Swapping classes 0 and 1 in the train dataset for each client")
            for client_id in drift_clients:
                indices = client_indices[client_id]
                for idx in indices:
                    if trainset.targets[idx] == 8:
                        trainset.targets[idx] = 3
                    elif trainset.targets[idx] == 3:
                        trainset.targets[idx] = 8
                    
                    
        client_train_loaders = [torch.utils.data.DataLoader(torch.utils.data.Subset(trainset, indices), batch_size=20, shuffle=True) for indices in client_indices]
        client_test_loaders = [torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False) for _ in range(num_clients)]
        client_models = [copy.deepcopy(global_model) for _ in range(num_clients)]

        # Perform local training and compute SHAP values
        shap_values_per_client = []
        shap_features_per_client = []
        client_losses = []
        
        for i, train_loader in enumerate(client_train_loaders):
            print(f"\tTraining Client {i + 1}/{num_clients}")
            criterion = nn.CrossEntropyLoss()
            optimizer = optim.Adam(client_models[i].parameters(), lr=0.001)
            loss = client_training(train_loader, client_models[i], criterion, optimizer, num_local_epochs)
            client_losses.append(loss)   

            client_shap_data = []
            client_shap_labels = []
            client_class_count = {i: 0 for i in range(10)}

            for batch_idx, (images, labels) in enumerate(train_loader):
                for img, lbl in zip(images, labels):
                    if client_class_count[lbl.item()] < 10:  # Accumulate up to num images per class
                        client_shap_data.append(img)
                        client_shap_labels.append(lbl)
                        client_class_count[lbl.item()] += 1

                # Check if we've accumulated 10 images for each class
                if all(count >= 10 for count in client_class_count.values()):
                    break
            
            if all(count >= 10 for count in client_class_count.values()):
                shap_data = torch.stack(client_shap_data).to(device)
                shap_labels = torch.stack(client_shap_labels).to(device)

                # Split shap_data into 90% and 10% sets
                num_samples = shap_data.size(0)
                split_idx = int(num_samples * 0.9)
                background = shap_data[split_idx:].to(device)  # 10% of data for background
                shap_data = shap_data[:split_idx].to(device)  # 90% of data for SHAP computation
                shap_labels = shap_labels[:split_idx].to(device)  # Corresponding labels

                # Compute SHAP values by class
                print(f"\t\tComputed SHAP values for Client {i + 1}/{num_clients}")
                shap_values, shap_features = compute_shap_values(client_models[i], shap_data, background)
                shap_values_per_client.append(shap_values)
                shap_features_per_client.append(shap_features)

        # Perform Kolmogorov-Smirnov test on SHAP values
        print("\tPerforming Kolmogorov–Smirnov test for SHAP values")
        p_values = []
        for i, (shap_values_client, shap_features_client) in enumerate(zip(shap_values_per_client, shap_features_per_client)):
            if previous_shap_values_per_client[i] is not None:
                if global_epoch >= 5 and frequent_features:
                    common_features = list(set(shap_features_client) & set(previous_shap_features_per_client[i]) & set(frequent_features))
                else:
                    common_features = list(set(shap_features_client) & set(previous_shap_features_per_client[i]))
                if common_features:
                    current_indices = [shap_features_client.index(feature) for feature in common_features]
                    previous_indices = [previous_shap_features_per_client[i].index(feature) for feature in common_features]

                    current_shap_values = shap_values_client[current_indices]
                    previous_shap_values = previous_shap_values_per_client[i][previous_indices]

                    u_statistic, p_value = ks_2samp(previous_shap_values, current_shap_values)
                    p_values.append(p_value)
                    print(f"Global Epoch {global_epoch + 1}, Client {i+1}: Kolmogorov–Smirnov test statistic: {u_statistic}, p-value: {p_value}")
                    
                    # Store the current p-value
                    p_values_per_client[i].append(p_value)

            # Update previous SHAP values
            previous_shap_values_per_client[i] = shap_values_client
            previous_shap_features_per_client[i] = shap_features_client
            
        # Update feature occurrence dictionary
        for shap_features_client in shap_features_per_client:
            for feature in shap_features_client:
                if feature in feature_occurrence_dict:
                    feature_occurrence_dict[feature] += 1
                else:
                    feature_occurrence_dict[feature] = 1

        # concept drift detection and performance metrics
        if global_epoch >= 5:
            
            # frequent features list
            frequent_features = [feature for feature, count in feature_occurrence_dict.items() if count >= 20]
            
            # drift detected from shap values
            detected_drift_clients_1 = []
            for i, p_value in enumerate(p_values):
                if len(p_values_per_client[i]) > 1:
                    avg_p_value = sum(p_values_per_client[i][:-1]) / len(p_values_per_client[i][:-1])
                    if p_value < (avg_p_value)**2:
                        detected_drift_clients_1.append(i)

            # drift detection using loss values
            for i, (prev_loss, curr_loss) in enumerate(zip(previous_client_losses, client_losses)):
                if not drastic_change_detected[i]:
                    if curr_loss > (prev_loss * drastic_change_factor):
                        drastic_change_detected[i] = True
                else:
                    if curr_loss >= 4:
                        sustained_high_loss[i] = True
                    else:
                        drastic_change_detected[i] = False
                        sustained_high_loss[i] = False

            detected_drift_clients_2 = [
                i for i, (drastic, sustained) in enumerate(zip(drastic_change_detected, sustained_high_loss))
                if drastic and sustained
            ]
            
            # list of clients from both drift detection logics
            detected_drift_clients = list(set(detected_drift_clients_1) | set(detected_drift_clients_2))

            for client_id in range(num_clients):
                ground_truth_drift = (client_id in drift_clients) and (global_epoch >= 10)
                detected_drift = client_id in detected_drift_clients_1

                if detected_drift and ground_truth_drift:
                    true_positives += 1
                elif not detected_drift and not ground_truth_drift:
                    true_negatives += 1
                elif detected_drift and not ground_truth_drift:
                    false_positives += 1
                elif not detected_drift and ground_truth_drift:
                    false_negatives += 1

                if detected_drift == ground_truth_drift:
                    print(f"Global Epoch {global_epoch + 1}, Client {client_id + 1}: Drift detection is accurate.")
                else:
                    print(f"Global Epoch {global_epoch + 1}, Client {client_id + 1}: Drift detection is NOT accurate.")

        previous_client_losses = client_losses
        
        print("\tAggregating client models")
        global_model = aggregate_models(global_model, client_models)

    # Test models on each client
    accuracies = []
    for i, test_loader in enumerate(client_test_loaders):
        print(f"Testing Client {i + 1}/{num_clients}")
        accuracy = client_testing(test_loader, global_model)
        accuracies.append(accuracy)
        print(f"\tAccuracy for Client {i + 1}/{num_clients}: {accuracy}")

    # Print final drift detection performance
    print("\nFinal Drift Detection Performance Metrics:")
    print(f"True Positives: {true_positives}")
    print(f"True Negatives: {true_negatives}")
    print(f"False Positives: {false_positives}")
    print(f"False Negatives: {false_negatives}")

# Aggregation of models' weights using FedAvg
def aggregate_models(global_model, client_models):
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_model.state_dict()[k].float() for client_model in client_models], 0).mean(0)
    global_model.load_state_dict(global_dict)
    return global_model

    


# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Example usage
if __name__ == "__main__":
    client_accuracies=federated_learning(num_clients=10, num_local_epochs=1, num_global_epochs=20)
    print("Client Accuracies:", client_accuracies)
