In [1]:
!pip install scikit-multiflow

Collecting scikit-multiflow
  Downloading scikit-multiflow-0.5.3.tar.gz (450 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m450.6/450.6 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: scikit-multiflow
  Building wheel for scikit-multiflow (setup.py) ... [?25ldone
[?25h  Created wheel for scikit-multiflow: filename=scikit_multiflow-0.5.3-cp310-cp310-linux_x86_64.whl size=576330 sha256=9de41833016b48fdc6ebcb3ffdbb1fe2c3ac58f83048d9c47ae03f1b4a2a602b
  Stored in directory: /root/.cache/pip/wheels/6e/1b/56/45b17a6cf203d98000a45976cb0dd0c4c3f11960e6a505f231
Successfully built scikit-multiflow
Installing collected packages: scikit-multiflow
Successfully installed scikit-multiflow-0.5.3


In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from scipy.ndimage import rotate
from scipy.stats import mannwhitneyu, ttest_ind, ks_2samp, ttest_ind_from_stats
import copy
import random
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split
from skmultiflow.drift_detection import ADWIN, DDM, EDDM, HDDM_W

In [4]:
# Set seed for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [5]:
# Define the model architecture
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(64 * 5 * 5, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = x.view(-1, 64 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.log_softmax(x, dim=1)

# Function to simulate client training with MNIST data
def client_training(train_loader, model, criterion, optimizer, num_epochs=1):
    model.train()
    batch_losses = []
    for epoch in range(num_epochs):
        total_loss, correct, total = 0.0, 0.0, 0.0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            batch_losses.append(loss.item())
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()        
        accuracy = correct / total
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss}, Accuracy: {accuracy}")
        
    return batch_losses

# Function to test the trained model on client data
def client_testing(test_loader, model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

In [6]:
# Function to apply gradual rotation drift
def apply_rotation_drift(images, current_epoch, max_rotation, start_epoch, end_epoch, total_epochs):
    if current_epoch < start_epoch:
        rotation_angle = 0
    elif start_epoch <= current_epoch <= end_epoch:
        transition_progress = (current_epoch - start_epoch) / (end_epoch - start_epoch)
        rotation_angle = transition_progress * max_rotation
    else:
        rotation_angle = 0
    
    fraction_rotated = (current_epoch - start_epoch + 1) / (end_epoch - start_epoch + 1)
    num_images_to_rotate = int(fraction_rotated * len(images))
    
    drifted_images = images.clone()
    if num_images_to_rotate > 0 and fraction_rotated<=1:
        indices_to_rotate = torch.randperm(len(images))[:num_images_to_rotate]
        for idx in indices_to_rotate:
            drifted_images[idx] = torch.tensor(rotate(images[idx].numpy(), rotation_angle, reshape=False))
    
    return drifted_images

In [7]:
#sudden

def federated_learning(num_clients=60, num_local_epochs=5, num_global_epochs=40):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths)

    client_indices = [dataset.indices for dataset in datasets]

    # Clients that will experience rotation drift
#     rotation_drift_clients = [2, 6, 9,13,15,27]

    # Clients that will experience class swap drift
    class_swap_drift_clients = [2, 6, 9, 13, 20, 27, 33, 37, 41, 55]

    # Initialize global model
    global_model = CNN().to(device)
    previous_client_losses = [[] for _ in range(num_clients)]

    # Initialize ADWIN, DDM, EDDM detectors
    hddm_detectors = [HDDM_W(drift_confidence=0.001, warning_confidence=0.005, lambda_option=0.05, two_side_option=True) for _ in range(num_clients)]
    ddm_detectors = [DDM(min_num_instances=3, warning_level=1, out_control_level=3) for _ in range(num_clients)]
    # Create an instance of EDDM
    edddm = EDDM()

    # Modify parameters after initialization if necessary
    edddm.min_num_instances = 3
    edddm.warning_level = 1
    edddm.out_control_level = 3
    
    eddm_detectors = [edddm for _ in range(num_clients)]

    # Initialize counters for drift detection performance
    metrics = {
        'HDDM': {'TP': 0, 'TN': 0, 'FP': 0, 'FN': 0},
        'DDM': {'TP': 0, 'TN': 0, 'FP': 0, 'FN': 0},
        'EDDM': {'TP': 0, 'TN': 0, 'FP': 0, 'FN': 0},
        'Combined': {'TP': 0, 'TN': 0, 'FP': 0, 'FN': 0}
    }
    # Federated learning process
    for global_epoch in range(num_global_epochs):
        print(f"Global Epoch {global_epoch + 1}/{num_global_epochs}")

        # Apply rotation drift gradually starting from a certain epoch
#         if global_epoch in range(10, 31):  # Example threshold for starting drift
#             print("Applying rotation drift to specific clients")
#             for client_id in rotation_drift_clients:
#                 indices = client_indices[client_id]
#                 images = trainset.data[indices]
#                 drifted_images = apply_rotation_drift(images, global_epoch, max_rotation=30, start_epoch=10, end_epoch=30, total_epochs=num_global_epochs)
#                 trainset.data[indices] = drifted_images

        # Apply class swap drift at a specific epoch
        if global_epoch == 10:
            print("Swapping classes 3 and 8 in the train dataset for selected clients")
            for client_id in class_swap_drift_clients:
                indices = client_indices[client_id]
                for idx in indices:
                    if trainset.targets[idx] == 8:
                        trainset.targets[idx] = 3
                    elif trainset.targets[idx] == 3:
                        trainset.targets[idx] = 8
                        
                    elif trainset.targets[idx] == 5:
                        trainset.targets[idx] = 6
                    elif trainset.targets[idx] == 6:
                        trainset.targets[idx] = 5

        client_train_loaders = [torch.utils.data.DataLoader(torch.utils.data.Subset(trainset, indices), batch_size=32, shuffle=True) for indices in client_indices]
        client_test_loaders = [torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False) for _ in range(num_clients)]
        client_models = [copy.deepcopy(global_model) for _ in range(num_clients)]

        # Perform local training and compute SHAP values
        client_losses = []

        for i, train_loader in enumerate(client_train_loaders):
            print(f"\tTraining Client {i + 1}/{num_clients}")
            criterion = nn.CrossEntropyLoss()
            optimizer = optim.Adam(client_models[i].parameters(), lr=0.001)
            loss = client_training(train_loader, client_models[i], criterion, optimizer, num_local_epochs)
            client_losses.append(loss)

            # Update previous client losses
            previous_client_losses[i].append(loss)
            
            
            

        # Concept drift detection using ADWIN, DDM, and EDDM
        if global_epoch >= 5:
            detected_drift_clients_ddm = []
            detected_drift_clients_eddm = []
            detected_drift_clients_hddm = []

            for i, losses in enumerate(client_losses):
                for loss in losses:
                    hddm_detectors[i].add_element(loss)
                    ddm_detectors[i].add_element(loss)
                    eddm_detectors[i].add_element(loss)

                    if hddm_detectors[i].detected_change():
                        detected_drift_clients_hddm.append(i)
                    if ddm_detectors[i].detected_change():
                        detected_drift_clients_ddm.append(i)
                    if eddm_detectors[i].detected_change():
                        detected_drift_clients_eddm.append(i)
                        
                        
                    
            detected_drift_clients_combined = list(set(  detected_drift_clients_hddm + detected_drift_clients_ddm + detected_drift_clients_eddm))
            print(f"ddm:{detected_drift_clients_ddm}, eddm:{detected_drift_clients_eddm}, hddm{detected_drift_clients_hddm}, combined{detected_drift_clients_combined}")
            print(f"metrics:{metrics}")
            


            for detector_name, detected_drift_clients in zip(['DDM', 'EDDM', 'HDDM', 'Combined'], [detected_drift_clients_ddm, detected_drift_clients_eddm, detected_drift_clients_hddm, detected_drift_clients_combined]):
                for client_id in range(num_clients):
                    ground_truth_drift = ((client_id in class_swap_drift_clients) and (global_epoch >= 10)) #| ((client_id in rotation_drift_clients) and (global_epoch >= 19) and (global_epoch <= 29))
                    detected_drift = client_id in detected_drift_clients

                    if detected_drift and ground_truth_drift:
                        metrics[detector_name]['TP'] += 1
                    elif not detected_drift and not ground_truth_drift:
                        metrics[detector_name]['TN'] += 1
                    elif detected_drift and not ground_truth_drift:
                        metrics[detector_name]['FP'] += 1
                    elif not detected_drift and ground_truth_drift:
                        metrics[detector_name]['FN'] += 1

                    if detected_drift == ground_truth_drift:
                        print(f"Global Epoch {global_epoch + 1}, Client {client_id + 1}: {detector_name} Drift detection is accurate.")
                    else:
                        print(f"Global Epoch {global_epoch + 1}, Client {client_id + 1}: {detector_name} Drift detection is inaccurate.")

        # Aggregate client models into global model
        print("\tAggregating client models")
        global_model = aggregate_models(global_model, client_models)

    # Test models on each client
    accuracies = []
    for i, test_loader in enumerate(client_test_loaders):
        print(f"Testing Client {i + 1}/{num_clients}")
        accuracy = client_testing(test_loader, global_model)
        accuracies.append(accuracy)
        print(f"\tAccuracy for Client {i + 1}/{num_clients}: {accuracy}")

    # Calculate performance metrics of drift detection
    for detector_name in ['DDM', 'EDDM','HDDM', 'Combined']:
        TP = metrics[detector_name]['TP']
        TN = metrics[detector_name]['TN']
        FP = metrics[detector_name]['FP']
        FN = metrics[detector_name]['FN']

        accuracy = (TP + TN) / (TP + TN + FP + FN)
        precision = TP / (TP + FP) if (TP + FP) > 0 else 0
        recall = TP / (TP + FN) if (TP + FN) > 0 else 0
        f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        print(f"Drift Detection metrics::: true_positives:{TP}, true_negatives:{TN}, false_positives:{FP}, false_negatives:{FN}")
        print(f"Drift Detection Performance ({detector_name}):\nAccuracy: {accuracy}\nPrecision: {precision}\nRecall: {recall}\nF1 Score: {f1_score}")

    return accuracies

# Aggregation of models' weights using FedAvg
def aggregate_models(global_model, client_models):
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_model.state_dict()[k].float() for client_model in client_models], 0).mean(0)
    global_model.load_state_dict(global_dict)
    return global_model

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Example usage
if __name__ == "__main__":
    client_accuracies = federated_learning(num_clients=60, num_local_epochs=1, num_global_epochs=40)
    print("Client Accuracies:", client_accuracies)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 35938039.16it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1035604.50it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 10481259.52it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2733610.10it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Global Epoch 1/40
	Training Client 1/60
Epoch 1/1, Loss: 49.7998705804348, Accuracy: 0.507
	Training Client 2/60
Epoch 1/1, Loss: 48.39897897839546, Accuracy: 0.551
	Training Client 3/60
Epoch 1/1, Loss: 47.44576150178909, Accuracy: 0.507
	Training Client 4/60
Epoch 1/1, Loss: 47.85296183824539, Accuracy: 0.533
	Training Client 5/60
Epoch 1/1, Loss: 50.38479936122894, Accuracy: 0.516
	Training Client 6/60
Epoch 1/1, Loss: 48.611885249614716, Accuracy: 0.528
	Training Client 7/60
Epoch 1/1, Loss: 51.794359147548676, Accuracy: 0.464
	Training Client 8/60
Epoch 1/1, Loss: 52.096949100494385, Accuracy: 0.494
	Training Client 9/60
Epoch 1/1, Loss: 51.13297080993652, Accuracy: 0.525
	Training Client 10/60
Epoch 1/1, Loss: 47.53140366077423, Accuracy: 0.56
	Training Client 11/60
Epoch 1/1, Loss: 47.49449124932289, Accuracy: 0.564
	Training Client 12/60
Epoch 1/1, Loss: 50.82881420850754, Accuracy: 0.507
	Training Clien

  self.miss_std = np.sqrt(self.miss_prob * (1 - self.miss_prob) / float(self.sample_count))


Epoch 1/1, Loss: 3.5448168758302927, Accuracy: 0.96
	Training Client 2/60
Epoch 1/1, Loss: 4.447187410667539, Accuracy: 0.956
	Training Client 3/60
Epoch 1/1, Loss: 20.79115556180477, Accuracy: 0.796
	Training Client 4/60
Epoch 1/1, Loss: 3.7643631948158145, Accuracy: 0.961
	Training Client 5/60
Epoch 1/1, Loss: 3.2484924364835024, Accuracy: 0.972
	Training Client 6/60
Epoch 1/1, Loss: 4.334619461558759, Accuracy: 0.964
	Training Client 7/60
Epoch 1/1, Loss: 18.814032297581434, Accuracy: 0.796
	Training Client 8/60
Epoch 1/1, Loss: 3.7736006751656532, Accuracy: 0.965
	Training Client 9/60
Epoch 1/1, Loss: 3.426117251627147, Accuracy: 0.973
	Training Client 10/60
Epoch 1/1, Loss: 20.846773959696293, Accuracy: 0.802
	Training Client 11/60
Epoch 1/1, Loss: 3.035882818076061, Accuracy: 0.968
	Training Client 12/60
Epoch 1/1, Loss: 3.5730089442804456, Accuracy: 0.965
	Training Client 13/60
Epoch 1/1, Loss: 3.3276825584471226, Accuracy: 0.961
	Training Client 14/60
Epoch 1/1, Loss: 19.070665

In [8]:
# incremental

#sudden

def federated_learning(num_clients=60, num_local_epochs=5, num_global_epochs=40):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths)

    client_indices = [dataset.indices for dataset in datasets]

    # Clients that will experience rotation drift
    rotation_drift_clients = [6, 50, 20, 5, 37, 42, 47, 13, 28, 15, 57, 10, 9, 31, 2, 41, 35, 1, 23, 58, 38, 46, 21, 48, 24, 43, 32, 4, 51, 56]

    # Clients that will experience class swap drift
