In [70]:
import random
required_numbers = [6, 50, 20, 5, 37, 42, 47, 13, 28, 15, 57, 10, 9, 31, 2, 41, 35, 1, 23, 58, 38, 46, 21, 48, 24, 43, 32, 4, 51, 56]
range_values = list(range(0, 60))

# Ensure the list contains required numbers and is filled with other random values from the range
final_list = random.sample(required_numbers, len(required_numbers))
final_list += random.sample([num for num in range_values if num not in required_numbers], 42 - len(required_numbers))

# Shuffle to mix required and random numbers
random.shuffle(final_list)
a=len(final_list)
a

42

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from scipy.ndimage import rotate
import copy
import random
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split

In [2]:
# Set seed for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
# Define the model architecture
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(64 * 5 * 5, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = x.view(-1, 64 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.log_softmax(x, dim=1)

# Function to simulate client training with MNIST data
def client_training(train_loader, model, criterion, optimizer, num_epochs=1):
    model.train()
    for epoch in range(num_epochs):
        total_loss, correct, total = 0.0, 0.0, 0.0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        accuracy = correct / total
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss}, Accuracy: {accuracy}")
        
    return total_loss
    

# Function to test the trained model on client data
def client_testing(test_loader, model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

In [4]:
# Function to apply gradual rotation drift
def apply_rotation_drift(images, current_epoch, max_rotation, start_epoch, end_epoch, total_epochs):
    if current_epoch < start_epoch:
        rotation_angle = 0
    elif start_epoch <= current_epoch <= end_epoch:
        transition_progress = (current_epoch - start_epoch) / (end_epoch - start_epoch)
        rotation_angle = transition_progress * max_rotation
    else:
        rotation_angle = 0
    
    fraction_rotated = (current_epoch - start_epoch + 1) / (end_epoch - start_epoch + 1)
    num_images_to_rotate = int(fraction_rotated * len(images))
    
    drifted_images = images.clone()
    if num_images_to_rotate > 0 and fraction_rotated<=1:
        indices_to_rotate = torch.randperm(len(images))[:num_images_to_rotate]
        for idx in indices_to_rotate:
            drifted_images[idx] = torch.tensor(rotate(images[idx].numpy(), rotation_angle, reshape=False))
    
    return drifted_images

In [6]:
def create_global_optimum_model(global_model, global_model_sdrift, alpha):
    global_optimum_model = copy.deepcopy(global_model)
    global_dict = global_optimum_model.state_dict()
    
    for k in global_dict.keys():
        global_dict[k] = alpha * global_model_sdrift.state_dict()[k] + (1-alpha) * global_model.state_dict()[k]
    
    global_optimum_model.load_state_dict(global_dict)
    return global_optimum_model

In [10]:
def federated_learning(num_clients=5, num_local_epochs=5, num_global_epochs=5):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths)
    
    client_indices = [dataset.indices for dataset in datasets]
    
    # clients that will experience concept drift
    drift_clients = [12, 30, 31, 0, 32, 20, 53, 15, 5, 13, 42, 10, 55, 35, 41, 48, 58, 40, 38, 9, 21, 14, 24, 43, 6, 57, 37, 4, 46, 47, 50, 11, 34, 59, 2, 51, 1, 23, 28, 56, 19, 16]
#     drift_clients = [6, 50, 20, 5, 37, 42, 47, 13, 28, 15, 57, 10, 9, 31, 2, 41, 35, 1, 23, 58, 38, 46, 21, 48, 24, 43, 32, 4, 51, 56]

#     drift_clients = [2, 6, 9, 10, 13, 15, 21, 23, 28, 35, 37, 41, 43, 47, 51, 56]
#     drift_clients = [2, 6, 9, 28, 41, 56]

    # Initialize global models
    global_model = CNN().to(device)
    global_model_sdrift = CNN().to(device)
    
    previous_client_losses = [[] for _ in range(num_clients)]
    drastic_change_detected = [False] * num_clients
    sustained_high_loss = [False] * num_clients
    fixed_long_averages = [None] * num_clients 

    # Initialize detected drift clients
    detected_drift_clients = []
    fixed_detected_drift_clients = []
    
    # Dictionary to track how many epochs each client has been detected for drift
    drift_count = {i: 0 for i in range(num_clients)}
    
    # Define window sizes (best result 12,17)
    short_window = 15
    long_window = 20
    
    # Initialize counters for drift detection performance
    true_positives = 0
    true_negatives = 0
    false_positives = 0
    false_negatives = 0
    
    # Select indices in test set for applying drift
    drifted_test_indices = np.random.choice(len(testset), size=int(1 * len(testset)), replace=False)

    # Federated learning process
    drifted_testset = None
    for global_epoch in range(num_global_epochs):
        print(f"Global Epoch {global_epoch + 1}/{num_global_epochs}")

        # Apply rotation drift gradually starting from a certain epoch
        if 10 <= global_epoch < 31:  # Example threshold for starting drift
            print("Applying rotation drift to specific clients")
            for client_id in drift_clients:
                indices = client_indices[client_id]
                images = trainset.data[indices]
                drifted_images = apply_rotation_drift(images, global_epoch, max_rotation=30, start_epoch=10, end_epoch=30, total_epochs=num_global_epochs)
                trainset.data[indices] = drifted_images

            # Apply the same drift to a subset of test data
            drifted_test_images = testset.data.clone()
            drifted_test_images[drifted_test_indices] = apply_rotation_drift(testset.data[drifted_test_indices], global_epoch, max_rotation=30, start_epoch=10, end_epoch=30, total_epochs=num_global_epochs)
            drifted_testset = copy.deepcopy(testset)
            drifted_testset.data = drifted_test_images

        client_train_loaders = [torch.utils.data.DataLoader(torch.utils.data.Subset(trainset, indices), batch_size=32, shuffle=True) for indices in client_indices]
        client_test_loaders = [torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True) for _ in range(num_clients)]
        if drifted_testset is not None:
            drifted_client_test_loaders = [torch.utils.data.DataLoader(drifted_testset, batch_size=64, shuffle=True) for _ in range(num_clients)]

        # Distribute the respective global models back to the clients
        client_models = []

        for i in range(num_clients):
            if i in fixed_detected_drift_clients:
                client_models.append(copy.deepcopy(global_model_sdrift))
            else:
                client_models.append(copy.deepcopy(global_model))

        # Perform local training and optimize drifted clients classifier
        client_losses = []
        
        for i, train_loader in enumerate(client_train_loaders):
            print(f"\tTraining Client {i + 1}/{num_clients}")
            criterion = nn.CrossEntropyLoss()
            
            # Change beta1 and beta2 if concept drift detected
            if i in fixed_detected_drift_clients:
                print(f"Initiating optimization for client {i + 1}")
                optimizer = optim.Adam(client_models[i].parameters(), lr=0.001, betas=(0.6, 0.7))
                loss = client_training(train_loader, client_models[i], criterion, optimizer, num_local_epochs)
            else:
                optimizer = optim.Adam(client_models[i].parameters(), lr=0.001)
                loss = client_training(train_loader, client_models[i], criterion, optimizer, num_local_epochs)
                
            client_losses.append(loss)
            
            
            # Update previous client losses
            previous_client_losses[i].append(loss)
            if len(previous_client_losses[i]) > long_window:
                previous_client_losses[i].pop(0)

        # Concept drift detection and performance metrics
        if global_epoch in range(5, 29):
            # Gradual detection using moving averages
            for i in range(num_clients):
                if len(previous_client_losses[i]) >= long_window:
                    short_avg = np.mean(previous_client_losses[i][-short_window:])
                    
                    if fixed_long_averages[i] is None:
                        long_avg = np.mean(previous_client_losses[i][-long_window:])
                    else:
                        long_avg = fixed_long_averages[i]

                    if client_losses[i] > short_avg:
                        drastic_change_detected[i] = True
                    else:
                        drastic_change_detected[i] = False

                    if client_losses[i] > long_avg:
                        sustained_high_loss[i] = True
                        if fixed_long_averages[i] is None:
                            fixed_long_averages[i] = long_avg
                    else:
                        sustained_high_loss[i] = False

            detected_drift_clients = [
                i for i, (drastic, sustained) in enumerate(zip(drastic_change_detected, sustained_high_loss))
                if drastic and sustained
            ]
            
            # Check if there are new clients in detected_drift_clients
            for client in detected_drift_clients:
                drift_count[client] += 1
            
            # Determine the maximum drift count in the freeze list
            max_drift_count = max([drift_count[client] for client in fixed_detected_drift_clients], default=0)
            
            # Update fixed_detected_drift_clients with any new drifted clients with higher drift count
            for client in detected_drift_clients:
                if client not in fixed_detected_drift_clients and drift_count[client] > max_drift_count:
                    fixed_detected_drift_clients.append(client)
            
            for client_id in range(num_clients):
                ground_truth_drift = (client_id in drift_clients) and (global_epoch >= 19) and (global_epoch <= 29)
                detected_drift = client_id in detected_drift_clients

                if detected_drift and ground_truth_drift:
                    true_positives += 1
                elif not detected_drift and not ground_truth_drift:
                    true_negatives += 1
                elif detected_drift and not ground_truth_drift:
                    false_positives += 1
                elif not detected_drift and ground_truth_drift:
                    false_negatives += 1

                if detected_drift == ground_truth_drift:
                    print(f"Global Epoch {global_epoch + 1}, Client {client_id + 1}: Drift detection is accurate.")
                else:
                    print(f"Global Epoch {global_epoch + 1}, Client {client_id + 1}: Drift detection is inaccurate.")
                    
        # Aggregate client models into global model
        print("\tAggregating client models")
        drifted_client_models = [client_models[i] for i in fixed_detected_drift_clients]
        non_drifted_client_models = [client_models[i] for i in range(num_clients) if i not in fixed_detected_drift_clients]

        if drifted_client_models:
            global_model_sdrift = aggregate_models(global_model_sdrift, drifted_client_models)

        if non_drifted_client_models:
            global_model = aggregate_models(global_model, non_drifted_client_models)
    
    # Test models on each client
    accuracies_avg = []
    additional_accuracies_avg = []

    # Create the averaged model
    avg_model = create_global_optimum_model(global_model, global_model_sdrift, 0.5)
    global_optimum_model = create_global_optimum_model(global_model, global_model_sdrift, 0.35)
    
    for i, drifted_test_loader in enumerate(drifted_client_test_loaders):
        print(f"Testing Client {i + 1}/{num_clients}")

        # Test with the averaged model
        accuracy_avg = client_testing(drifted_test_loader, avg_model)
        accuracies_avg.append(accuracy_avg)
        print(f"\tAccuracy for Client {i + 1}/{num_clients} using avg_model: {accuracy_avg}")

        # Test with the averaged model on drifted test data
        accuracy_avg_optm = client_testing(drifted_test_loader, global_optimum_model)
        additional_accuracies_avg.append(accuracy_avg_optm)
        print(f"\tAccuracy for Client {i + 1}/{num_clients} using global_optimum_model: {accuracy_avg_optm}")

    # Calculate performance metrics of drift detection
    accuracy = (true_positives + true_negatives) / (true_positives + true_negatives + false_positives + false_negatives)
    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
    f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    print(f"Drift Detection metrics::: true_positives:{true_positives},true_negatives:{true_negatives},false_positives:{false_positives},false_negatives:{false_negatives}")
    print(f"Drift Detection Performance:\nAccuracy: {accuracy}\nPrecision: {precision}\nRecall: {recall}\nF1 Score: {f1_score}")

