In [1]:
%cd ../..

/Users/davideleo/Desktop/Projects/research/papers/fl_wavelet_v0


## Setup

In [2]:
import random
import torch 
import numpy as np 
from src.data.cifar100 import get_federation 
from src.data.attacks import GaussianBlur, GaussianNoise
from src.federated_learning.standard.fedavg import Client as Client
from src.models.neural_networks import LeNet5
from src.models.metrics import Accuracy, WeightedAccuracy
from copy import deepcopy

random.seed(42)
np.random.seed(42)
torch.random.manual_seed(42)

# Federation
model = LeNet5(in_channels = 3, in_padding = 0, num_classes = 100)

federation = get_federation(
    num_shards = 100,
    alpha = 1000,
    attacks = [GaussianNoise(sigma = .5), GaussianBlur(kernel_size = 11)],
    attacks_proba = 0.4
)

clients = [
    Client(
        train_dataset = dataset["train"],
        test_dataset = dataset["test"],
        distribution = dataset["distribution"],
        batch_size = 64,
        device = "cpu"
    ) for dataset in federation
]

waffle_preds = [1,  6,  7, 14, 16, 17, 18, 19, 20, 26, 39, 40, 42, 45, 46, 53, 54, 55,
        57, 60, 61, 65, 68, 73, 75, 80, 82, 85, 90, 92, 95, 97]
waffle_clients = []
for i in range(100): 
    if i not in waffle_preds: 
        waffle_clients.append(clients[i])

benign_clients = [client for dataset, client in zip(federation, clients) if len(dataset["id"].split(".")) == 1] 

In [3]:
# Default experiments: new_seed = None -> [None, 7, 365]
new_seed = 365

if new_seed is not None: 
    random.seed(new_seed)
    np.random.seed(new_seed)
    torch.random.manual_seed(new_seed)

## FedAvg

In [4]:
from src.federated_learning.standard.fedavg import Server as FedAvgServer

# Training
server = FedAvgServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model)
)

fedavg_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
fedavg_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(fedavg_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 18:08:51,131 - FEDAVG/Server - INFO - Round 1: training_loss = 0.0657
  0%|          | 1/500 [00:04<39:21,  4.73s/it]2025-05-21 18:08:55,742 - FEDAVG/Server - INFO - Round 2: training_loss = 0.0668
  0%|          | 2/500 [00:09<38:41,  4.66s/it]2025-05-21 18:08:59,992 - FEDAVG/Server - INFO - Round 3: training_loss = 0.0669
  1%|          | 3/500 [00:13<37:03,  4.47s/it]2025-05-21 18:09:04,047 - FEDAVG/Server - INFO - Round 4: training_loss = 0.0668
  1%|          | 4/500 [00:17<35:36,  4.31s/it]2025-05-21 18:09:08,634 - FEDAVG/Server - INFO - Round 5: training_loss = 0.066
  1%|          | 5/500 [00:22<36:22,  4.41s/it]2025-05-21 18:09:12,717 - FEDAVG/Server - INFO - Round 6: training_loss = 0.0668
  1%|          | 6/500 [00:26<35:23,  4.30s/it]2025-05-21 18:09:17,032 - FEDAVG/Server - INFO - Round 7: training_loss = 0.0662
  1%|▏         | 7/500 [00:30<35:21,  4.30s/it]2025-05-21 18:09:21,111 - FEDAVG/Server - INFO - Round 8: training

{'loss': 3.4935325503648573, 'metrics': {'acc': 0.16418410028498542, 'wacc': 0.16418410350338686}}


## Krum

In [5]:
from src.federated_learning.standard.krum import Server as KrumServer

# Training
server = KrumServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model),
    K = 1
)

krum_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
krum_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(krum_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 18:44:14,517 - KRUM/Server - INFO - Round 1: training_loss = 0.0676
  0%|          | 1/500 [00:04<41:04,  4.94s/it]2025-05-21 18:44:19,041 - KRUM/Server - INFO - Round 2: training_loss = 0.0679
  0%|          | 2/500 [00:09<38:58,  4.69s/it]2025-05-21 18:44:23,503 - KRUM/Server - INFO - Round 3: training_loss = 0.0664
  1%|          | 3/500 [00:13<38:00,  4.59s/it]2025-05-21 18:44:28,416 - KRUM/Server - INFO - Round 4: training_loss = 0.0665
  1%|          | 4/500 [00:18<38:59,  4.72s/it]2025-05-21 18:44:32,764 - KRUM/Server - INFO - Round 5: training_loss = 0.0681
  1%|          | 5/500 [00:23<37:48,  4.58s/it]2025-05-21 18:44:37,941 - KRUM/Server - INFO - Round 6: training_loss = 0.0668
  1%|          | 6/500 [00:28<39:23,  4.79s/it]2025-05-21 18:44:42,291 - KRUM/Server - INFO - Round 7: training_loss = 0.0655
  1%|▏         | 7/500 [00:32<38:09,  4.64s/it]2025-05-21 18:44:46,487 - KRUM/Server - INFO - Round 8: training_loss = 0.0663


{'loss': 4.119729559670931, 'metrics': {'acc': 0.07598326413935073, 'wacc': 0.07598326358118067}}


## Multi-Krum

In [6]:
# Training
server = KrumServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model),
    K = 5
)

