In [1]:
%cd ../..

/Users/davideleo/Desktop/Projects/research/papers/fl_wavelet_v0


# Malicious clients analysis
## FedAvg w/ malicious clients 

In [2]:
import random
import torch 
import numpy as np 
import matplotlib.pyplot as plt
from src.data.newsgroups import get_federation 
from src.data.attacks import ShiftEmbedding
from src.federated_learning.standard.fedavg import Client, Server 
from src.models.neural_networks import NewsCNNClassifier
from src.models.metrics import Accuracy, WeightedAccuracy

random.seed(42)
np.random.seed(42)
torch.random.manual_seed(42)

# Federation
federation = get_federation(
    num_shards = 100,
    alpha = 1000,
    attacks = [ShiftEmbedding(proba = 0.5)],
    attacks_proba = 0.4
)

clients = [
    Client(
        train_dataset = dataset["train"],
        test_dataset = dataset["test"],
        distribution = dataset["distribution"],
        batch_size = 64,
        device = "cpu"
    ) for dataset in federation
]

benign_clients = [client for dataset, client in zip(federation, clients) if len(dataset["id"].split(".")) == 1] 

## FedAvg w/o malicious clients

In [3]:
# Default experiments: new_seed = None -> [42, 7, 365]
new_seed = 7

if new_seed is not None: 
    random.seed(new_seed)
    np.random.seed(new_seed)
    torch.random.manual_seed(new_seed)

# Trainin
server = Server(
    clients = benign_clients,
    participation_rate = 10, 
    model = NewsCNNClassifier()
)

benign_train_results = server.train(
    num_rounds = 500,
    num_local_epochs = 1,
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer_class = torch.optim.Adam, 
    optimizer_params = {"lr": 1e-3},
    evaluation_step = 501,
    metrics = dict()
)

# Evaluation 
evaluation_results = server.evaluate(
    criterion = torch.nn.CrossEntropyLoss(),
    metrics = {"acc": Accuracy(), "wacc": WeightedAccuracy()}
)

print(evaluation_results["server"])

  0%|          | 0/500 [00:00<?, ?it/s]2025-05-22 00:36:19,097 - FEDAVG/Server - INFO - Round 1: training_loss = 0.0344
  0%|          | 1/500 [00:01<14:15,  1.71s/it]2025-05-22 00:36:20,537 - FEDAVG/Server - INFO - Round 2: training_loss = 0.0336
  0%|          | 2/500 [00:03<12:53,  1.55s/it]2025-05-22 00:36:22,740 - FEDAVG/Server - INFO - Round 3: training_loss = 0.0318
  1%|          | 3/500 [00:05<15:19,  1.85s/it]2025-05-22 00:36:24,767 - FEDAVG/Server - INFO - Round 4: training_loss = 0.0303
  1%|          | 4/500 [00:07<15:52,  1.92s/it]2025-05-22 00:36:27,241 - FEDAVG/Server - INFO - Round 5: training_loss = 0.0293
  1%|          | 5/500 [00:09<17:29,  2.12s/it]2025-05-22 00:36:29,316 - FEDAVG/Server - INFO - Round 6: training_loss = 0.0283
  1%|          | 6/500 [00:11<17:19,  2.10s/it]2025-05-22 00:36:31,528 - FEDAVG/Server - INFO - Round 7: training_loss = 0.0278
  1%|▏         | 7/500 [00:14<17:34,  2.14s/it]2025-05-22 00:36:33,958 - FEDAVG/Server - INFO - Round 8: trainin

{'loss': 0.653973323593537, 'metrics': {'acc': 0.7463333314756553, 'wacc': 0.7463333424031734}}