# Aggregation of models' weights using FedAvg
def aggregate_models(global_model, client_models):
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_model.state_dict()[k].float() for client_model in client_models], 0).mean(0)
    global_model.load_state_dict(global_dict)
    return global_model

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if __name__ == "__main__":
    federated_learning(num_clients=60, num_local_epochs=1, num_global_epochs=40)
    print("--------------------------------------------")
    print("Experiment Completed")
    print("--------------------------------------------")

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 34054808.87it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1040202.09it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 9429543.73it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 5496401.84it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Global Epoch 1/40
	Training Client 1/60
Epoch 1/1, Loss: 51.12887167930603, Accuracy: 0.488
	Training Client 2/60
Epoch 1/1, Loss: 49.150713950395584, Accuracy: 0.532
	Training Client 3/60
Epoch 1/1, Loss: 48.964728474617004, Accuracy: 0.527
	Training Client 4/60
Epoch 1/1, Loss: 46.206991314888, Accuracy: 0.536
	Training Client 5/60
Epoch 1/1, Loss: 49.01094609498978, Accuracy: 0.498
	Training Client 6/60
Epoch 1/1, Loss: 46.55858953297138, Accuracy: 0.547
	Training Client 7/60
Epoch 1/1, Loss: 46.956732988357544, Accuracy: 0.537
	Training Client 8/60