#     class_swap_drift_clients = [2, 6, 9, 13, 20, 27, 33, 37, 41, 55]
    
    #     drift_clients = [6, 50, 20, 5, 37, 42, 47, 13, 28, 15, 57, 10, 9, 31, 2, 41, 35, 1, 23, 58, 38, 46, 21, 48, 24, 43, 32, 4, 51, 56]

#     drift_clients = [2, 6, 9, 10, 13, 15, 21, 23, 28, 35, 37, 41, 43, 47, 51, 56]
#     drift_clients = [2, 6, 9, 28, 41, 56]

    # Initialize global model
    global_model = CNN().to(device)
    previous_client_losses = [[] for _ in range(num_clients)]

    # Initialize ADWIN, DDM, EDDM detectors
    hddm_detectors = [HDDM_W(drift_confidence=0.001, warning_confidence=0.005, lambda_option=0.05, two_side_option=True) for _ in range(num_clients)]
    ddm_detectors = [DDM(min_num_instances=3, warning_level=1, out_control_level=3) for _ in range(num_clients)]
    # Create an instance of EDDM
    edddm = EDDM()

    # Modify parameters after initialization if necessary
    edddm.min_num_instances = 3
    edddm.warning_level = 1
    edddm.out_control_level = 3
    
    eddm_detectors = [edddm for _ in range(num_clients)]

    # Initialize counters for drift detection performance
    metrics = {
        'HDDM': {'TP': 0, 'TN': 0, 'FP': 0, 'FN': 0},
        'DDM': {'TP': 0, 'TN': 0, 'FP': 0, 'FN': 0},
        'EDDM': {'TP': 0, 'TN': 0, 'FP': 0, 'FN': 0},
        'Combined': {'TP': 0, 'TN': 0, 'FP': 0, 'FN': 0}
    }
    # Federated learning process
    for global_epoch in range(num_global_epochs):
        print(f"Global Epoch {global_epoch + 1}/{num_global_epochs}")

        # Apply rotation drift gradually starting from a certain epoch
        if global_epoch in range(10, 31):  # Example threshold for starting drift
            print("Applying rotation drift to specific clients")
            for client_id in rotation_drift_clients:
                indices = client_indices[client_id]
                images = trainset.data[indices]
                drifted_images = apply_rotation_drift(images, global_epoch, max_rotation=30, start_epoch=10, end_epoch=30, total_epochs=num_global_epochs)
                trainset.data[indices] = drifted_images

