In [1]:
%cd ../..

/Users/davideleo/Desktop/Projects/research/papers/fl_wavelet_v0


## Setup

In [2]:
import random
import torch 
import numpy as np 
from src.data.fashion_mnist import get_federation 
from src.data.attacks import GaussianBlur, GaussianNoise
from src.federated_learning.standard.fedavg import Client as Client
from src.models.neural_networks import LeNet5
from src.models.metrics import Accuracy, WeightedAccuracy
from copy import deepcopy

random.seed(42)
np.random.seed(42)
torch.random.manual_seed(42)

# Federation
model = LeNet5(in_channels = 1, in_padding = 2, num_classes = 10)

federation = get_federation(
    num_shards = 100,
    alpha = 1000,
    attacks = [GaussianNoise(sigma = .5), GaussianBlur(kernel_size = 11)],
    attacks_proba = 0.4
)

clients = [
    Client(
        train_dataset = dataset["train"],
        test_dataset = dataset["test"],
        distribution = dataset["distribution"],
        batch_size = 64,
        device = "cpu"
    ) for dataset in federation
]

waffle_preds = [1,  6, 11, 16, 18, 19, 20, 26, 37, 39, 40, 45, 53, 54, 55, 57, 68, 73, 75, 82, 85, 90, 92, 97]
waffle_clients = []
for i in range(100): 
    if i not in waffle_preds: 
        waffle_clients.append(clients[i])

benign_clients = [client for dataset, client in zip(federation, clients) if len(dataset["id"].split(".")) == 1] 

In [3]:
# Default experiments: new_seed = None -> [None, 7, 365]
new_seed = 365

if new_seed is not None: 
    random.seed(new_seed)
    np.random.seed(new_seed)
    torch.random.manual_seed(new_seed)

## FedAvg

In [4]:
from src.federated_learning.standard.fedavg import Server as FedAvgServer

# Training
server = FedAvgServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model)
)

fedavg_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
fedavg_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(fedavg_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 18:05:59,431 - FEDAVG/Server - INFO - Round 1: training_loss = 0.0343
  0%|          | 1/500 [00:04<37:53,  4.56s/it]2025-05-21 18:06:04,222 - FEDAVG/Server - INFO - Round 2: training_loss = 0.0332
  0%|          | 2/500 [00:09<38:57,  4.69s/it]2025-05-21 18:06:08,233 - FEDAVG/Server - INFO - Round 3: training_loss = 0.0313
  1%|          | 3/500 [00:13<36:17,  4.38s/it]2025-05-21 18:06:12,463 - FEDAVG/Server - INFO - Round 4: training_loss = 0.03
  1%|          | 4/500 [00:17<35:43,  4.32s/it]2025-05-21 18:06:16,817 - FEDAVG/Server - INFO - Round 5: training_loss = 0.0287
  1%|          | 5/500 [00:21<35:45,  4.33s/it]2025-05-21 18:06:21,087 - FEDAVG/Server - INFO - Round 6: training_loss = 0.0279
  1%|          | 6/500 [00:26<35:30,  4.31s/it]2025-05-21 18:06:26,003 - FEDAVG/Server - INFO - Round 7: training_loss = 0.0274
  1%|▏         | 7/500 [00:31<37:03,  4.51s/it]2025-05-21 18:06:30,609 - FEDAVG/Server - INFO - Round 8: training_

{'loss': 0.6938109275400639, 'metrics': {'acc': 0.7391666645805041, 'wacc': 0.7391666705409686}}


## Krum

In [5]:
from src.federated_learning.standard.krum import Server as KrumServer

# Training
server = KrumServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model),
    K = 1
)

krum_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
krum_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(krum_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 18:47:39,466 - KRUM/Server - INFO - Round 1: training_loss = 0.0343
  0%|          | 1/500 [00:05<44:29,  5.35s/it]2025-05-21 18:47:44,768 - KRUM/Server - INFO - Round 2: training_loss = 0.0332
  0%|          | 2/500 [00:10<44:10,  5.32s/it]2025-05-21 18:47:50,049 - KRUM/Server - INFO - Round 3: training_loss = 0.0315
  1%|          | 3/500 [00:15<43:55,  5.30s/it]2025-05-21 18:47:55,549 - KRUM/Server - INFO - Round 4: training_loss = 0.0301
  1%|          | 4/500 [00:21<44:28,  5.38s/it]2025-05-21 18:48:00,696 - KRUM/Server - INFO - Round 5: training_loss = 0.029
  1%|          | 5/500 [00:26<43:41,  5.30s/it]2025-05-21 18:48:05,830 - KRUM/Server - INFO - Round 6: training_loss = 0.0282
  1%|          | 6/500 [00:31<43:09,  5.24s/it]2025-05-21 18:48:10,836 - KRUM/Server - INFO - Round 7: training_loss = 0.0277
  1%|▏         | 7/500 [00:36<42:25,  5.16s/it]2025-05-21 18:48:15,792 - KRUM/Server - INFO - Round 8: training_loss = 0.0271
 

{'loss': 0.8381677047312259, 'metrics': {'acc': 0.6990000071922938, 'wacc': 0.6990000091791153}}


## Multi-Krum

In [6]:
# Training
server = KrumServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model),
    K = 5
)

