In [1]:
%cd ../..

/Users/davideleo/Desktop/Projects/research/papers/fl_wavelet_v0


## Setup

In [2]:
import random
import torch 
import numpy as np 
from src.data.cifar100 import get_federation 
from src.data.attacks import GaussianBlur, GaussianNoise
from src.federated_learning.standard.fedavg import Client as Client
from src.models.neural_networks import LeNet5
from src.models.metrics import Accuracy, WeightedAccuracy
from copy import deepcopy

random.seed(42)
np.random.seed(42)
torch.random.manual_seed(42)

# Federation
model = LeNet5(in_channels = 3, in_padding = 0, num_classes = 100)

federation = get_federation(
    num_shards = 100,
    alpha = 1000,
    attacks = [GaussianNoise(sigma = .5), GaussianBlur(kernel_size = 11)],
    attacks_proba = 0.4
)

clients = [
    Client(
        train_dataset = dataset["train"],
        test_dataset = dataset["test"],
        distribution = dataset["distribution"],
        batch_size = 64,
        device = "cpu"
    ) for dataset in federation
]

waffle_preds = [ 0,  1,  2,  3,  4,  6,  8,  9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20,
        21, 22, 23, 24, 25, 26, 28, 29, 31, 32, 33, 36, 37, 38, 39, 40, 41, 42,
        44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 60, 61, 62, 63,
        64, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82,
        83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 94, 95, 96, 97, 99]
waffle_clients = []
for i in range(100): 
    if i not in waffle_preds: 
        waffle_clients.append(clients[i])

benign_clients = [client for dataset, client in zip(federation, clients) if len(dataset["id"].split(".")) == 1] 

In [3]:
# Default experiments: new_seed = None -> [None, 7, 365]
new_seed = 7

if new_seed is not None: 
    random.seed(new_seed)
    np.random.seed(new_seed)
    torch.random.manual_seed(new_seed)

## FedAvg

In [4]:
from src.federated_learning.standard.fedavg import Server as FedAvgServer

# Training
server = FedAvgServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model)
)

fedavg_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
fedavg_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(fedavg_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]

2025-05-21 21:13:26,251 - FEDAVG/Server - INFO - Round 1: training_loss = 0.067
  0%|          | 1/500 [00:03<29:38,  3.56s/it]2025-05-21 21:13:29,280 - FEDAVG/Server - INFO - Round 2: training_loss = 0.0663
  0%|          | 2/500 [00:06<26:58,  3.25s/it]2025-05-21 21:13:31,950 - FEDAVG/Server - INFO - Round 3: training_loss = 0.0661
  1%|          | 3/500 [00:09<24:43,  2.98s/it]2025-05-21 21:13:34,648 - FEDAVG/Server - INFO - Round 4: training_loss = 0.0657
  1%|          | 4/500 [00:11<23:44,  2.87s/it]2025-05-21 21:13:37,430 - FEDAVG/Server - INFO - Round 5: training_loss = 0.0653
  1%|          | 5/500 [00:14<23:25,  2.84s/it]2025-05-21 21:13:40,242 - FEDAVG/Server - INFO - Round 6: training_loss = 0.0664
  1%|          | 6/500 [00:17<23:18,  2.83s/it]2025-05-21 21:13:43,111 - FEDAVG/Server - INFO - Round 7: training_loss = 0.0655
  1%|▏         | 7/500 [00:20<23:21,  2.84s/it]2025-05-21 21:13:45,905 - FEDAVG/Server - INFO - Round 8: training_loss = 0.065
  2%|▏         | 8/500 [0

{'loss': 4.0791774319205825, 'metrics': {'acc': 0.11481171572370509, 'wacc': 0.11481171799254716}}


## Krum

In [5]:
from src.federated_learning.standard.krum import Server as KrumServer

# Training
server = KrumServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model),
    K = 1
)

krum_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
krum_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(krum_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 21:36:47,295 - KRUM/Server - INFO - Round 1: training_loss = 0.067
  0%|          | 1/500 [00:02<23:02,  2.77s/it]2025-05-21 21:36:50,271 - KRUM/Server - INFO - Round 2: training_loss = 0.067
  0%|          | 2/500 [00:05<24:00,  2.89s/it]2025-05-21 21:36:53,200 - KRUM/Server - INFO - Round 3: training_loss = 0.0671
  1%|          | 3/500 [00:08<24:05,  2.91s/it]2025-05-21 21:36:56,004 - KRUM/Server - INFO - Round 4: training_loss = 0.0666
  1%|          | 4/500 [00:11<23:42,  2.87s/it]2025-05-21 21:36:59,012 - KRUM/Server - INFO - Round 5: training_loss = 0.0676
  1%|          | 5/500 [00:14<24:04,  2.92s/it]2025-05-21 21:37:01,872 - KRUM/Server - INFO - Round 6: training_loss = 0.0676
  1%|          | 6/500 [00:17<23:51,  2.90s/it]2025-05-21 21:37:04,717 - KRUM/Server - INFO - Round 7: training_loss = 0.0661
  1%|▏         | 7/500 [00:20<23:40,  2.88s/it]2025-05-21 21:37:07,732 - KRUM/Server - INFO - Round 8: training_loss = 0.0666
  

{'loss': 4.603268082391268, 'metrics': {'acc': 0.08000000073290021, 'wacc': 0.08000000192281093}}


## Multi-Krum

In [6]:
# Training
server = KrumServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model),
    K = 5
)

