In [1]:
%cd ../..

/Users/davideleo/Desktop/Projects/research/papers/fl_wavelet_v0


## Setup

In [2]:
import random
import torch 
import numpy as np 
from src.data.cifar10 import get_federation 
from src.data.attacks import GaussianBlur, GaussianNoise
from src.federated_learning.standard.fedavg import Client as Client
from src.models.neural_networks import LeNet5
from src.models.metrics import Accuracy, WeightedAccuracy
from copy import deepcopy

random.seed(42)
np.random.seed(42)
torch.random.manual_seed(42)

# Federation
model = LeNet5(in_channels = 3, in_padding = 0, num_classes = 10)

federation = get_federation(
    num_shards = 100,
    alpha = 1000,
    attacks = [GaussianNoise(sigma = .5), GaussianBlur(kernel_size = 11)],
    attacks_proba = 0.4
)

clients = [
    Client(
        train_dataset = dataset["train"],
        test_dataset = dataset["test"],
        distribution = dataset["distribution"],
        batch_size = 64,
        device = "cpu"
    ) for dataset in federation
]

waffle_preds = [1,  2,  6,  7,  8, 14, 16, 18, 19, 20, 26, 32, 37, 39, 40, 43, 45, 46,
        50, 53, 54, 55, 57, 61, 63, 65, 68, 73, 75, 77, 80, 82, 85, 90, 92, 95,
        96, 97]
waffle_clients = []
for i in range(100): 
    if i not in waffle_preds: 
        waffle_clients.append(clients[i])

benign_clients = [client for dataset, client in zip(federation, clients) if len(dataset["id"].split(".")) == 1] 

In [3]:
# Default experiments: new_seed = None -> [None, 7, 365]
new_seed = 365

if new_seed is not None: 
    random.seed(new_seed)
    np.random.seed(new_seed)
    torch.random.manual_seed(new_seed)

## FedAvg

In [4]:
from src.federated_learning.standard.fedavg import Server as FedAvgServer

# Training
server = FedAvgServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model)
)

fedavg_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
fedavg_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(fedavg_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 18:08:01,797 - FEDAVG/Server - INFO - Round 1: training_loss = 0.0323
  0%|          | 1/500 [00:04<34:50,  4.19s/it]2025-05-21 18:08:05,694 - FEDAVG/Server - INFO - Round 2: training_loss = 0.0321
  0%|          | 2/500 [00:08<33:20,  4.02s/it]2025-05-21 18:08:09,480 - FEDAVG/Server - INFO - Round 3: training_loss = 0.0318
  1%|          | 3/500 [00:11<32:24,  3.91s/it]2025-05-21 18:08:13,336 - FEDAVG/Server - INFO - Round 4: training_loss = 0.0314
  1%|          | 4/500 [00:15<32:09,  3.89s/it]2025-05-21 18:08:17,292 - FEDAVG/Server - INFO - Round 5: training_loss = 0.0308
  1%|          | 5/500 [00:19<32:17,  3.91s/it]2025-05-21 18:08:21,086 - FEDAVG/Server - INFO - Round 6: training_loss = 0.0302
  1%|          | 6/500 [00:23<31:53,  3.87s/it]2025-05-21 18:08:25,092 - FEDAVG/Server - INFO - Round 7: training_loss = 0.0296
  1%|▏         | 7/500 [00:27<32:10,  3.92s/it]2025-05-21 18:08:28,994 - FEDAVG/Server - INFO - Round 8: trainin

{'loss': 1.4064278420408567, 'metrics': {'acc': 0.4936666645556688, 'wacc': 0.49366667748490967}}


## Krum

In [5]:
from src.federated_learning.standard.krum import Server as KrumServer

# Training
server = KrumServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model),
    K = 1
)

krum_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
krum_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(krum_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 18:40:51,384 - KRUM/Server - INFO - Round 1: training_loss = 0.0322
  0%|          | 1/500 [00:04<35:08,  4.23s/it]2025-05-21 18:40:55,601 - KRUM/Server - INFO - Round 2: training_loss = 0.0321
  0%|          | 2/500 [00:08<35:01,  4.22s/it]2025-05-21 18:40:59,738 - KRUM/Server - INFO - Round 3: training_loss = 0.0317
  1%|          | 3/500 [00:12<34:38,  4.18s/it]2025-05-21 18:41:04,531 - KRUM/Server - INFO - Round 4: training_loss = 0.0313
  1%|          | 4/500 [00:17<36:33,  4.42s/it]2025-05-21 18:41:09,504 - KRUM/Server - INFO - Round 5: training_loss = 0.0305
  1%|          | 5/500 [00:22<38:07,  4.62s/it]2025-05-21 18:41:14,597 - KRUM/Server - INFO - Round 6: training_loss = 0.0303
  1%|          | 6/500 [00:27<39:22,  4.78s/it]2025-05-21 18:41:19,571 - KRUM/Server - INFO - Round 7: training_loss = 0.03
  1%|▏         | 7/500 [00:32<39:48,  4.84s/it]2025-05-21 18:41:24,699 - KRUM/Server - INFO - Round 8: training_loss = 0.0296
  

{'loss': 1.4900594565272332, 'metrics': {'acc': 0.4648333315849304, 'wacc': 0.4648333444943031}}


## Multi-Krum

In [6]:
# Training
server = KrumServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model),
    K = 5
)

