In [1]:
%cd ../..

/Users/davideleo/Desktop/Projects/research/papers/fl_wavelet_v0


## Setup

In [2]:
import random
import torch 
import numpy as np 
from src.data.fashion_mnist import get_federation 
from src.data.attacks import GaussianBlur, GaussianNoise
from src.federated_learning.standard.fedavg import Client as Client
from src.models.neural_networks import LeNet5
from src.models.metrics import Accuracy, WeightedAccuracy
from copy import deepcopy

random.seed(42)
np.random.seed(42)
torch.random.manual_seed(42)

# Federation
model = LeNet5(in_channels = 1, in_padding = 2, num_classes = 10)

federation = get_federation(
    num_shards = 100,
    alpha = 1000,
    attacks = [GaussianNoise(sigma = .5), GaussianBlur(kernel_size = 11)],
    attacks_proba = 0.4
)

clients = [
    Client(
        train_dataset = dataset["train"],
        test_dataset = dataset["test"],
        distribution = dataset["distribution"],
        batch_size = 64,
        device = "cpu"
    ) for dataset in federation
]

benign_clients = [client for dataset, client in zip(federation, clients) if len(dataset["id"].split(".")) == 1] 

In [3]:
# Default experiments: new_seed = None -> [None, 7, 365]
new_seed = 365

if new_seed is not None: 
    random.seed(new_seed)
    np.random.seed(new_seed)
    torch.random.manual_seed(new_seed)

## FedAvg

In [4]:
from src.federated_learning.standard.fedavg import Server as FedAvgServer

# Training
server = FedAvgServer(
    clients = clients,
    participation_rate = 10, 
    model = deepcopy(model)
)