mkrum_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
mkrum_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(mkrum_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 19:20:31,609 - KRUM/Server - INFO - Round 1: training_loss = 0.0343
  0%|          | 1/500 [00:03<26:51,  3.23s/it]2025-05-21 19:20:34,635 - KRUM/Server - INFO - Round 2: training_loss = 0.0329
  0%|          | 2/500 [00:06<25:48,  3.11s/it]2025-05-21 19:20:37,773 - KRUM/Server - INFO - Round 3: training_loss = 0.0313
  1%|          | 3/500 [00:09<25:51,  3.12s/it]2025-05-21 19:20:40,955 - KRUM/Server - INFO - Round 4: training_loss = 0.0298
  1%|          | 4/500 [00:12<26:00,  3.15s/it]2025-05-21 19:20:44,088 - KRUM/Server - INFO - Round 5: training_loss = 0.029
  1%|          | 5/500 [00:15<25:55,  3.14s/it]2025-05-21 19:20:47,276 - KRUM/Server - INFO - Round 6: training_loss = 0.0276
  1%|          | 6/500 [00:18<25:59,  3.16s/it]2025-05-21 19:20:50,433 - KRUM/Server - INFO - Round 7: training_loss = 0.0276
  1%|▏         | 7/500 [00:22<25:56,  3.16s/it]2025-05-21 19:20:53,626 - KRUM/Server - INFO - Round 8: training_loss = 0.0268
 

{'loss': 0.6882702090342839, 'metrics': {'acc': 0.7413333333730697, 'wacc': 0.7413333403368791}}


## TrimmedMean

In [7]:
from src.federated_learning.standard.trimmedmean import Server as TrimmedMeanServer

# Training
server = TrimmedMeanServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model),
    tail_size = 2
)

tmean_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
tmean_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(tmean_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 19:47:08,733 - TRIMMEDMEAN/Server - INFO - Round 1: training_loss = 0.0343
  0%|          | 1/500 [00:03<28:31,  3.43s/it]2025-05-21 19:47:12,197 - TRIMMEDMEAN/Server - INFO - Round 2: training_loss = 0.0331
  0%|          | 2/500 [00:06<28:38,  3.45s/it]2025-05-21 19:47:15,583 - TRIMMEDMEAN/Server - INFO - Round 3: training_loss = 0.0312
  1%|          | 3/500 [00:10<28:20,  3.42s/it]2025-05-21 19:47:19,142 - TRIMMEDMEAN/Server - INFO - Round 4: training_loss = 0.0299
  1%|          | 4/500 [00:13<28:43,  3.48s/it]2025-05-21 19:47:22,629 - TRIMMEDMEAN/Server - INFO - Round 5: training_loss = 0.0287
  1%|          | 5/500 [00:17<28:42,  3.48s/it]2025-05-21 19:47:26,081 - TRIMMEDMEAN/Server - INFO - Round 6: training_loss = 0.0281
  1%|          | 6/500 [00:20<28:34,  3.47s/it]2025-05-21 19:47:29,464 - TRIMMEDMEAN/Server - INFO - Round 7: training_loss = 0.0274
  1%|▏         | 7/500 [00:24<28:16,  3.44s/it]2025-05-21 19:47:32,777 - TRIM

{'loss': 0.6929875004887581, 'metrics': {'acc': 0.7345000053544839, 'wacc': 0.734500009338061}}


## GeoMed

In [8]:
from src.federated_learning.standard.geomed import Server as GeoMedServer

# Training
server = GeoMedServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model),
    geomed_max_iter = 10
)

geomed_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
geomed_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(geomed_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 20:15:20,213 - GEOMED/Server - INFO - Round 1: training_loss = 0.0343
  0%|          | 1/500 [00:03<27:46,  3.34s/it]2025-05-21 20:15:23,682 - GEOMED/Server - INFO - Round 2: training_loss = 0.0331
  0%|          | 2/500 [00:06<28:21,  3.42s/it]2025-05-21 20:15:26,893 - GEOMED/Server - INFO - Round 3: training_loss = 0.0314
  1%|          | 3/500 [00:10<27:31,  3.32s/it]2025-05-21 20:15:30,338 - GEOMED/Server - INFO - Round 4: training_loss = 0.0296
  1%|          | 4/500 [00:13<27:51,  3.37s/it]2025-05-21 20:15:33,843 - GEOMED/Server - INFO - Round 5: training_loss = 0.029
  1%|          | 5/500 [00:16<28:12,  3.42s/it]2025-05-21 20:15:37,153 - GEOMED/Server - INFO - Round 6: training_loss = 0.0282
  1%|          | 6/500 [00:20<27:50,  3.38s/it]2025-05-21 20:15:40,765 - GEOMED/Server - INFO - Round 7: training_loss = 0.0277
  1%|▏         | 7/500 [00:23<28:24,  3.46s/it]2025-05-21 20:15:44,287 - GEOMED/Server - INFO - Round 8: training

{'loss': 0.6861883154511452, 'metrics': {'acc': 0.7385000033080578, 'wacc': 0.7385000082552433}}


In [9]:
from json import dump 

with open(f"notebooks/fashion_mnist/results/waffle_wst_{new_seed}.json", "w") as f: 
    d = {
        "fedavg": {"train": fedavg_train_results, "test": fedavg_evaluation_results},
        "krum": {"train": krum_train_results, "test": krum_evaluation_results},
        "mkrum": {"train": mkrum_train_results, "test": mkrum_evaluation_results},
        "trimmedmean": {"train": tmean_train_results, "test": tmean_evaluation_results},
        "geomed": {"train": geomed_train_results, "test": geomed_evaluation_results},
    }
    dump(d, f)