In [1]:
%cd ../..

/Users/davideleo/Desktop/Projects/research/papers/fl_wavelet_v0


In [2]:
import random
import torch 
import numpy as np 
from src.data.cifar100 import get_federation 
from src.data.attacks import GaussianBlur, GaussianNoise
from src.federated_learning.detection.vae import Server, Client
from src.models.neural_networks import LeNet5
from src.models.metrics import Accuracy, WeightedAccuracy
from copy import deepcopy

random.seed(42)
np.random.seed(42)
torch.random.manual_seed(42)

# Federation
model = LeNet5(in_channels = 3, in_padding = 0, num_classes = 100)

federation = get_federation(
    num_shards = 100,
    alpha = 1000,
    attacks = [GaussianNoise(sigma = .5), GaussianBlur(kernel_size = 11)],
    attacks_proba = 0.4
)

clients = [
    Client(
        train_dataset = dataset["train"],
        test_dataset = dataset["test"],
        distribution = dataset["distribution"],
        batch_size = 64,
        device = "cpu"
    ) for dataset in federation
]

benign_clients = [client for dataset, client in zip(federation, clients) if len(dataset["id"].split(".")) == 1] 

In [3]:
import os

# Training
server = Server(
    clients = clients,
    participation_rate = 10, 
    model = deepcopy(model),
    vae_path = os.path.join("notebooks", "cifar100", "results", "vae.pth"),
    num_features = 200
)

train_results = server.train(
    num_rounds = 20,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 1,
    metrics = dict()
)

# Evaluation 
server.clients = benign_clients
evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(evaluation_results["server"])

  0%|          | 0/20 [00:00<?, ?it/s]2025-08-04 23:29:43,962 - FEDAVG/Server - INFO - Round 1: training_loss = 0.0323, evaluation_loss = 2.2967
  5%|▌         | 1/20 [00:05<01:47,  5.68s/it]2025-08-04 23:29:48,221 - FEDAVG/Server - INFO - Round 2: training_loss = 0.0321, evaluation_loss = 2.2767
 10%|█         | 2/20 [00:09<01:27,  4.84s/it]2025-08-04 23:29:52,135 - FEDAVG/Server - INFO - Round 3: training_loss = 0.0318, evaluation_loss = 2.239
 15%|█▌        | 3/20 [00:13<01:15,  4.42s/it]2025-08-04 23:29:55,917 - FEDAVG/Server - INFO - Round 4: training_loss = 0.0313, evaluation_loss = 2.178
 20%|██        | 4/20 [00:17<01:06,  4.17s/it]2025-08-04 23:29:59,624 - FEDAVG/Server - INFO - Round 5: training_loss = 0.0308, evaluation_loss = 2.1322
 25%|██▌       | 5/20 [00:21<01:00,  4.00s/it]2025-08-04 23:30:03,231 - FEDAVG/Server - INFO - Round 6: training_loss = 0.0302, evaluation_loss = 2.0838
 30%|███       | 6/20 [00:24<00:54,  3.87s/it]2025-08-04 23:30:06,925 - FEDAVG/Server - INFO

{'loss': 1.9316258279482523, 'metrics': {'acc': 0.2995000025431315, 'wacc': 0.2995000122288863}}


In [4]:
# Count
malicious_clients = dict()

for k,v in server.malicious_votes.items(): 
    if len(v) > 0:
        malicious_clients[k] = round(sum(v) / len(v))
    else:
        malicious_clients[k] = 0

malicious_clients = np.array([v for v in malicious_clients.values()])

In [5]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Setup clients
domains_wst = []
labels = []

for d in federation: 
    if "." in d["id"]: 
        labels.append(1)
    else:
        labels.append(0)

with torch.no_grad(): 
    pred = np.zeros(100)
    if len(malicious_clients) > 0:
        pred[np.array(malicious_clients)] = 1
    print("Accuracy: ", accuracy_score(labels, pred))
    print("Precision: ", precision_score(labels, pred, pos_label = 1))
    print("Recall:  ", recall_score(labels, pred, pos_label = 1))
    print("F1-Score: ", f1_score(labels, pred, average = "binary"))
    print(torch.arange(100)[pred == 1])

Accuracy:  0.6
Precision:  0.5
Recall:   0.025
F1-Score:  0.047619047619047616
tensor([0, 1])