Epoch 1/1, Loss: 47.52933883666992, Accuracy: 0.557
	Training Client 9/60
Epoch 1/1, Loss: 47.53733825683594, Accuracy: 0.508
	Training Client 10/60
Epoch 1/1, Loss: 47.277563631534576, Accuracy: 0.532
	Training Client 11/60
Epoch 1/1, Loss: 48.85650992393494, Accuracy: 0.514
	Training Client 12/60
Epoch 1/1, Loss: 50.21908086538315, Accuracy: 0.499
	Training Clie

In [11]:
# Federated learning process simple fedavg without multiple models
def federated_learning(num_clients=5, num_local_epochs=5, num_global_epochs=5):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths)
    
    client_indices = [dataset.indices for dataset in datasets]
    
    # clients that will experience concept drift
#     drift_clients = [12, 30, 31, 0, 32, 20, 53, 15, 5, 13, 42, 10, 55, 35, 41, 48, 58, 40, 38, 9, 21, 14, 24, 43, 6, 57, 37, 4, 46, 47, 50, 11, 34, 59, 2, 51, 1, 23, 28, 56, 19, 16]
#     drift_clients = [6, 50, 20, 5, 37, 42, 47, 13, 28, 15, 57, 10, 9, 31, 2, 41, 35, 1, 23, 58, 38, 46, 21, 48, 24, 43, 32, 4, 51, 56]