#         # Apply class swap drift at a specific epoch
#         if global_epoch == 10:
#             print("Swapping classes 3 and 8 in the train dataset for selected clients")
#             for client_id in class_swap_drift_clients:
#                 indices = client_indices[client_id]
#                 for idx in indices:
#                     if trainset.targets[idx] == 8:
#                         trainset.targets[idx] = 3
#                     elif trainset.targets[idx] == 3:
#                         trainset.targets[idx] = 8
                        
#                     elif trainset.targets[idx] == 5:
#                         trainset.targets[idx] = 6
#                     elif trainset.targets[idx] == 6:
#                         trainset.targets[idx] = 5

        client_train_loaders = [torch.utils.data.DataLoader(torch.utils.data.Subset(trainset, indices), batch_size=32, shuffle=True) for indices in client_indices]
        client_test_loaders = [torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False) for _ in range(num_clients)]
        client_models = [copy.deepcopy(global_model) for _ in range(num_clients)]

        # Perform local training and compute SHAP values
        client_losses = []

        for i, train_loader in enumerate(client_train_loaders):
            print(f"\tTraining Client {i + 1}/{num_clients}")
            criterion = nn.CrossEntropyLoss()
            optimizer = optim.Adam(client_models[i].parameters(), lr=0.001)
            loss = client_training(train_loader, client_models[i], criterion, optimizer, num_local_epochs)
            client_losses.append(loss)

            # Update previous client losses
            previous_client_losses[i].append(loss)
            
            
            

        # Concept drift detection using ADWIN, DDM, and EDDM
        if global_epoch >= 5 and global_epoch <= 28:
            detected_drift_clients_ddm = []
            detected_drift_clients_eddm = []
            detected_drift_clients_hddm = []

            for i, losses in enumerate(client_losses):
                for loss in losses:
                    hddm_detectors[i].add_element(loss)
                    ddm_detectors[i].add_element(loss)
                    eddm_detectors[i].add_element(loss)

                    if hddm_detectors[i].detected_change():
                        detected_drift_clients_hddm.append(i)
                    if ddm_detectors[i].detected_change():
                        detected_drift_clients_ddm.append(i)
                    if eddm_detectors[i].detected_change():
                        detected_drift_clients_eddm.append(i)
                        
                        
                    
            detected_drift_clients_combined = list(set(  detected_drift_clients_hddm + detected_drift_clients_ddm + detected_drift_clients_eddm))
            print(f"ddm:{detected_drift_clients_ddm}, eddm:{detected_drift_clients_eddm}, hddm{detected_drift_clients_hddm}, combined{detected_drift_clients_combined}")
            print(f"metrics:{metrics}")
            


            for detector_name, detected_drift_clients in zip(['DDM', 'EDDM', 'HDDM', 'Combined'], [detected_drift_clients_ddm, detected_drift_clients_eddm, detected_drift_clients_hddm, detected_drift_clients_combined]):
                for client_id in range(num_clients):
                    ground_truth_drift = (client_id in rotation_drift_clients) and (global_epoch >= 19) and (global_epoch <= 29)
                    detected_drift = client_id in detected_drift_clients

                    if detected_drift and ground_truth_drift:
                        metrics[detector_name]['TP'] += 1
                    elif not detected_drift and not ground_truth_drift:
                        metrics[detector_name]['TN'] += 1
                    elif detected_drift and not ground_truth_drift:
                        metrics[detector_name]['FP'] += 1
                    elif not detected_drift and ground_truth_drift:
                        metrics[detector_name]['FN'] += 1

                    if detected_drift == ground_truth_drift:
                        print(f"Global Epoch {global_epoch + 1}, Client {client_id + 1}: {detector_name} Drift detection is accurate.")
                    else:
                        print(f"Global Epoch {global_epoch + 1}, Client {client_id + 1}: {detector_name} Drift detection is inaccurate.")

        # Aggregate client models into global model
        print("\tAggregating client models")
        global_model = aggregate_models(global_model, client_models)

    # Test models on each client
    accuracies = []
    for i, test_loader in enumerate(client_test_loaders):
        print(f"Testing Client {i + 1}/{num_clients}")
        accuracy = client_testing(test_loader, global_model)
        accuracies.append(accuracy)
        print(f"\tAccuracy for Client {i + 1}/{num_clients}: {accuracy}")

    # Calculate performance metrics of drift detection
    for detector_name in ['DDM', 'EDDM','HDDM', 'Combined']:
        TP = metrics[detector_name]['TP']
        TN = metrics[detector_name]['TN']
        FP = metrics[detector_name]['FP']
        FN = metrics[detector_name]['FN']

        accuracy = (TP + TN) / (TP + TN + FP + FN)
        precision = TP / (TP + FP) if (TP + FP) > 0 else 0
        recall = TP / (TP + FN) if (TP + FN) > 0 else 0
        f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        print(f"Drift Detection metrics::: true_positives:{TP}, true_negatives:{TN}, false_positives:{FP}, false_negatives:{FN}")
        print(f"Drift Detection Performance ({detector_name}):\nAccuracy: {accuracy}\nPrecision: {precision}\nRecall: {recall}\nF1 Score: {f1_score}")

    return accuracies

