In [1]:
%cd ../..

/Users/davideleo/Desktop/Projects/research/papers/fl_wavelet_v0


# Malicious clients analysis
## FedAvg w/ malicious clients 

In [2]:
import random
import torch 
import numpy as np 
import matplotlib.pyplot as plt
from src.data.cifar10 import get_federation 
from src.data.attacks import GaussianBlur, GaussianNoise
from src.federated_learning.standard.fedavg import Client, Server 
from src.models.neural_networks import LeNet5
from src.models.metrics import Accuracy, WeightedAccuracy

random.seed(42)
np.random.seed(42)
torch.random.manual_seed(42)

# Federation
federation = get_federation(
    num_shards = 100,
    alpha = 1000,
    attacks = [GaussianNoise(sigma = .5), GaussianBlur(kernel_size = 11)],
    attacks_proba = 0.4
)

clients = [
    Client(
        train_dataset = dataset["train"],
        test_dataset = dataset["test"],
        distribution = dataset["distribution"],
        batch_size = 64,
        device = "cpu"
    ) for dataset in federation
]

benign_clients = [client for dataset, client in zip(federation, clients) if len(dataset["id"].split(".")) == 1] 

## FedAvg w/o malicious clients

In [3]:
# Default experiments: new_seed = None -> [42, 7, 365]
new_seed = 7

if new_seed is not None: 
    random.seed(new_seed)
    np.random.seed(new_seed)
    torch.random.manual_seed(new_seed)

# Trainin
server = Server(
    clients = benign_clients,
    participation_rate = 10, 
    model = LeNet5(in_channels = 3, in_padding = 0, num_classes = 10)
)

benign_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 501,
    metrics = dict()
)

# Evaluation 
evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-22 00:36:23,901 - FEDAVG/Server - INFO - Round 1: training_loss = 0.0323
  0%|          | 1/500 [00:02<23:09,  2.78s/it]2025-05-22 00:36:26,663 - FEDAVG/Server - INFO - Round 2: training_loss = 0.0321
  0%|          | 2/500 [00:05<23:00,  2.77s/it]2025-05-22 00:36:29,191 - FEDAVG/Server - INFO - Round 3: training_loss = 0.0318
  1%|          | 3/500 [00:08<22:02,  2.66s/it]2025-05-22 00:36:31,571 - FEDAVG/Server - INFO - Round 4: training_loss = 0.0311
  1%|          | 4/500 [00:10<21:04,  2.55s/it]2025-05-22 00:36:34,182 - FEDAVG/Server - INFO - Round 5: training_loss = 0.0305
  1%|          | 5/500 [00:13<21:12,  2.57s/it]2025-05-22 00:36:36,700 - FEDAVG/Server - INFO - Round 6: training_loss = 0.0299
  1%|          | 6/500 [00:15<21:01,  2.55s/it]2025-05-22 00:36:39,263 - FEDAVG/Server - INFO - Round 7: training_loss = 0.0297
  1%|▏         | 7/500 [00:18<21:00,  2.56s/it]2025-05-22 00:36:41,917 - FEDAVG/Server - INFO - Round 8: trainin

{'loss': 1.3936750630338988, 'metrics': {'acc': 0.5066666635870933, 'wacc': 0.5066666730344296}}