#     drift_clients = [2, 6, 9, 10, 13, 15, 21, 23, 28, 35, 37, 41, 43, 47, 51, 56]
    drift_clients = [2, 6, 9, 28, 41, 56]

    # Initialize global models
    global_model = CNN().to(device)
    global_model_sdrift = CNN().to(device)
    
    previous_client_losses = [[] for _ in range(num_clients)]
    drastic_change_detected = [False] * num_clients
    sustained_high_loss = [False] * num_clients
    fixed_long_averages = [None] * num_clients 

    # Initialize detected drift clients
    detected_drift_clients = []
    fixed_detected_drift_clients = []
    
    # Dictionary to track how many epochs each client has been detected for drift
    drift_count = {i: 0 for i in range(num_clients)}
    
    # Define window sizes (best result 12,17)
    short_window = 15
    long_window = 20
    
    # Initialize counters for drift detection performance
    true_positives = 0
    true_negatives = 0
    false_positives = 0
    false_negatives = 0
    
    # Select indices in test set for applying drift
    drifted_test_indices = np.random.choice(len(testset), size=int(1 * len(testset)), replace=False)
    

    # Federated learning process
    drifted_testset = None
    for global_epoch in range(num_global_epochs):
        print(f"Global Epoch {global_epoch + 1}/{num_global_epochs}")

        # Apply rotation drift gradually starting from a certain epoch
        if 10 <= global_epoch < 31:  # Example threshold for starting drift
            print("Applying rotation drift to specific clients")
            for client_id in drift_clients:
                indices = client_indices[client_id]
                images = trainset.data[indices]
                drifted_images = apply_rotation_drift(images, global_epoch, max_rotation=30, start_epoch=10, end_epoch=30, total_epochs=num_global_epochs)
                trainset.data[indices] = drifted_images

                
            # Apply the same drift to a subset of test data
            drifted_test_images = testset.data.clone()
            drifted_test_images[drifted_test_indices] = apply_rotation_drift(testset.data[drifted_test_indices], global_epoch, max_rotation=30, start_epoch=10, end_epoch=30, total_epochs=num_global_epochs)
            drifted_testset = copy.deepcopy(testset)
            drifted_testset.data = drifted_test_images

        client_train_loaders = [torch.utils.data.DataLoader(torch.utils.data.Subset(trainset, indices), batch_size=32, shuffle=True) for indices in client_indices]
        client_test_loaders = [torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True) for _ in range(num_clients)]
        if drifted_testset is not None:
            drifted_client_test_loaders = [torch.utils.data.DataLoader(drifted_testset, batch_size=64, shuffle=True) for _ in range(num_clients)]

        # Distribute the respective global models back to the clients
        client_models = []

        for i in range(num_clients):
            if i in fixed_detected_drift_clients:
                client_models.append(copy.deepcopy(global_model_sdrift))
            else:
                client_models.append(copy.deepcopy(global_model))

        # Perform local training and optimize drifted clients classifier
        client_losses = []
        
        for i, train_loader in enumerate(client_train_loaders):
            print(f"\tTraining Client {i + 1}/{num_clients}")
            criterion = nn.CrossEntropyLoss()
            
            # Change beta1 and beta2 if concept drift detected

            optimizer = optim.Adam(client_models[i].parameters(), lr=0.001)
            loss = client_training(train_loader, client_models[i], criterion, optimizer, num_local_epochs)
                
            client_losses.append(loss)
            
            
            # Update previous client losses
            previous_client_losses[i].append(loss)
            if len(previous_client_losses[i]) > long_window:
                previous_client_losses[i].pop(0)

        # Concept drift detection and performance metrics
        if global_epoch in range(5, 29):
            # Gradual detection using moving averages
            for i in range(num_clients):
                if len(previous_client_losses[i]) >= long_window:
                    short_avg = np.mean(previous_client_losses[i][-short_window:])
                    
                    if fixed_long_averages[i] is None:
                        long_avg = np.mean(previous_client_losses[i][-long_window:])
                    else:
                        long_avg = fixed_long_averages[i]

                    if client_losses[i] > short_avg:
                        drastic_change_detected[i] = True
                    else:
                        drastic_change_detected[i] = False

                    if client_losses[i] > long_avg:
                        sustained_high_loss[i] = True
                        if fixed_long_averages[i] is None:
                            fixed_long_averages[i] = long_avg
                    else:
                        sustained_high_loss[i] = False

            detected_drift_clients = [
                i for i, (drastic, sustained) in enumerate(zip(drastic_change_detected, sustained_high_loss))
                if drastic and sustained
            ]
            