mkrum_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
mkrum_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(mkrum_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 22:00:23,996 - KRUM/Server - INFO - Round 1: training_loss = 0.067
  0%|          | 1/500 [00:02<24:08,  2.90s/it]2025-05-21 22:00:26,872 - KRUM/Server - INFO - Round 2: training_loss = 0.0667
  0%|          | 2/500 [00:05<23:57,  2.89s/it]2025-05-21 22:00:29,743 - KRUM/Server - INFO - Round 3: training_loss = 0.0671
  1%|          | 3/500 [00:08<23:51,  2.88s/it]2025-05-21 22:00:32,570 - KRUM/Server - INFO - Round 4: training_loss = 0.0659
  1%|          | 4/500 [00:11<23:38,  2.86s/it]2025-05-21 22:00:35,231 - KRUM/Server - INFO - Round 5: training_loss = 0.0667
  1%|          | 5/500 [00:14<22:59,  2.79s/it]2025-05-21 22:00:38,058 - KRUM/Server - INFO - Round 6: training_loss = 0.0663
  1%|          | 6/500 [00:16<23:03,  2.80s/it]2025-05-21 22:00:40,791 - KRUM/Server - INFO - Round 7: training_loss = 0.0657
  1%|▏         | 7/500 [00:19<22:49,  2.78s/it]2025-05-21 22:00:43,696 - KRUM/Server - INFO - Round 8: training_loss = 0.0655
 

{'loss': 4.294218689986352, 'metrics': {'acc': 0.1064435155318746, 'wacc': 0.10644351630374228}}


## TrimmedMean

In [7]:
from src.federated_learning.standard.trimmedmean import Server as TrimmedMeanServer

# Training
server = TrimmedMeanServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model),
    tail_size = 2
)

tmean_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
tmean_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(tmean_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-21 22:23:23,026 - TRIMMEDMEAN/Server - INFO - Round 1: training_loss = 0.0682
  0%|          | 1/500 [00:02<24:03,  2.89s/it]2025-05-21 22:23:25,783 - TRIMMEDMEAN/Server - INFO - Round 2: training_loss = 0.066
  0%|          | 2/500 [00:05<23:21,  2.81s/it]2025-05-21 22:23:28,638 - TRIMMEDMEAN/Server - INFO - Round 3: training_loss = 0.0662
  1%|          | 3/500 [00:08<23:27,  2.83s/it]2025-05-21 22:23:31,411 - TRIMMEDMEAN/Server - INFO - Round 4: training_loss = 0.0656
  1%|          | 4/500 [00:11<23:13,  2.81s/it]2025-05-21 22:23:34,213 - TRIMMEDMEAN/Server - INFO - Round 5: training_loss = 0.0658
  1%|          | 5/500 [00:14<23:09,  2.81s/it]2025-05-21 22:23:37,011 - TRIMMEDMEAN/Server - INFO - Round 6: training_loss = 0.065
  1%|          | 6/500 [00:16<23:04,  2.80s/it]2025-05-21 22:23:39,907 - TRIMMEDMEAN/Server - INFO - Round 7: training_loss = 0.0657
  1%|▏         | 7/500 [00:19<23:17,  2.83s/it]2025-05-21 22:23:42,979 - TRIMME

{'loss': 4.371445736146871, 'metrics': {'acc': 0.10610878674866764, 'wacc': 0.106108788902166}}


## GeoMed

In [8]:
from src.federated_learning.standard.geomed import Server as GeoMedServer

# Training
server = GeoMedServer(
    clients = waffle_clients,
    participation_rate = 10, 
    model = deepcopy(model),
    geomed_max_iter = 10
)

geomed_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 201,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
geomed_evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(geomed_evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-22 00:11:27,698 - GEOMED/Server - INFO - Round 1: training_loss = 0.0671
  0%|          | 1/500 [00:02<18:31,  2.23s/it]2025-05-22 00:11:29,967 - GEOMED/Server - INFO - Round 2: training_loss = 0.0662
  0%|          | 2/500 [00:04<18:41,  2.25s/it]2025-05-22 00:11:32,181 - GEOMED/Server - INFO - Round 3: training_loss = 0.0665
  1%|          | 3/500 [00:06<18:30,  2.23s/it]2025-05-22 00:11:34,486 - GEOMED/Server - INFO - Round 4: training_loss = 0.0655
  1%|          | 4/500 [00:09<18:42,  2.26s/it]2025-05-22 00:11:36,632 - GEOMED/Server - INFO - Round 5: training_loss = 0.0652
  1%|          | 5/500 [00:11<18:19,  2.22s/it]2025-05-22 00:11:38,782 - GEOMED/Server - INFO - Round 6: training_loss = 0.0654
  1%|          | 6/500 [00:13<18:05,  2.20s/it]2025-05-22 00:11:41,062 - GEOMED/Server - INFO - Round 7: training_loss = 0.065
  1%|▏         | 7/500 [00:15<18:16,  2.22s/it]2025-05-22 00:11:43,572 - GEOMED/Server - INFO - Round 8: training

{'loss': 4.116599174204231, 'metrics': {'acc': 0.11949790841989437, 'wacc': 0.11949790958580113}}


In [9]:
from json import dump 

with open(f"notebooks/cifar100/results/waffle_fft_seed{new_seed}.json", "w") as f: 
    d = {
        "fedavg": {"train": fedavg_train_results, "test": fedavg_evaluation_results},
        "krum": {"train": krum_train_results, "test": krum_evaluation_results},
        "mkrum": {"train": mkrum_train_results, "test": mkrum_evaluation_results},
        "trimmedmean": {"train": tmean_train_results, "test": tmean_evaluation_results},
        "geomed": {"train": geomed_train_results, "test": geomed_evaluation_results},
    }
    dump(d, f)