fedavg_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
fedavg_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(fedavg_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 18:23:15,094 - FEDAVG/Server - INFO - Round 1: training_loss = 0.0343
  0%|          | 1/500 [00:06<51:21,  6.17s/it]2025-05-21 18:23:20,612 - FEDAVG/Server - INFO - Round 2: training_loss = 0.0332
  0%|          | 2/500 [00:11<48:02,  5.79s/it]2025-05-21 18:23:26,024 - FEDAVG/Server - INFO - Round 3: training_loss = 0.0315
  1%|          | 3/500 [00:17<46:31,  5.62s/it]2025-05-21 18:23:31,233 - FEDAVG/Server - INFO - Round 4: training_loss = 0.0297
  1%|          | 4/500 [00:22<45:05,  5.46s/it]2025-05-21 18:23:36,090 - FEDAVG/Server - INFO - Round 5: training_loss = 0.0287
  1%|          | 5/500 [00:27<43:13,  5.24s/it]2025-05-21 18:23:41,440 - FEDAVG/Server - INFO - Round 6: training_loss = 0.0282
  1%|          | 6/500 [00:32<43:26,  5.28s/it]2025-05-21 18:23:46,577 - FEDAVG/Server - INFO - Round 7: training_loss = 0.0274
  1%|▏         | 7/500 [00:37<42:59,  5.23s/it]2025-05-21 18:23:51,166 - FEDAVG/Server - INFO - Round 8: trainin

{'loss': 0.6698376302222411, 'metrics': {'acc': 0.7463333342373372, 'wacc': 0.7463333401978016}}


## Krum

In [5]:
from src.federated_learning.standard.krum import Server as KrumServer

# Training
server = KrumServer(
    clients = clients,
    participation_rate = 10, 
    model = deepcopy(model),
    K = 1
)

krum_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
krum_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(krum_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 19:08:35,357 - KRUM/Server - INFO - Round 1: training_loss = 0.0343
  0%|          | 1/500 [00:04<33:42,  4.05s/it]2025-05-21 19:08:39,061 - KRUM/Server - INFO - Round 2: training_loss = 0.0328
  0%|          | 2/500 [00:07<31:56,  3.85s/it]2025-05-21 19:08:43,050 - KRUM/Server - INFO - Round 3: training_loss = 0.0315
  1%|          | 3/500 [00:11<32:24,  3.91s/it]2025-05-21 19:08:47,121 - KRUM/Server - INFO - Round 4: training_loss = 0.0302
  1%|          | 4/500 [00:15<32:51,  3.97s/it]2025-05-21 19:08:50,627 - KRUM/Server - INFO - Round 5: training_loss = 0.0292
  1%|          | 5/500 [00:19<31:23,  3.81s/it]2025-05-21 19:08:54,111 - KRUM/Server - INFO - Round 6: training_loss = 0.0284
  1%|          | 6/500 [00:22<30:25,  3.70s/it]2025-05-21 19:08:58,151 - KRUM/Server - INFO - Round 7: training_loss = 0.0279
  1%|▏         | 7/500 [00:26<31:17,  3.81s/it]2025-05-21 19:09:01,205 - KRUM/Server - INFO - Round 8: training_loss = 0.0273


{'loss': 0.7112725607355436, 'metrics': {'acc': 0.7448333282768727, 'wacc': 0.7448333401978016}}


## Multi-Krum

In [6]:
# Training
server = KrumServer(
    clients = clients,
    participation_rate = 10, 
    model = deepcopy(model),
    K = 5
)

mkrum_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
mkrum_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(mkrum_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 19:39:28,330 - KRUM/Server - INFO - Round 1: training_loss = 0.0343
  0%|          | 1/500 [00:03<28:49,  3.47s/it]2025-05-21 19:39:32,419 - KRUM/Server - INFO - Round 2: training_loss = 0.033
  0%|          | 2/500 [00:07<31:48,  3.83s/it]2025-05-21 19:39:36,647 - KRUM/Server - INFO - Round 3: training_loss = 0.0314
  1%|          | 3/500 [00:11<33:14,  4.01s/it]2025-05-21 19:39:40,663 - KRUM/Server - INFO - Round 4: training_loss = 0.0302
  1%|          | 4/500 [00:15<33:11,  4.01s/it]2025-05-21 19:39:43,891 - KRUM/Server - INFO - Round 5: training_loss = 0.0289
  1%|          | 5/500 [00:19<30:46,  3.73s/it]2025-05-21 19:39:47,002 - KRUM/Server - INFO - Round 6: training_loss = 0.0281
  1%|          | 6/500 [00:22<28:58,  3.52s/it]2025-05-21 19:39:50,365 - KRUM/Server - INFO - Round 7: training_loss = 0.0279
  1%|▏         | 7/500 [00:25<28:30,  3.47s/it]2025-05-21 19:39:54,116 - KRUM/Server - INFO - Round 8: training_loss = 0.0272
 

{'loss': 0.6674748397370179, 'metrics': {'acc': 0.7523333372275035, 'wacc': 0.7523333312670389}}


## TrimmedMean

In [7]:
from src.federated_learning.standard.trimmedmean import Server as TrimmedMeanServer

# Training
server = TrimmedMeanServer(
    clients = clients,
    participation_rate = 10, 
    model = deepcopy(model),
    tail_size = 2
)

tmean_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
tmean_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(tmean_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 20:11:39,509 - TRIMMEDMEAN/Server - INFO - Round 1: training_loss = 0.0342
  0%|          | 1/500 [00:03<30:25,  3.66s/it]2025-05-21 20:11:43,652 - TRIMMEDMEAN/Server - INFO - Round 2: training_loss = 0.033
  0%|          | 2/500 [00:07<32:43,  3.94s/it]2025-05-21 20:11:47,247 - TRIMMEDMEAN/Server - INFO - Round 3: training_loss = 0.0313
  1%|          | 3/500 [00:11<31:20,  3.78s/it]2025-05-21 20:11:50,978 - TRIMMEDMEAN/Server - INFO - Round 4: training_loss = 0.0299
  1%|          | 4/500 [00:15<31:06,  3.76s/it]2025-05-21 20:11:54,714 - TRIMMEDMEAN/Server - INFO - Round 5: training_loss = 0.0286
  1%|          | 5/500 [00:18<30:57,  3.75s/it]2025-05-21 20:11:58,904 - TRIMMEDMEAN/Server - INFO - Round 6: training_loss = 0.0282
  1%|          | 6/500 [00:23<32:07,  3.90s/it]2025-05-21 20:12:02,817 - TRIMMEDMEAN/Server - INFO - Round 7: training_loss = 0.0275
  1%|▏         | 7/500 [00:26<32:05,  3.91s/it]2025-05-21 20:12:06,874 - TRIMM

{'loss': 0.6774162550667922, 'metrics': {'acc': 0.7440000004867713, 'wacc': 0.7440000104208787}}


## GeoMed

In [8]:
from src.federated_learning.standard.geomed import Server as GeoMedServer

# Training
server = GeoMedServer(
    clients = clients,
    participation_rate = 10, 
    model = deepcopy(model),
    geomed_max_iter = 10
)

geomed_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
geomed_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(geomed_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 20:39:04,960 - GEOMED/Server - INFO - Round 1: training_loss = 0.0343
  0%|          | 1/500 [00:02<17:26,  2.10s/it]2025-05-21 20:39:07,237 - GEOMED/Server - INFO - Round 2: training_loss = 0.0332
  0%|          | 2/500 [00:04<18:16,  2.20s/it]2025-05-21 20:39:09,956 - GEOMED/Server - INFO - Round 3: training_loss = 0.0315
  1%|          | 3/500 [00:07<20:11,  2.44s/it]2025-05-21 20:39:12,262 - GEOMED/Server - INFO - Round 4: training_loss = 0.0298
  1%|          | 4/500 [00:09<19:43,  2.39s/it]2025-05-21 20:39:14,494 - GEOMED/Server - INFO - Round 5: training_loss = 0.0288
  1%|          | 5/500 [00:11<19:13,  2.33s/it]2025-05-21 20:39:16,897 - GEOMED/Server - INFO - Round 6: training_loss = 0.0281
  1%|          | 6/500 [00:14<19:23,  2.36s/it]2025-05-21 20:39:19,191 - GEOMED/Server - INFO - Round 7: training_loss = 0.0272
  1%|▏         | 7/500 [00:16<19:11,  2.34s/it]2025-05-21 20:39:21,397 - GEOMED/Server - INFO - Round 8: trainin

{'loss': 0.741818493972222, 'metrics': {'acc': 0.7285000055531661, 'wacc': 0.7285000115136305}}


## Results

In [9]:
from json import dump 

with open(f"notebooks/fashion_mnist/results/standard_seed{new_seed}.json", "w") as f: 
    d = {
        "fedavg": {"train": fedavg_train_results, "test": fedavg_evaluation_results},
        "krum": {"train": krum_train_results, "test": krum_evaluation_results},
        "mkrum": {"train": mkrum_train_results, "test": mkrum_evaluation_results},
        "trimmedmean": {"train": tmean_train_results, "test": tmean_evaluation_results},
        "geomed": {"train": geomed_train_results, "test": geomed_evaluation_results},
    }
    dump(d, f)