#             # Check if there are new clients in detected_drift_clients
#             for client in detected_drift_clients:
#                 drift_count[client] += 1
            
#             # Update fixed_detected_drift_clients with any new drifted clients with higher drift count
#             for client in detected_drift_clients:
#                 if client not in fixed_detected_drift_clients and drift_count[client] > 3:
#                     fixed_detected_drift_clients.append(client)
            
            
            for client_id in range(num_clients):
                ground_truth_drift = (client_id in drift_clients) and (global_epoch >= 19) and (global_epoch <= 29)
                detected_drift = client_id in detected_drift_clients

                if detected_drift and ground_truth_drift:
                    true_positives += 1
                elif not detected_drift and not ground_truth_drift:
                    true_negatives += 1
                elif detected_drift and not ground_truth_drift:
                    false_positives += 1
                elif not detected_drift and ground_truth_drift:
                    false_negatives += 1

                if detected_drift == ground_truth_drift:
                    print(f"Global Epoch {global_epoch + 1}, Client {client_id + 1}: Drift detection is accurate.")
                else:
                    print(f"Global Epoch {global_epoch + 1}, Client {client_id + 1}: Drift detection is inaccurate.")
                    
        # Aggregate client models into global model
        print("\tAggregating client models")
        global_model = aggregate_models(global_model, client_models)
    
    # Test models on each client
    accuracies_avg = []

    

    for i, drifted_test_loader in enumerate(drifted_client_test_loaders):
        print(f"Testing Client {i + 1}/{num_clients}")

        # Test with the averaged model
        accuracy_avg = client_testing(drifted_test_loader, global_model)
        accuracies_avg.append(accuracy_avg)
        print(f"\tAccuracy for Client {i + 1}/{num_clients} using avg_model: {accuracy_avg}")

    

    # Calculate performance metrics of drift detection
    accuracy = (true_positives + true_negatives) / (true_positives + true_negatives + false_positives + false_negatives)
    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
    f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    print(f"Drift Detection metrics::: true_positives:{true_positives},true_negatives:{true_negatives},false_positives:{false_positives},false_negatives:{false_negatives}")
    print(f"Drift Detection Performance:\nAccuracy: {accuracy}\nPrecision: {precision}\nRecall: {recall}\nF1 Score: {f1_score}")

