In [1]:
%cd ../..

/Users/davideleo/Desktop/Projects/research/papers/fl_wavelet_v0


## Setup

In [2]:
import random
import torch 
import numpy as np 
from src.data.cifar10 import get_federation 
from src.data.attacks import GaussianBlur, GaussianNoise
from src.federated_learning.standard.fedavg import Client as Client
from src.models.neural_networks import LeNet5
from src.models.metrics import Accuracy, WeightedAccuracy
from copy import deepcopy

random.seed(42)
np.random.seed(42)
torch.random.manual_seed(42)

# Federation
model = LeNet5(in_channels = 3, in_padding = 0, num_classes = 10)

federation = get_federation(
    num_shards = 100,
    alpha = 1000,
    attacks = [GaussianNoise(sigma = .5), GaussianBlur(kernel_size = 11)],
    attacks_proba = 0.4
)

clients = [
    Client(
        train_dataset = dataset["train"],
        test_dataset = dataset["test"],
        distribution = dataset["distribution"],
        batch_size = 64,
        device = "cpu"
    ) for dataset in federation
]

waffle_preds = [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 24, 25, 26, 27, 29, 30, 31, 32, 33, 34, 35, 36, 37,
        38, 39, 40, 41, 42, 43, 44, 45, 46, 48, 49, 50, 52, 53, 54, 55, 56, 57,
        58, 60, 61, 62, 63, 64, 65, 66, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,
        78, 79, 80, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96,
        97, 98, 99]
waffle_clients = []
for i in range(100): 
    if i not in waffle_preds: 
        waffle_clients.append(clients[i])

benign_clients = [client for dataset, client in zip(federation, clients) if len(dataset["id"].split(".")) == 1] 

In [3]:
# Default experiments: new_seed = None -> [None, 7, 365]
new_seed = 7

if new_seed is not None: 
    random.seed(new_seed)
    np.random.seed(new_seed)
    torch.random.manual_seed(new_seed)

## FedAvg

In [4]:
from src.federated_learning.standard.fedavg import Server as FedAvgServer

# Training
server = FedAvgServer(
    clients = waffle_clients,
    participation_rate = len(waffle_clients), 
    model = deepcopy(model)
)

