In [1]:
%cd ../..

/Users/davideleo/Desktop/Projects/research/papers/fl_wavelet_v0


# Malicious clients analysis
## FedAvg w/ malicious clients 

In [2]:
import random
import torch 
import numpy as np 
import matplotlib.pyplot as plt
from src.data.fashion_mnist import get_federation 
from src.data.attacks import GaussianBlur, GaussianNoise
from src.federated_learning.detection.fldetector import Client, Server 
from src.models.neural_networks import LeNet5
from src.models.metrics import Accuracy, WeightedAccuracy

random.seed(42)
np.random.seed(42)
torch.random.manual_seed(42)

# Federation
federation = get_federation(
    num_shards = 100,
    alpha = 1000,
    attacks = [GaussianNoise(sigma = .5), GaussianBlur(kernel_size = 11)],
    attacks_proba = 0.4
)

clients = [
    Client(
        train_dataset = dataset["train"],
        test_dataset = dataset["test"],
        distribution = dataset["distribution"],
        batch_size = 64,
        device = "cpu"
    ) for dataset in federation
]

benign_clients = [client for dataset, client in zip(federation, clients) if len(dataset["id"].split(".")) == 1] 

## FedAvg w/o malicious clients

In [3]:
# Default experiments: new_seed = None -> [None, 7, 365]
new_seed = 365

if new_seed is not None: 
    random.seed(new_seed)
    np.random.seed(new_seed)
    torch.random.manual_seed(new_seed)

# Trainin
server = Server(
    clients = clients,
    participation_rate = 0.1, 
    model = LeNet5(),
    window_size = 10,
    batch_size = 10
)

benign_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 501,
    metrics = dict()
)

# Evaluation 
evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-22 11:27:42,224 - FLDETECTOR/Server - INFO - Round 1: training_loss = 0.0344
  0%|          | 1/500 [00:02<22:28,  2.70s/it]2025-05-22 11:27:44,433 - FLDETECTOR/Server - INFO - Round 2: training_loss = 0.0334
  0%|          | 2/500 [00:04<20:01,  2.41s/it]2025-05-22 11:27:46,660 - FLDETECTOR/Server - INFO - Round 3: training_loss = 0.0315
  1%|          | 3/500 [00:07<19:16,  2.33s/it]2025-05-22 11:27:48,930 - FLDETECTOR/Server - INFO - Round 4: training_loss = 0.0296
  1%|          | 4/500 [00:09<19:03,  2.30s/it]2025-05-22 11:27:51,064 - FLDETECTOR/Server - INFO - Round 5: training_loss = 0.0286
  1%|          | 5/500 [00:11<18:30,  2.24s/it]2025-05-22 11:27:53,337 - FLDETECTOR/Server - INFO - Round 6: training_loss = 0.0281
  1%|          | 6/500 [00:13<18:33,  2.25s/it]2025-05-22 11:27:55,522 - FLDETECTOR/Server - INFO - Round 7: training_loss = 0.0273
  1%|▏         | 7/500 [00:16<18:19,  2.23s/it]2025-05-22 11:27:57,658 - FLDETECTOR/

{'loss': 0.5906098579863707, 'metrics': {'acc': 0.7881666690508524, 'wacc': 0.7881666710575421}}


In [4]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Setup clients
domains_wst = []
labels = []

for d in federation: 
    if "." in d["id"]: 
        labels.append(1)
    else:
        labels.append(0)

with torch.no_grad(): 
    pred = np.zeros(100)
    if len(server.malicious_clients) > 0:
        pred[np.array(server.malicious_clients)] = 1
    print("Accuracy: ", accuracy_score(labels, pred))
    print("Precision: ", precision_score(labels, pred, pos_label = 1))
    print("Recall:  ", recall_score(labels, pred, pos_label = 1))
    print("F1-Score: ", f1_score(labels, pred, average = "binary"))
    print(torch.arange(100)[pred == 1])

Accuracy:  0.6
Precision:  0.0
Recall:   0.0
F1-Score:  0.0
tensor([], dtype=torch.int64)