# Aggregation of models' weights using FedAvg
def aggregate_models(global_model, client_models):
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_model.state_dict()[k].float() for client_model in client_models], 0).mean(0)
    global_model.load_state_dict(global_dict)
    return global_model

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if __name__ == "__main__":
    federated_learning(num_clients=60, num_local_epochs=1, num_global_epochs=40)
    print("--------------------------------------------")
    print("Experiment Completed")
    print("--------------------------------------------")

Global Epoch 1/40
	Training Client 1/60
Epoch 1/1, Loss: 52.03825771808624, Accuracy: 0.453
	Training Client 2/60
Epoch 1/1, Loss: 52.474316358566284, Accuracy: 0.498
	Training Client 3/60
Epoch 1/1, Loss: 51.20811140537262, Accuracy: 0.481
	Training Client 4/60
Epoch 1/1, Loss: 49.882207095623016, Accuracy: 0.524
	Training Client 5/60
Epoch 1/1, Loss: 52.14577031135559, Accuracy: 0.519
	Training Client 6/60
Epoch 1/1, Loss: 51.85648339986801, Accuracy: 0.476
	Training Client 7/60
Epoch 1/1, Loss: 53.88154822587967, Accuracy: 0.459
	Training Client 8/60
Epoch 1/1, Loss: 50.25312155485153, Accuracy: 0.507
	Training Client 9/60
Epoch 1/1, Loss: 52.41806694865227, Accuracy: 0.499
	Training Client 10/60
Epoch 1/1, Loss: 55.17402836680412, Accuracy: 0.484
	Training Client 11/60
Epoch 1/1, Loss: 57.820004403591156, Accuracy: 0.412
	Training Client 12/60
Epoch 1/1, Loss: 50.99312025308609, Accuracy: 0.496
	Training Client 13/60
Epoch 1/1, Loss: 50.642578423023224, Accuracy: 0.493
	Training Cl

In [7]:
# adaptive learning rate

def federated_learning(num_clients=5, num_local_epochs=5, num_global_epochs=5):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths)
    
    client_indices = [dataset.indices for dataset in datasets]
    
    # clients that will experience concept drift
    drift_clients = [12, 30, 31, 0, 32, 20, 53, 15, 5, 13, 42, 10, 55, 35, 41, 48, 58, 40, 38, 9, 21, 14, 24, 43, 6, 57, 37, 4, 46, 47, 50, 11, 34, 59, 2, 51, 1, 23, 28, 56, 19, 16]
#     drift_clients = [6, 50, 20, 5, 37, 42, 47, 13, 28, 15, 57, 10, 9, 31, 2, 41, 35, 1, 23, 58, 38, 46, 21, 48, 24, 43, 32, 4, 51, 56]

