In [1]:
%cd ../..

/Users/davideleo/Desktop/Projects/research/papers/fl_wavelet_v0


## Setup

In [2]:
import random
import torch 
import numpy as np 
from src.data.fashion_mnist import get_federation 
from src.data.attacks import GaussianBlur, GaussianNoise
from src.federated_learning.standard.fedavg import Client as Client
from src.models.neural_networks import LeNet5
from src.models.metrics import Accuracy, WeightedAccuracy
from copy import deepcopy

random.seed(42)
np.random.seed(42)
torch.random.manual_seed(42)

# Federation
model = LeNet5(in_channels = 1, in_padding = 2, num_classes = 10)

federation = get_federation(
    num_shards = 100,
    alpha = 1000,
    attacks = [GaussianNoise(sigma = .5), GaussianBlur(kernel_size = 11)],
    attacks_proba = 0.4
)

clients = [
    Client(
        train_dataset = dataset["train"],
        test_dataset = dataset["test"],
        distribution = dataset["distribution"],
        batch_size = 64,
        device = "cpu"
    ) for dataset in federation
]

waffle_preds = [1,  5,  6,  8,  9, 12, 13, 14, 16, 18, 19, 20, 24, 26, 28, 29, 30, 31,
                33, 34, 37, 39, 40, 45, 47, 50, 53, 54, 55, 57, 63, 64, 65, 66, 68, 69,
                73, 74, 75, 82, 85, 91, 92, 96, 97, 99]
waffle_clients = []
for i in range(100): 
    if i not in waffle_preds: 
        waffle_clients.append(clients[i])

benign_clients = [client for dataset, client in zip(federation, clients) if len(dataset["id"].split(".")) == 1] 

In [3]:
# Default experiments: new_seed = None -> [None, 7, 365]
new_seed = 7

if new_seed is not None: 
    random.seed(new_seed)
    np.random.seed(new_seed)
    torch.random.manual_seed(new_seed)

## FedAvg

In [4]:
from src.federated_learning.standard.fedavg import Server as FedAvgServer

# Training
server = FedAvgServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model)
)

fedavg_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
fedavg_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(fedavg_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 21:11:26,717 - FEDAVG/Server - INFO - Round 1: training_loss = 0.0343
  0%|          | 1/500 [00:02<19:55,  2.40s/it]2025-05-21 21:11:28,954 - FEDAVG/Server - INFO - Round 2: training_loss = 0.0329
  0%|          | 2/500 [00:04<19:06,  2.30s/it]2025-05-21 21:11:30,950 - FEDAVG/Server - INFO - Round 3: training_loss = 0.0314
  1%|          | 3/500 [00:06<17:54,  2.16s/it]2025-05-21 21:11:32,979 - FEDAVG/Server - INFO - Round 4: training_loss = 0.0296
  1%|          | 4/500 [00:08<17:26,  2.11s/it]2025-05-21 21:11:34,644 - FEDAVG/Server - INFO - Round 5: training_loss = 0.0289
  1%|          | 5/500 [00:10<16:04,  1.95s/it]2025-05-21 21:11:36,694 - FEDAVG/Server - INFO - Round 6: training_loss = 0.028
  1%|          | 6/500 [00:12<16:19,  1.98s/it]2025-05-21 21:11:38,531 - FEDAVG/Server - INFO - Round 7: training_loss = 0.0274
  1%|▏         | 7/500 [00:14<15:54,  1.94s/it]2025-05-21 21:11:40,458 - FEDAVG/Server - INFO - Round 8: training

{'loss': 0.6702230829099814, 'metrics': {'acc': 0.7458333352804184, 'wacc': 0.7458333382407825}}


## Krum

In [5]:
from src.federated_learning.standard.krum import Server as KrumServer

# Training
server = KrumServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model),
    K = 1
)

krum_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
krum_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(krum_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 21:35:38,914 - KRUM/Server - INFO - Round 1: training_loss = 0.0343
  0%|          | 1/500 [00:03<27:14,  3.28s/it]2025-05-21 21:35:42,226 - KRUM/Server - INFO - Round 2: training_loss = 0.0331
  0%|          | 2/500 [00:06<27:22,  3.30s/it]2025-05-21 21:35:45,399 - KRUM/Server - INFO - Round 3: training_loss = 0.0317
  1%|          | 3/500 [00:09<26:50,  3.24s/it]2025-05-21 21:35:48,525 - KRUM/Server - INFO - Round 4: training_loss = 0.0301
  1%|          | 4/500 [00:12<26:24,  3.20s/it]2025-05-21 21:35:51,520 - KRUM/Server - INFO - Round 5: training_loss = 0.0291
  1%|          | 5/500 [00:15<25:46,  3.12s/it]2025-05-21 21:35:54,612 - KRUM/Server - INFO - Round 6: training_loss = 0.0282
  1%|          | 6/500 [00:18<25:37,  3.11s/it]2025-05-21 21:35:57,685 - KRUM/Server - INFO - Round 7: training_loss = 0.0275
  1%|▏         | 7/500 [00:22<25:28,  3.10s/it]2025-05-21 21:36:01,103 - KRUM/Server - INFO - Round 8: training_loss = 0.0271


{'loss': 0.7466761565208435, 'metrics': {'acc': 0.7199999994238218, 'wacc': 0.720000006377697}}


## Multi-Krum

In [6]:
# Training
server = KrumServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model),
    K = 5
)