# Aggregation of models' weights using FedAvg
def aggregate_models(global_model, client_models):
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_model.state_dict()[k].float() for client_model in client_models], 0).mean(0)
    global_model.load_state_dict(global_dict)
    return global_model

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Example usage
if __name__ == "__main__":
    client_accuracies = federated_learning(num_clients=60, num_local_epochs=1, num_global_epochs=40)
    print("Client Accuracies:", client_accuracies)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 33618212.99it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 988136.83it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 9371825.00it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3819272.01it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Global Epoch 1/40
	Training Client 1/60
Epoch 1/1, Loss: 49.7998705804348, Accuracy: 0.507
	Training Client 2/60
Epoch 1/1, Loss: 48.39897897839546, Accuracy: 0.551
	Training Client 3/60
Epoch 1/1, Loss: 47.44576150178909, Accuracy: 0.507
	Training Client 4/60
Epoch 1/1, Loss: 47.85296183824539, Accuracy: 0.533
	Training Client 5/60
Epoch 1/1, Loss: 50.38479936122894, Accuracy: 0.516
	Training Client 6/60
Epoch 1/1, Loss: 48.611885249614716, Accuracy: 0.528
	Training Client 7/60
Epoch 1/1, Loss: 51.794359147548676, Accuracy: 0.464
	Training Client 8/60