fedavg_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
fedavg_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(fedavg_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 21:12:25,455 - FEDAVG/Server - INFO - Round 1: training_loss = 0.0322
  0%|          | 1/500 [00:01<16:33,  1.99s/it]2025-05-21 21:12:26,827 - FEDAVG/Server - INFO - Round 2: training_loss = 0.0321
  0%|          | 2/500 [00:03<13:29,  1.63s/it]2025-05-21 21:12:28,194 - FEDAVG/Server - INFO - Round 3: training_loss = 0.0319
  1%|          | 3/500 [00:04<12:29,  1.51s/it]2025-05-21 21:12:29,654 - FEDAVG/Server - INFO - Round 4: training_loss = 0.0314
  1%|          | 4/500 [00:06<12:18,  1.49s/it]2025-05-21 21:12:31,111 - FEDAVG/Server - INFO - Round 5: training_loss = 0.0308
  1%|          | 5/500 [00:07<12:11,  1.48s/it]2025-05-21 21:12:32,862 - FEDAVG/Server - INFO - Round 6: training_loss = 0.0303
  1%|          | 6/500 [00:09<12:55,  1.57s/it]2025-05-21 21:12:34,524 - FEDAVG/Server - INFO - Round 7: training_loss = 0.0298
  1%|▏         | 7/500 [00:11<13:08,  1.60s/it]2025-05-21 21:12:36,035 - FEDAVG/Server - INFO - Round 8: trainin

{'loss': 1.5166973026394843, 'metrics': {'acc': 0.4739999977648258, 'wacc': 0.4740000171413024}}


## Krum

In [5]:
from src.federated_learning.standard.krum import Server as KrumServer

# Training
server = KrumServer(
    clients = waffle_clients,
    participation_rate = len(waffle_clients), 
    model = deepcopy(model),
    K = 1
)

krum_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
krum_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(krum_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 21:28:05,178 - KRUM/Server - INFO - Round 1: training_loss = 0.0323
  0%|          | 1/500 [00:01<15:33,  1.87s/it]2025-05-21 21:28:07,114 - KRUM/Server - INFO - Round 2: training_loss = 0.0321
  0%|          | 2/500 [00:03<15:50,  1.91s/it]2025-05-21 21:28:09,094 - KRUM/Server - INFO - Round 3: training_loss = 0.0319
  1%|          | 3/500 [00:05<16:04,  1.94s/it]2025-05-21 21:28:11,058 - KRUM/Server - INFO - Round 4: training_loss = 0.0313
  1%|          | 4/500 [00:07<16:07,  1.95s/it]2025-05-21 21:28:13,031 - KRUM/Server - INFO - Round 5: training_loss = 0.0308
  1%|          | 5/500 [00:09<16:09,  1.96s/it]2025-05-21 21:28:14,861 - KRUM/Server - INFO - Round 6: training_loss = 0.0302
  1%|          | 6/500 [00:11<15:45,  1.91s/it]2025-05-21 21:28:16,708 - KRUM/Server - INFO - Round 7: training_loss = 0.0297
  1%|▏         | 7/500 [00:13<15:33,  1.89s/it]2025-05-21 21:28:18,600 - KRUM/Server - INFO - Round 8: training_loss = 0.0297


{'loss': 1.6673362889091174, 'metrics': {'acc': 0.4261666650275389, 'wacc': 0.4261666799336672}}


## Multi-Krum

In [6]:
# Training
server = KrumServer(
    clients = waffle_clients,
    participation_rate = len(waffle_clients), 
    model = deepcopy(model),
    K = 5
)

mkrum_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
mkrum_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(mkrum_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 21:44:22,094 - KRUM/Server - INFO - Round 1: training_loss = 0.0322
  0%|          | 1/500 [00:02<16:46,  2.02s/it]2025-05-21 21:44:24,163 - KRUM/Server - INFO - Round 2: training_loss = 0.0321
  0%|          | 2/500 [00:04<16:59,  2.05s/it]2025-05-21 21:44:26,015 - KRUM/Server - INFO - Round 3: training_loss = 0.0318
  1%|          | 3/500 [00:05<16:13,  1.96s/it]2025-05-21 21:44:28,037 - KRUM/Server - INFO - Round 4: training_loss = 0.0314
  1%|          | 4/500 [00:07<16:23,  1.98s/it]2025-05-21 21:44:29,895 - KRUM/Server - INFO - Round 5: training_loss = 0.0308
  1%|          | 5/500 [00:09<15:59,  1.94s/it]2025-05-21 21:44:31,900 - KRUM/Server - INFO - Round 6: training_loss = 0.0301
  1%|          | 6/500 [00:11<16:08,  1.96s/it]2025-05-21 21:44:33,876 - KRUM/Server - INFO - Round 7: training_loss = 0.0298
  1%|▏         | 7/500 [00:13<16:09,  1.97s/it]2025-05-21 21:44:35,920 - KRUM/Server - INFO - Round 8: training_loss = 0.0295


{'loss': 1.5421865633130074, 'metrics': {'acc': 0.4596666635374228, 'wacc': 0.4596666804254055}}


## TrimmedMean

In [7]:
from src.federated_learning.standard.trimmedmean import Server as TrimmedMeanServer

# Training
server = TrimmedMeanServer(
    clients = waffle_clients,
    participation_rate = len(waffle_clients), 
    model = deepcopy(model),
    tail_size = 2
)

tmean_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
tmean_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(tmean_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 22:00:40,086 - TRIMMEDMEAN/Server - INFO - Round 1: training_loss = 0.0323
  0%|          | 1/500 [00:01<15:15,  1.83s/it]2025-05-21 22:00:42,081 - TRIMMEDMEAN/Server - INFO - Round 2: training_loss = 0.0321
  0%|          | 2/500 [00:03<16:00,  1.93s/it]2025-05-21 22:00:44,228 - TRIMMEDMEAN/Server - INFO - Round 3: training_loss = 0.0319
  1%|          | 3/500 [00:05<16:48,  2.03s/it]2025-05-21 22:00:46,287 - TRIMMEDMEAN/Server - INFO - Round 4: training_loss = 0.0313
  1%|          | 4/500 [00:08<16:52,  2.04s/it]2025-05-21 22:00:48,277 - TRIMMEDMEAN/Server - INFO - Round 5: training_loss = 0.0306
  1%|          | 5/500 [00:10<16:41,  2.02s/it]2025-05-21 22:00:50,076 - TRIMMEDMEAN/Server - INFO - Round 6: training_loss = 0.0301
  1%|          | 6/500 [00:11<16:01,  1.95s/it]2025-05-21 22:00:52,109 - TRIMMEDMEAN/Server - INFO - Round 7: training_loss = 0.0296
  1%|▏         | 7/500 [00:13<16:13,  1.97s/it]2025-05-21 22:00:54,115 - TRIM

{'loss': 1.521643837273121, 'metrics': {'acc': 0.4614999993741512, 'wacc': 0.4615000152736902}}


## GeoMed

In [8]:
from src.federated_learning.standard.geomed import Server as GeoMedServer

# Training
server = GeoMedServer(
    clients = waffle_clients,
    participation_rate = len(waffle_clients), 
    model = deepcopy(model),
    geomed_max_iter = 10
)

geomed_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
geomed_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(geomed_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 22:16:28,154 - GEOMED/Server - INFO - Round 1: training_loss = 0.0323
  0%|          | 1/500 [00:01<15:48,  1.90s/it]2025-05-21 22:16:30,039 - GEOMED/Server - INFO - Round 2: training_loss = 0.0321
  0%|          | 2/500 [00:03<15:41,  1.89s/it]2025-05-21 22:16:31,987 - GEOMED/Server - INFO - Round 3: training_loss = 0.0319
  1%|          | 3/500 [00:05<15:52,  1.92s/it]2025-05-21 22:16:33,893 - GEOMED/Server - INFO - Round 4: training_loss = 0.0314
  1%|          | 4/500 [00:07<15:48,  1.91s/it]2025-05-21 22:16:35,805 - GEOMED/Server - INFO - Round 5: training_loss = 0.0307
  1%|          | 5/500 [00:09<15:46,  1.91s/it]2025-05-21 22:16:37,670 - GEOMED/Server - INFO - Round 6: training_loss = 0.0302
  1%|          | 6/500 [00:11<15:36,  1.90s/it]2025-05-21 22:16:39,562 - GEOMED/Server - INFO - Round 7: training_loss = 0.0298
  1%|▏         | 7/500 [00:13<15:34,  1.89s/it]2025-05-21 22:16:41,371 - GEOMED/Server - INFO - Round 8: trainin

{'loss': 1.5196594678560893, 'metrics': {'acc': 0.47033332935969036, 'wacc': 0.4703333472410838}}


In [9]:
from json import dump 

with open(f"notebooks/cifar10/results/waffle_fft.json_seed{new_seed}", "w") as f: 
    d = {
        "fedavg": {"train": fedavg_train_results, "test": fedavg_evaluation_results},
        "krum": {"train": krum_train_results, "test": krum_evaluation_results},
        "mkrum": {"train": mkrum_train_results, "test": mkrum_evaluation_results},
        "trimmedmean": {"train": tmean_train_results, "test": tmean_evaluation_results},
        "geomed": {"train": geomed_train_results, "test": geomed_evaluation_results},
    }
    dump(d, f)