mkrum_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
mkrum_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(mkrum_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 19:12:49,873 - KRUM/Server - INFO - Round 1: training_loss = 0.0323
  0%|          | 1/500 [00:02<23:54,  2.88s/it]2025-05-21 19:12:52,652 - KRUM/Server - INFO - Round 2: training_loss = 0.0321
  0%|          | 2/500 [00:05<23:23,  2.82s/it]2025-05-21 19:12:55,520 - KRUM/Server - INFO - Round 3: training_loss = 0.0319
  1%|          | 3/500 [00:08<23:31,  2.84s/it]2025-05-21 19:12:58,211 - KRUM/Server - INFO - Round 4: training_loss = 0.0312
  1%|          | 4/500 [00:11<22:59,  2.78s/it]2025-05-21 19:13:01,020 - KRUM/Server - INFO - Round 5: training_loss = 0.0306
  1%|          | 5/500 [00:14<23:01,  2.79s/it]2025-05-21 19:13:03,892 - KRUM/Server - INFO - Round 6: training_loss = 0.0301
  1%|          | 6/500 [00:16<23:12,  2.82s/it]2025-05-21 19:13:06,724 - KRUM/Server - INFO - Round 7: training_loss = 0.0297
  1%|▏         | 7/500 [00:19<23:11,  2.82s/it]2025-05-21 19:13:09,538 - KRUM/Server - INFO - Round 8: training_loss = 0.0294


{'loss': 1.4081560252308845, 'metrics': {'acc': 0.49616666852434477, 'wacc': 0.49616668044527373}}


## TrimmedMean

In [7]:
from src.federated_learning.standard.trimmedmean import Server as TrimmedMeanServer

# Training
server = TrimmedMeanServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model),
    tail_size = 2
)

tmean_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
tmean_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(tmean_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 19:36:06,473 - TRIMMEDMEAN/Server - INFO - Round 1: training_loss = 0.0322
  0%|          | 1/500 [00:02<23:18,  2.80s/it]2025-05-21 19:36:09,225 - TRIMMEDMEAN/Server - INFO - Round 2: training_loss = 0.0321
  0%|          | 2/500 [00:05<23:00,  2.77s/it]2025-05-21 19:36:12,027 - TRIMMEDMEAN/Server - INFO - Round 3: training_loss = 0.0318
  1%|          | 3/500 [00:08<23:04,  2.79s/it]2025-05-21 19:36:14,794 - TRIMMEDMEAN/Server - INFO - Round 4: training_loss = 0.0313
  1%|          | 4/500 [00:11<22:58,  2.78s/it]2025-05-21 19:36:17,466 - TRIMMEDMEAN/Server - INFO - Round 5: training_loss = 0.0305
  1%|          | 5/500 [00:13<22:36,  2.74s/it]2025-05-21 19:36:20,180 - TRIMMEDMEAN/Server - INFO - Round 6: training_loss = 0.03
  1%|          | 6/500 [00:16<22:29,  2.73s/it]2025-05-21 19:36:22,901 - TRIMMEDMEAN/Server - INFO - Round 7: training_loss = 0.0297
  1%|▏         | 7/500 [00:19<22:24,  2.73s/it]2025-05-21 19:36:25,636 - TRIMME

{'loss': 1.3962401349941889, 'metrics': {'acc': 0.5001666631549597, 'wacc': 0.5001666745841503}}


## GeoMed

In [8]:
from src.federated_learning.standard.geomed import Server as GeoMedServer

# Training
server = GeoMedServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model),
    geomed_max_iter = 10
)

geomed_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
geomed_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(geomed_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 19:59:58,914 - GEOMED/Server - INFO - Round 1: training_loss = 0.0322
  0%|          | 1/500 [00:02<24:06,  2.90s/it]2025-05-21 20:00:01,784 - GEOMED/Server - INFO - Round 2: training_loss = 0.0321
  0%|          | 2/500 [00:05<23:55,  2.88s/it]2025-05-21 20:00:04,521 - GEOMED/Server - INFO - Round 3: training_loss = 0.032
  1%|          | 3/500 [00:08<23:19,  2.82s/it]2025-05-21 20:00:07,473 - GEOMED/Server - INFO - Round 4: training_loss = 0.0314
  1%|          | 4/500 [00:11<23:43,  2.87s/it]2025-05-21 20:00:10,433 - GEOMED/Server - INFO - Round 5: training_loss = 0.0309
  1%|          | 5/500 [00:14<23:56,  2.90s/it]2025-05-21 20:00:13,209 - GEOMED/Server - INFO - Round 6: training_loss = 0.0302
  1%|          | 6/500 [00:17<23:32,  2.86s/it]2025-05-21 20:00:16,086 - GEOMED/Server - INFO - Round 7: training_loss = 0.0298
  1%|▏         | 7/500 [00:20<23:32,  2.87s/it]2025-05-21 20:00:18,907 - GEOMED/Server - INFO - Round 8: training

{'loss': 1.434299098153909, 'metrics': {'acc': 0.48033333113789556, 'wacc': 0.48033334157367547}}


In [9]:
from json import dump 

with open(f"notebooks/cifar10/results/waffle_wst_seed{new_seed}.json", "w") as f: 
    d = {
        "fedavg": {"train": fedavg_train_results, "test": fedavg_evaluation_results},
        "krum": {"train": krum_train_results, "test": krum_evaluation_results},
        "mkrum": {"train": mkrum_train_results, "test": mkrum_evaluation_results},
        "trimmedmean": {"train": tmean_train_results, "test": tmean_evaluation_results},
        "geomed": {"train": geomed_train_results, "test": geomed_evaluation_results},
    }
    dump(d, f)