#     drift_clients = [2, 6, 9, 10, 13, 15, 21, 23, 28, 35, 37, 41, 43, 47, 51, 56]
#     drift_clients = [2, 6, 9, 28, 41, 56]

    # Initialize global models
    global_model = CNN().to(device)
    global_model_sdrift = CNN().to(device)
    
    previous_client_losses = [[] for _ in range(num_clients)]
    drastic_change_detected = [False] * num_clients
    sustained_high_loss = [False] * num_clients
    fixed_long_averages = [None] * num_clients 

    # Initialize detected drift clients
    detected_drift_clients = []
    fixed_detected_drift_clients = []
    
    # Dictionary to track how many epochs each client has been detected for drift
    drift_count = {i: 0 for i in range(num_clients)}
    
    # Define window sizes (best result 12,17)
    short_window = 15
    long_window = 20
    
    # Initialize counters for drift detection performance
    true_positives = 0
    true_negatives = 0
    false_positives = 0
    false_negatives = 0
    
    # Select indices in test set for applying drift
    drifted_test_indices = np.random.choice(len(testset), size=int(1 * len(testset)), replace=False)

    # Learning rate for each client
    client_lrs = [0.001 for _ in range(num_clients)]

    # Federated learning process
    drifted_testset = None
    for global_epoch in range(num_global_epochs):
        print(f"Global Epoch {global_epoch + 1}/{num_global_epochs}")

        # Apply rotation drift gradually starting from a certain epoch
        if 10 <= global_epoch < 31:  # Example threshold for starting drift
            print("Applying rotation drift to specific clients")
            for client_id in drift_clients:
                indices = client_indices[client_id]
                images = trainset.data[indices]
                drifted_images = apply_rotation_drift(images, global_epoch, max_rotation=30, start_epoch=10, end_epoch=30, total_epochs=num_global_epochs)
                trainset.data[indices] = drifted_images

            # Apply the same drift to a subset of test data
            drifted_test_images = testset.data.clone()
            drifted_test_images[drifted_test_indices] = apply_rotation_drift(testset.data[drifted_test_indices], global_epoch, max_rotation=30, start_epoch=10, end_epoch=30, total_epochs=num_global_epochs)
            drifted_testset = copy.deepcopy(testset)
            drifted_testset.data = drifted_test_images

        client_train_loaders = [torch.utils.data.DataLoader(torch.utils.data.Subset(trainset, indices), batch_size=32, shuffle=True) for indices in client_indices]
        client_test_loaders = [torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True) for _ in range(num_clients)]
        if drifted_testset is not None:
            drifted_client_test_loaders = [torch.utils.data.DataLoader(drifted_testset, batch_size=64, shuffle=True) for _ in range(num_clients)]

        # Distribute the respective global models back to the clients
        client_models = []

        for i in range(num_clients):
            if i in fixed_detected_drift_clients:
                client_models.append(copy.deepcopy(global_model_sdrift))
            else:
                client_models.append(copy.deepcopy(global_model))

        # Perform local training and optimize drifted clients classifier
        client_losses = []
        
        for i, train_loader in enumerate(client_train_loaders):
            print(f"\tTraining Client {i + 1}/{num_clients}")
            criterion = nn.CrossEntropyLoss()

            # Adjust learning rate dynamically for clients with drift
            if i in fixed_detected_drift_clients:
                print(f"Initiating optimization for client {i + 1} with reduced learning rate.")
                client_lrs[i] *= 0.95  # Gradually reduce learning rate for drifted clients
                optimizer = optim.Adam(client_models[i].parameters(), lr=client_lrs[i], betas=(0.6, 0.7))
            else:
                optimizer = optim.Adam(client_models[i].parameters(), lr=client_lrs[i])

            loss = client_training(train_loader, client_models[i], criterion, optimizer, num_local_epochs)
            client_losses.append(loss)
            
            # Update previous client losses
            previous_client_losses[i].append(loss)
            if len(previous_client_losses[i]) > long_window:
                previous_client_losses[i].pop(0)

        # Concept drift detection and performance metrics
        if global_epoch in range(5, 29):
            # Gradual detection using moving averages
            for i in range(num_clients):
                if len(previous_client_losses[i]) >= long_window:
                    short_avg = np.mean(previous_client_losses[i][-short_window:])
                    
                    if fixed_long_averages[i] is None:
                        long_avg = np.mean(previous_client_losses[i][-long_window:])
                    else:
                        long_avg = fixed_long_averages[i]

                    if client_losses[i] > short_avg:
                        drastic_change_detected[i] = True
                    else:
                        drastic_change_detected[i] = False

                    if client_losses[i] > long_avg:
                        sustained_high_loss[i] = True
                        if fixed_long_averages[i] is None:
                            fixed_long_averages[i] = long_avg
                    else:
                        sustained_high_loss[i] = False

            detected_drift_clients = [
                i for i, (drastic, sustained) in enumerate(zip(drastic_change_detected, sustained_high_loss))
                if drastic and sustained
            ]
            
            # Check if there are new clients in detected_drift_clients
            for client in detected_drift_clients:
                drift_count[client] += 1
            
            # Determine the maximum drift count in the freeze list
            max_drift_count = max([drift_count[client] for client in fixed_detected_drift_clients], default=0)
            
            # Update fixed_detected_drift_clients with any new drifted clients with higher drift count
            for client in detected_drift_clients:
                if client not in fixed_detected_drift_clients and drift_count[client] > max_drift_count:
                    fixed_detected_drift_clients.append(client)
            
            for client_id in range(num_clients):
                ground_truth_drift = (client_id in drift_clients) and (global_epoch >= 19) and (global_epoch <= 29)
                detected_drift = client_id in detected_drift_clients

                if detected_drift and ground_truth_drift:
                    true_positives += 1
                elif not detected_drift and not ground_truth_drift:
                    true_negatives += 1
                elif detected_drift and not ground_truth_drift:
                    false_positives += 1
                elif not detected_drift and ground_truth_drift:
                    false_negatives += 1

                if detected_drift == ground_truth_drift:
                    print(f"Global Epoch {global_epoch + 1}, Client {client_id + 1}: Drift detection is accurate.")
                else:
                    print(f"Global Epoch {global_epoch + 1}, Client {client_id + 1}: Drift detection is inaccurate.")
                    
        # Aggregate client models into global model
        print("\tAggregating client models")
        drifted_client_models = [client_models[i] for i in fixed_detected_drift_clients]
        non_drifted_client_models = [client_models[i] for i in range(num_clients) if i not in fixed_detected_drift_clients]

        if drifted_client_models:
            global_model_sdrift = aggregate_models(global_model_sdrift, drifted_client_models)

        if non_drifted_client_models:
            global_model = aggregate_models(global_model, non_drifted_client_models)
    
    # Rest of your testing and metrics logic...
    # Test models on each client
    accuracies_avg = []
    additional_accuracies_avg = []

    # Create the averaged model
    avg_model = create_global_optimum_model(global_model, global_model_sdrift, 0.5)
    global_optimum_model = create_global_optimum_model(global_model, global_model_sdrift, 0.35)
    
    for i, drifted_test_loader in enumerate(drifted_client_test_loaders):
        print(f"Testing Client {i + 1}/{num_clients}")

        # Test with the averaged model
        accuracy_avg = client_testing(drifted_test_loader, avg_model)
        accuracies_avg.append(accuracy_avg)
        print(f"\tAccuracy for Client {i + 1}/{num_clients} using avg_model: {accuracy_avg}")

        # Test with the averaged model on drifted test data
        accuracy_avg_optm = client_testing(drifted_test_loader, global_optimum_model)
        additional_accuracies_avg.append(accuracy_avg_optm)
        print(f"\tAccuracy for Client {i + 1}/{num_clients} using global_optimum_model: {accuracy_avg_optm}")

    # Calculate performance metrics of drift detection
    accuracy = (true_positives + true_negatives) / (true_positives + true_negatives + false_positives + false_negatives)
    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
    f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    print(f"Drift Detection metrics::: true_positives:{true_positives},true_negatives:{true_negatives},false_positives:{false_positives},false_negatives:{false_negatives}")
    print(f"Drift Detection Performance:\nAccuracy: {accuracy}\nPrecision: {precision}\nRecall: {recall}\nF1 Score: {f1_score}")