mkrum_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
mkrum_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(mkrum_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 19:16:55,070 - KRUM/Server - INFO - Round 1: training_loss = 0.0686
  0%|          | 1/500 [00:03<27:02,  3.25s/it]2025-05-21 19:16:58,133 - KRUM/Server - INFO - Round 2: training_loss = 0.0667
  0%|          | 2/500 [00:06<26:04,  3.14s/it]2025-05-21 19:17:01,019 - KRUM/Server - INFO - Round 3: training_loss = 0.0676
  1%|          | 3/500 [00:09<25:03,  3.02s/it]2025-05-21 19:17:04,477 - KRUM/Server - INFO - Round 4: training_loss = 0.0667
  1%|          | 4/500 [00:12<26:25,  3.20s/it]2025-05-21 19:17:07,263 - KRUM/Server - INFO - Round 5: training_loss = 0.0666
  1%|          | 5/500 [00:15<25:08,  3.05s/it]2025-05-21 19:17:10,079 - KRUM/Server - INFO - Round 6: training_loss = 0.0676
  1%|          | 6/500 [00:18<24:26,  2.97s/it]2025-05-21 19:17:13,240 - KRUM/Server - INFO - Round 7: training_loss = 0.0666
  1%|▏         | 7/500 [00:21<24:54,  3.03s/it]2025-05-21 19:17:16,185 - KRUM/Server - INFO - Round 8: training_loss = 0.0667


{'loss': 3.6630429649353027, 'metrics': {'acc': 0.14543933161507092, 'wacc': 0.14543933194676203}}


## TrimmedMean

In [7]:
from src.federated_learning.standard.trimmedmean import Server as TrimmedMeanServer

# Training
server = TrimmedMeanServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model),
    tail_size = 2
)

tmean_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
tmean_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(tmean_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 19:41:55,069 - TRIMMEDMEAN/Server - INFO - Round 1: training_loss = 0.067
  0%|          | 1/500 [00:03<26:09,  3.14s/it]2025-05-21 19:41:58,382 - TRIMMEDMEAN/Server - INFO - Round 2: training_loss = 0.0674
  0%|          | 2/500 [00:06<26:55,  3.24s/it]2025-05-21 19:42:01,328 - TRIMMEDMEAN/Server - INFO - Round 3: training_loss = 0.0682
  1%|          | 3/500 [00:09<25:44,  3.11s/it]2025-05-21 19:42:04,837 - TRIMMEDMEAN/Server - INFO - Round 4: training_loss = 0.0664
  1%|          | 4/500 [00:12<27:00,  3.27s/it]2025-05-21 19:42:07,832 - TRIMMEDMEAN/Server - INFO - Round 5: training_loss = 0.0675
  1%|          | 5/500 [00:15<26:08,  3.17s/it]2025-05-21 19:42:11,004 - TRIMMEDMEAN/Server - INFO - Round 6: training_loss = 0.0668
  1%|          | 6/500 [00:19<26:05,  3.17s/it]2025-05-21 19:42:13,850 - TRIMMEDMEAN/Server - INFO - Round 7: training_loss = 0.0646
  1%|▏         | 7/500 [00:21<25:10,  3.06s/it]2025-05-21 19:42:16,921 - TRIMM

{'loss': 3.518558972410577, 'metrics': {'acc': 0.16117154830295172, 'wacc': 0.16117155067965575}}


## GeoMed

In [8]:
from src.federated_learning.standard.geomed import Server as GeoMedServer

# Training
server = GeoMedServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model),
    geomed_max_iter = 10
)

geomed_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
geomed_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(geomed_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 20:07:57,782 - GEOMED/Server - INFO - Round 1: training_loss = 0.0671
  0%|          | 1/500 [00:03<27:29,  3.31s/it]2025-05-21 20:08:01,345 - GEOMED/Server - INFO - Round 2: training_loss = 0.067
  0%|          | 2/500 [00:06<28:41,  3.46s/it]2025-05-21 20:08:04,744 - GEOMED/Server - INFO - Round 3: training_loss = 0.0673
  1%|          | 3/500 [00:10<28:24,  3.43s/it]2025-05-21 20:08:07,582 - GEOMED/Server - INFO - Round 4: training_loss = 0.064
  1%|          | 4/500 [00:13<26:25,  3.20s/it]2025-05-21 20:08:10,953 - GEOMED/Server - INFO - Round 5: training_loss = 0.0656
  1%|          | 5/500 [00:16<26:53,  3.26s/it]2025-05-21 20:08:14,256 - GEOMED/Server - INFO - Round 6: training_loss = 0.0677
  1%|          | 6/500 [00:19<26:57,  3.27s/it]2025-05-21 20:08:17,546 - GEOMED/Server - INFO - Round 7: training_loss = 0.0661
  1%|▏         | 7/500 [00:23<26:56,  3.28s/it]2025-05-21 20:08:21,005 - GEOMED/Server - INFO - Round 8: training_

{'loss': 3.5115617942012003, 'metrics': {'acc': 0.16033472853849123, 'wacc': 0.16033472835144738}}


In [9]:
from json import dump 

with open(f"notebooks/cifar100/results/waffle_wst_seed{new_seed}.json", "w") as f: 
    d = {
        "fedavg": {"train": fedavg_train_results, "test": fedavg_evaluation_results},
        "krum": {"train": krum_train_results, "test": krum_evaluation_results},
        "mkrum": {"train": mkrum_train_results, "test": mkrum_evaluation_results},
        "trimmedmean": {"train": tmean_train_results, "test": tmean_evaluation_results},
        "geomed": {"train": geomed_train_results, "test": geomed_evaluation_results},
    }
    dump(d, f)