mkrum_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
mkrum_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(mkrum_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 22:01:25,081 - KRUM/Server - INFO - Round 1: training_loss = 0.0343
  0%|          | 1/500 [00:03<26:12,  3.15s/it]2025-05-21 22:01:28,145 - KRUM/Server - INFO - Round 2: training_loss = 0.0331
  0%|          | 2/500 [00:06<25:43,  3.10s/it]2025-05-21 22:01:31,344 - KRUM/Server - INFO - Round 3: training_loss = 0.0312
  1%|          | 3/500 [00:09<26:03,  3.15s/it]2025-05-21 22:01:34,488 - KRUM/Server - INFO - Round 4: training_loss = 0.03
  1%|          | 4/500 [00:12<25:59,  3.14s/it]2025-05-21 22:01:37,296 - KRUM/Server - INFO - Round 5: training_loss = 0.0288
  1%|          | 5/500 [00:15<24:56,  3.02s/it]2025-05-21 22:01:40,100 - KRUM/Server - INFO - Round 6: training_loss = 0.0281
  1%|          | 6/500 [00:18<24:16,  2.95s/it]2025-05-21 22:01:43,359 - KRUM/Server - INFO - Round 7: training_loss = 0.0274
  1%|▏         | 7/500 [00:21<25:03,  3.05s/it]2025-05-21 22:01:46,343 - KRUM/Server - INFO - Round 8: training_loss = 0.0271
  

{'loss': 0.6642759261329969, 'metrics': {'acc': 0.7493333325882753, 'wacc': 0.7493333355784416}}


## TrimmedMean

In [7]:
from src.federated_learning.standard.trimmedmean import Server as TrimmedMeanServer

# Training
server = TrimmedMeanServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model),
    tail_size = 2
)

tmean_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
tmean_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(tmean_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 22:26:06,476 - TRIMMEDMEAN/Server - INFO - Round 1: training_loss = 0.0343
  0%|          | 1/500 [00:03<25:09,  3.03s/it]2025-05-21 22:26:09,599 - TRIMMEDMEAN/Server - INFO - Round 2: training_loss = 0.033
  0%|          | 2/500 [00:06<25:35,  3.08s/it]2025-05-21 22:26:12,547 - TRIMMEDMEAN/Server - INFO - Round 3: training_loss = 0.0313
  1%|          | 3/500 [00:09<25:01,  3.02s/it]2025-05-21 22:26:16,022 - TRIMMEDMEAN/Server - INFO - Round 4: training_loss = 0.0302
  1%|          | 4/500 [00:12<26:27,  3.20s/it]2025-05-21 22:26:19,161 - TRIMMEDMEAN/Server - INFO - Round 5: training_loss = 0.0284
  1%|          | 5/500 [00:15<26:13,  3.18s/it]2025-05-21 22:26:22,208 - TRIMMEDMEAN/Server - INFO - Round 6: training_loss = 0.0279
  1%|          | 6/500 [00:18<25:47,  3.13s/it]2025-05-21 22:26:25,077 - TRIMMEDMEAN/Server - INFO - Round 7: training_loss = 0.0273
  1%|▏         | 7/500 [00:21<25:02,  3.05s/it]2025-05-21 22:26:28,190 - TRIMM

{'loss': 0.6553637717068196, 'metrics': {'acc': 0.7518333331843218, 'wacc': 0.7518333391348521}}


## GeoMed

In [8]:
from src.federated_learning.standard.geomed import Server as GeoMedServer

# Training
server = GeoMedServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model),
    geomed_max_iter = 10
)

geomed_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
geomed_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(geomed_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-22 00:14:14,378 - GEOMED/Server - INFO - Round 1: training_loss = 0.0343
  0%|          | 1/500 [00:02<18:16,  2.20s/it]2025-05-22 00:14:16,528 - GEOMED/Server - INFO - Round 2: training_loss = 0.0331
  0%|          | 2/500 [00:04<18:00,  2.17s/it]2025-05-22 00:14:18,692 - GEOMED/Server - INFO - Round 3: training_loss = 0.0314
  1%|          | 3/500 [00:06<17:56,  2.17s/it]2025-05-22 00:14:20,939 - GEOMED/Server - INFO - Round 4: training_loss = 0.0297
  1%|          | 4/500 [00:08<18:10,  2.20s/it]2025-05-22 00:14:23,115 - GEOMED/Server - INFO - Round 5: training_loss = 0.0287
  1%|          | 5/500 [00:10<18:04,  2.19s/it]2025-05-22 00:14:25,245 - GEOMED/Server - INFO - Round 6: training_loss = 0.028
  1%|          | 6/500 [00:13<17:51,  2.17s/it]2025-05-22 00:14:27,419 - GEOMED/Server - INFO - Round 7: training_loss = 0.0273
  1%|▏         | 7/500 [00:15<17:50,  2.17s/it]2025-05-22 00:14:29,545 - GEOMED/Server - INFO - Round 8: training

{'loss': 0.6659508054653803, 'metrics': {'acc': 0.7419999987383683, 'wacc': 0.7420000017086665}}


In [9]:
from json import dump 

with open(f"notebooks/fashion_mnist/results/waffle_fft_seed{new_seed}.json", "w") as f: 
    d = {
        "fedavg": {"train": fedavg_train_results, "test": fedavg_evaluation_results},
        "krum": {"train": krum_train_results, "test": krum_evaluation_results},
        "mkrum": {"train": mkrum_train_results, "test": mkrum_evaluation_results},
        "trimmedmean": {"train": tmean_train_results, "test": tmean_evaluation_results},
        "geomed": {"train": geomed_train_results, "test": geomed_evaluation_results},
    }
    dump(d, f)