Epoch 1/1, Loss: 52.096949100494385, Accuracy: 0.494
	Training Client 9/60
Epoch 1/1, Loss: 51.13297080993652, Accuracy: 0.525
	Training Client 10/60
Epoch 1/1, Loss: 47.53140366077423, Accuracy: 0.56
	Training Client 11/60
Epoch 1/1, Loss: 47.49449124932289, Accuracy: 0.564
	Training Client 12/60
Epoch 1/1, Loss: 50.82881420850754, Accuracy: 0.507
	Training Clien

  self.miss_std = np.sqrt(self.miss_prob * (1 - self.miss_prob) / float(self.sample_count))


	Training Client 1/60
Epoch 1/1, Loss: 3.19597311122925, Accuracy: 0.967
	Training Client 2/60
Epoch 1/1, Loss: 17.639482602477074, Accuracy: 0.844
	Training Client 3/60
Epoch 1/1, Loss: 17.872203215956688, Accuracy: 0.83
	Training Client 4/60
Epoch 1/1, Loss: 3.1000210917554796, Accuracy: 0.968
	Training Client 5/60
Epoch 1/1, Loss: 16.73876717686653, Accuracy: 0.831
	Training Client 6/60
Epoch 1/1, Loss: 18.284272700548172, Accuracy: 0.836
	Training Client 7/60
Epoch 1/1, Loss: 19.12045055627823, Accuracy: 0.832
	Training Client 8/60
Epoch 1/1, Loss: 3.5463344079616945, Accuracy: 0.962
	Training Client 9/60
Epoch 1/1, Loss: 2.623308934736997, Accuracy: 0.973
	Training Client 10/60
Epoch 1/1, Loss: 17.768496111035347, Accuracy: 0.836
	Training Client 11/60
Epoch 1/1, Loss: 17.22267486155033, Accuracy: 0.838
	Training Client 12/60
Epoch 1/1, Loss: 2.105351832753513, Accuracy: 0.973
	Training Client 13/60
Epoch 1/1, Loss: 3.059891506229178, Accuracy: 0.971
	Training Client 14/60
Epoch 1