# Aggregation of models' weights using FedAvg
def aggregate_models(global_model, client_models):
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_model.state_dict()[k].float() for client_model in client_models], 0).mean(0)
    global_model.load_state_dict(global_dict)
    return global_model

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if __name__ == "__main__":
    federated_learning(num_clients=60, num_local_epochs=1, num_global_epochs=60)
    print("--------------------------------------------")
    print("Experiment Completed")
    print("--------------------------------------------")

Global Epoch 1/60
	Training Client 1/60
Epoch 1/1, Loss: 53.13452070951462, Accuracy: 0.443
	Training Client 2/60
Epoch 1/1, Loss: 51.90388888120651, Accuracy: 0.471
	Training Client 3/60
Epoch 1/1, Loss: 53.83594053983688, Accuracy: 0.463
	Training Client 4/60
Epoch 1/1, Loss: 53.3793740272522, Accuracy: 0.459
	Training Client 5/60
Epoch 1/1, Loss: 53.71051913499832, Accuracy: 0.471
	Training Client 6/60
Epoch 1/1, Loss: 52.29525971412659, Accuracy: 0.482
	Training Client 7/60
Epoch 1/1, Loss: 62.116252064704895, Accuracy: 0.423
	Training Client 8/60
Epoch 1/1, Loss: 52.008520901203156, Accuracy: 0.496
	Training Client 9/60
Epoch 1/1, Loss: 51.012915194034576, Accuracy: 0.494
	Training Client 10/60
Epoch 1/1, Loss: 53.929866433143616, Accuracy: 0.489
	Training Client 11/60
Epoch 1/1, Loss: 50.757097482681274, Accuracy: 0.464
	Training Client 12/60
Epoch 1/1, Loss: 53.56581324338913, Accuracy: 0.47
	Training Client 13/60
Epoch 1/1, Loss: 55.63694900274277, Accuracy: 0.44
	Training Clie

KeyboardInterrupt: 