## Imports & Setup


In [None]:
import uuid
from datetime import datetime


def make_auto_filename(prefix: str, suffix: str = "", ext: str = "pth"):
    """Tạo tên file tự động: <prefix>_<YYYYmmdd_HHMMSS>[_suffix].ext

    Ví dụ:
        fedplus_global_model_20251126_101530.pth
    """
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    parts = [prefix, ts]
    if suffix:
        parts.append(suffix)
    return "_".join(parts) + f".{ext}"



In [None]:
import os
import time
import json
from datetime import datetime
from collections import OrderedDict
from typing import List, Dict, Tuple, Optional
import contextlib

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
try:
    from torch.amp import autocast as torch_autocast, GradScaler
except ImportError:
    from torch.cuda.amp import autocast as torch_autocast, GradScaler

import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from sklearn.metrics import (
    confusion_matrix,
    classification_report,
    precision_recall_fscore_support,
    roc_auc_score,
    roc_curve,
    auc,
    f1_score,
    precision_score,
    recall_score,
    accuracy_score,
)
from sklearn.preprocessing import label_binarize
import pandas as pd
import seaborn as sns

torch.set_printoptions(linewidth=120, sci_mode=False)

## Config


In [None]:
CONFIG = {
    # Path to step2-version3 output
    "data_dir": "/kaggle/input/data-10clients",
    "output_dir": "./results",
    "checkpoint_dir": "./checkpoints",  # NEW: Checkpoint directory

    # Federation
    "num_clients": 10,
    "algorithm": "fedavgm",   # "fedavg" or "fedprox" or "fedavgm" or "fedplus"

    # Fedavgm Parameters 
    "server_momentum": 0.9,
    "server_lr": 1,

    # Model
    "input_shape": None,
    "num_classes": None,

    # Training
    "num_rounds": 5,
    "local_epochs": 3,
    "learning_rate": 1e-3,
    "batch_size": 1024,
    "mu": 0.01,

    # Device
    "device": "cuda" if torch.cuda.is_available() else "cpu",

    # Eval & Checkpoint
    "eval_every": 1,
    "save_checkpoint_every": 1,  # NEW: Save checkpoint every N rounds
    "resume_from_checkpoint": None,  # NEW: Path to checkpoint to resume from
}

device = CONFIG["device"]
print(f"Using device: {device}")
print(f"Algorithm: {CONFIG['algorithm']}")

# Create checkpoint directory
os.makedirs(CONFIG["checkpoint_dir"], exist_ok=True)
os.makedirs(CONFIG["output_dir"], exist_ok=True)

## Checkpoint Manager & Metrics


In [None]:
class CheckpointManager:
    """
    Quản lý checkpoint cho Federated Learning.
    Lưu và load model, optimizer state, history sau mỗi round.
    """
    def __init__(self, checkpoint_dir: str, algorithm: str, num_clients: int):
        self.checkpoint_dir = checkpoint_dir
        self.algorithm = algorithm
        self.num_clients = num_clients
        self.run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
        # Đổi tên folder thành algorithm_num_clients (ví dụ: fedplus_5)
        self.folder_name = f"{algorithm}_{num_clients}"
        self.run_dir = os.path.join(checkpoint_dir, self.folder_name)
        os.makedirs(self.run_dir, exist_ok=True)
        print(f"✓ Checkpoint directory: {self.run_dir}")
    
    def get_run_id(self):
        """Trả về run_id để dùng chung với results."""
        return self.run_id
    
    def get_folder_name(self):
        """Trả về folder name để dùng chung với results."""
        return self.folder_name
    
    def save_checkpoint(self, round_idx: int, server, history: Dict, 
                        config: Dict, extra_info: Dict = None):
        """
        Lưu checkpoint sau mỗi round.
        
        Args:
            round_idx: Round index (0-based)
            server: FederatedServerV2 instance
            history: Training history dict
            config: CONFIG dict
            extra_info: Additional info to save
        """
        checkpoint = {
            "round": round_idx,
            "global_model_state": server.global_model.state_dict(),
            "velocity": server.velocity if hasattr(server, 'velocity') else None,
            "history": history,
            "config": {k: str(v) if not isinstance(v, (int, float, str, bool, type(None))) else v 
                      for k, v in config.items()},
            "timestamp": datetime.now().isoformat(),
        }
        
        if extra_info:
            checkpoint["extra_info"] = extra_info
        
        # Save checkpoint
        ckpt_path = os.path.join(self.run_dir, f"checkpoint_round_{round_idx+1:03d}.pt")
        torch.save(checkpoint, ckpt_path)
        
        # Also save as latest
        latest_path = os.path.join(self.run_dir, "checkpoint_latest.pt")
        torch.save(checkpoint, latest_path)
        
        # Save history as JSON for easy viewing
        hist_path = os.path.join(self.run_dir, "history.json")
        with open(hist_path, "w") as f:
            json.dump(history, f, indent=2)
        
        print(f"  💾 Saved checkpoint: round {round_idx+1}")
        return ckpt_path
    
    @staticmethod
    def load_checkpoint(checkpoint_path: str, server, device: str):
        """
        Load checkpoint và restore state.
        
        Returns:
            round_idx: Round đã train xong
            history: Training history
        """
        print(f"\n📂 Loading checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        
        # Restore model
        server.global_model.load_state_dict(checkpoint["global_model_state"])
        
        # Restore velocity (for FedAvgM)
        if checkpoint.get("velocity") is not None:
            server.velocity = checkpoint["velocity"]
        
        # Restore history
        server.history = checkpoint["history"]
        
        round_idx = checkpoint["round"]
        print(f"  ✓ Restored from round {round_idx + 1}")
        print(f"  ✓ Last accuracy: {checkpoint['history']['test_accuracy'][-1]*100:.2f}%")
        
        return round_idx, checkpoint["history"]
    
    def get_all_checkpoints(self):
        """List all checkpoints in run directory."""
        ckpts = []
        for f in os.listdir(self.run_dir):
            if f.startswith("checkpoint_round_") and f.endswith(".pt"):
                ckpts.append(os.path.join(self.run_dir, f))
        return sorted(ckpts)

In [None]:
class MetricsCalculator:
    """
    Tính toán các metrics cho multiclass classification:
    - Accuracy, Precision, Recall, F1 (macro, micro, weighted)
    - ROC-AUC (one-vs-rest)
    - Per-class metrics
    """
    
    @staticmethod
    def compute_all_metrics(y_true: np.ndarray, y_pred: np.ndarray, 
                           y_proba: np.ndarray = None, num_classes: int = None):
        """
        Tính tất cả metrics.
        
        Args:
            y_true: Ground truth labels
            y_pred: Predicted labels
            y_proba: Predicted probabilities (softmax output) - for AUC
            num_classes: Number of classes
            
        Returns:
            Dict chứa tất cả metrics
        """
        metrics = {}
        
        # Basic metrics
        metrics["accuracy"] = accuracy_score(y_true, y_pred)
        
        # Precision, Recall, F1 - multiple averaging strategies
        for avg in ["macro", "micro", "weighted"]:
            metrics[f"precision_{avg}"] = precision_score(y_true, y_pred, average=avg, zero_division=0)
            metrics[f"recall_{avg}"] = recall_score(y_true, y_pred, average=avg, zero_division=0)
            metrics[f"f1_{avg}"] = f1_score(y_true, y_pred, average=avg, zero_division=0)
        
        # Per-class metrics
        precision_per_class, recall_per_class, f1_per_class, support = \
            precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0)
        
        metrics["precision_per_class"] = precision_per_class.tolist()
        metrics["recall_per_class"] = recall_per_class.tolist()
        metrics["f1_per_class"] = f1_per_class.tolist()
        metrics["support_per_class"] = support.tolist()
        
        # ROC-AUC (requires probability scores)
        if y_proba is not None and num_classes is not None:
            try:
                # Binarize labels for multiclass AUC
                y_true_bin = label_binarize(y_true, classes=list(range(num_classes)))
                
                # Handle case where not all classes are present
                if y_true_bin.shape[1] == 1:
                    y_true_bin = np.hstack([1 - y_true_bin, y_true_bin])
                
                # Macro AUC (one-vs-rest)
                auc_ovr = roc_auc_score(y_true_bin, y_proba, average="macro", multi_class="ovr")
                metrics["auc_macro_ovr"] = auc_ovr
                
                # Weighted AUC
                auc_weighted = roc_auc_score(y_true_bin, y_proba, average="weighted", multi_class="ovr")
                metrics["auc_weighted_ovr"] = auc_weighted
                
                # Per-class AUC
                auc_per_class = []
                for i in range(num_classes):
                    if y_true_bin[:, i].sum() > 0:  # Only if class exists in y_true
                        try:
                            auc_i = roc_auc_score(y_true_bin[:, i], y_proba[:, i])
                            auc_per_class.append(auc_i)
                        except:
                            auc_per_class.append(None)
                    else:
                        auc_per_class.append(None)
                metrics["auc_per_class"] = auc_per_class
                
            except Exception as e:
                print(f"  ⚠ Warning: Could not compute AUC: {e}")
                metrics["auc_macro_ovr"] = None
                metrics["auc_weighted_ovr"] = None
                metrics["auc_per_class"] = None
        
        return metrics
    
    @staticmethod
    def print_metrics_summary(metrics: Dict, round_idx: int = None):
        """In tóm tắt metrics."""
        header = f"Round {round_idx+1}" if round_idx is not None else "Evaluation"
        print(f"\n{'='*60}")
        print(f"📊 METRICS SUMMARY - {header}")
        print(f"{'='*60}")
        print(f"  Accuracy:           {metrics['accuracy']*100:.2f}%")
        print(f"  F1 (macro):         {metrics['f1_macro']*100:.2f}%")
        print(f"  F1 (weighted):      {metrics['f1_weighted']*100:.2f}%")
        print(f"  Precision (macro):  {metrics['precision_macro']*100:.2f}%")
        print(f"  Recall (macro):     {metrics['recall_macro']*100:.2f}%")
        if metrics.get("auc_macro_ovr") is not None:
            print(f"  AUC (macro OvR):    {metrics['auc_macro_ovr']*100:.2f}%")
        print(f"{'='*60}")

In [None]:
class NonIIDAnalyzer:
    """
    Phân tích và visualize Non-IID distribution của data.
    """
    
    @staticmethod
    def analyze_client_distribution(client_data: List[Dict], num_classes: int):
        """
        Phân tích label distribution của mỗi client.
        
        Returns:
            DataFrame với distribution stats
        """
        print("\n" + "="*60)
        print("📊 NON-IID ANALYSIS: Client Data Distribution")
        print("="*60)
        
        stats = []
        for cid, data in enumerate(client_data):
            y = data['y_train'].cpu().numpy()
            unique, counts = np.unique(y, return_counts=True)
            
            # Create full distribution
            full_dist = np.zeros(num_classes)
            for u, c in zip(unique, counts):
                full_dist[u] = c
            
            stats.append({
                "client_id": cid,
                "total_samples": len(y),
                "num_classes": len(unique),
                "class_distribution": full_dist,
                "dominant_class": unique[np.argmax(counts)],
                "dominant_ratio": counts.max() / len(y),
            })
            
            print(f"  Client {cid}: {len(y):,} samples, {len(unique)} classes, "
                  f"dominant class {unique[np.argmax(counts)]} ({counts.max()/len(y)*100:.1f}%)")
        
        return stats
    
    @staticmethod
    def plot_client_distribution(client_stats: List[Dict], num_classes: int, 
                                 save_path: str = None):
        """
        Vẽ heatmap của label distribution per client.
        """
        # Build distribution matrix
        num_clients = len(client_stats)
        dist_matrix = np.zeros((num_clients, num_classes))
        
        for stat in client_stats:
            cid = stat["client_id"]
            dist_matrix[cid] = stat["class_distribution"]
        
        # Normalize to percentages
        dist_matrix_pct = dist_matrix / dist_matrix.sum(axis=1, keepdims=True) * 100
        
        # Plot
        fig, axes = plt.subplots(1, 2, figsize=(18, max(6, num_clients * 0.5)))
        
        # Heatmap - raw counts
        sns.heatmap(dist_matrix, ax=axes[0], cmap="YlOrRd", 
                    xticklabels=[f"C{i}" for i in range(num_classes)],
                    yticklabels=[f"Client {i}" for i in range(num_clients)],
                    cbar_kws={"label": "Sample Count"})
        axes[0].set_title("Label Distribution (Raw Counts)", fontsize=12)
        axes[0].set_xlabel("Class")
        axes[0].set_ylabel("Client")
        
        # Heatmap - percentages
        sns.heatmap(dist_matrix_pct, ax=axes[1], cmap="YlGnBu", 
                    xticklabels=[f"C{i}" for i in range(num_classes)],
                    yticklabels=[f"Client {i}" for i in range(num_clients)],
                    cbar_kws={"label": "Percentage (%)"}, vmin=0, vmax=100)
        axes[1].set_title("Label Distribution (Percentage)", fontsize=12)
        axes[1].set_xlabel("Class")
        axes[1].set_ylabel("Client")
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
            print(f"  💾 Saved Non-IID plot: {save_path}")
        
        plt.show()
        return fig

## Model: CNN-GRU


In [None]:
class CNN_GRU_Model(nn.Module):
    def __init__(self, input_shape, num_classes: int = 34):
        super().__init__()
        if isinstance(input_shape, tuple):
            seq_length = input_shape[0]
        else:
            seq_length = int(input_shape)

        self.input_shape = input_shape
        self.num_classes = num_classes

        # CNN blocks
        self.conv1 = nn.Conv1d(1, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm1d(64)
        self.pool1 = nn.MaxPool1d(2)
        self.dropout_cnn1 = nn.Dropout(0.2)

        self.conv2 = nn.Conv1d(64, 128, 3, padding=1)
        self.bn2 = nn.BatchNorm1d(128)
        self.pool2 = nn.MaxPool1d(2)
        self.dropout_cnn2 = nn.Dropout(0.2)

        self.conv3 = nn.Conv1d(128, 256, 3, padding=1)
        self.bn3 = nn.BatchNorm1d(256)
        self.pool3 = nn.MaxPool1d(2)
        self.dropout_cnn3 = nn.Dropout(0.3)

        # Calculate CNN output size
        cnn_len = seq_length
        for _ in range(3):  # 3 pooling layers
            cnn_len = cnn_len // 2
        self.cnn_output_size = 256 * cnn_len

        # GRU
        self.gru1 = nn.GRU(1, 128, batch_first=True)
        self.gru2 = nn.GRU(128, 64, batch_first=True)
        self.gru_output_size = 64

        # MLP
        concat_size = self.cnn_output_size + self.gru_output_size
        self.dense1 = nn.Linear(concat_size, 256)
        self.bn_mlp1 = nn.BatchNorm1d(256)
        self.dropout1 = nn.Dropout(0.4)
        self.dense2 = nn.Linear(256, 128)
        self.bn_mlp2 = nn.BatchNorm1d(128)
        self.dropout2 = nn.Dropout(0.3)
        self.output = nn.Linear(128, num_classes)
        self.relu = nn.ReLU()

    def forward(self, x):
        if x.ndim == 2:
            x = x.unsqueeze(-1)
        batch_size = x.size(0)

        # CNN
        x_cnn = x.permute(0, 2, 1)
        x_cnn = self.pool1(self.relu(self.bn1(self.conv1(x_cnn))))
        x_cnn = self.dropout_cnn1(x_cnn)
        x_cnn = self.pool2(self.relu(self.bn2(self.conv2(x_cnn))))
        x_cnn = self.dropout_cnn2(x_cnn)
        x_cnn = self.pool3(self.relu(self.bn3(self.conv3(x_cnn))))
        x_cnn = self.dropout_cnn3(x_cnn)
        cnn_output = x_cnn.view(batch_size, -1)

        # GRU
        x_gru = x
        x_gru, _ = self.gru1(x_gru)
        x_gru, _ = self.gru2(x_gru)
        gru_output = x_gru[:, -1, :]

        # Concat & MLP
        z = torch.cat([cnn_output, gru_output], dim=1)
        z = self.dense1(z)
        if z.size(0) > 1:
            z = self.bn_mlp1(z)
        z = self.relu(z)
        z = self.dropout1(z)
        z = self.dense2(z)
        if z.size(0) > 1:
            z = self.bn_mlp2(z)
        z = self.relu(z)
        z = self.dropout2(z)
        return self.output(z)

## NEW: Pre-load Data into RAM


In [None]:
def load_all_client_data_v3(data_dir: str, num_clients: int, device: str):
    """
    Load ALL client data into RAM at once (step2-version3 format).
    
    Returns:
        client_data: List[Dict] with 'X_train', 'y_train' as tensors
        test_data: Dict with 'X_test', 'y_test' as tensors
        input_shape, num_classes
    """
    print("\n" + "="*80)
    print("PRE-LOADING ALL DATA INTO RAM (V3 Format)")
    print("="*80)
    
    client_data = []
    all_labels = []
    
    # Load each client train data
    for cid in range(num_clients):
        path = os.path.join(data_dir, f"client_{cid}_train.npz")
        if not os.path.exists(path):
            raise FileNotFoundError(f"Missing: {path}")
        
        data = np.load(path)
        X_train = data['X_train'].astype(np.float32)
        y_train = data['y_train'].astype(np.int64)
        
        # Convert to tensors and move to device
        X_train_t = torch.from_numpy(X_train).to(device)
        y_train_t = torch.from_numpy(y_train).to(device)
        
        client_data.append({
            'X_train': X_train_t,
            'y_train': y_train_t,
            'num_samples': len(y_train)
        })
        
        all_labels.append(y_train)
        
        if (cid + 1) % 100 == 0 or cid == num_clients - 1:
            print(f"  Loaded client {cid+1}/{num_clients}: {len(y_train):,} samples")
    
    # Load global test data
    test_path = os.path.join(data_dir, "global_test_data.npz")
    if not os.path.exists(test_path):
        raise FileNotFoundError(f"Missing: {test_path}")
    
    test_npz = np.load(test_path)
    X_test = test_npz['X_test'].astype(np.float32)
    y_test = test_npz['y_test'].astype(np.int64)
    
    # Move to device
    X_test_t = torch.from_numpy(X_test).to(device)
    y_test_t = torch.from_numpy(y_test).to(device)
    
    test_data = {
        'X_test': X_test_t,
        'y_test': y_test_t,
        'num_samples': len(y_test)
    }
    
    print(f"\n  ✓ Loaded global test: {len(y_test):,} samples")
    
    # Detect input_shape & num_classes
    input_shape = (client_data[0]['X_train'].shape[1],)
    all_labels_np = np.concatenate(all_labels)
    num_classes = int(len(np.unique(all_labels_np)))
    
    total_train = sum(c['num_samples'] for c in client_data)
    print(f"\n  input_shape: {input_shape}")
    print(f"  num_classes: {num_classes}")
    print(f"  total_train: {total_train:,}")
    print(f"  total_test:  {len(y_test):,}")
    print("\n  ✓ All data pre-loaded into RAM!")
    print("="*80)
    
    return client_data, test_data, input_shape, num_classes

## FederatedClient V2 (RAM-based)


In [None]:
class FederatedClientV2:
    """
    Federated client that works with pre-loaded RAM data.
    No DataLoader overhead - direct tensor operations.
    """
    def __init__(self, client_id: int, model: nn.Module, 
                 X_train: torch.Tensor, y_train: torch.Tensor, device: str):
        self.client_id = client_id
        self.model = model
        self.X_train = X_train
        self.y_train = y_train
        self.device = device
        self.num_samples = len(y_train)
        
        self.use_amp = (device == "cuda" and torch.cuda.is_available())
    
    def _amp_ctx(self):
        return (
            torch_autocast(device_type="cuda", dtype=torch.float16)
            if self.use_amp else contextlib.nullcontext()
        )
    
    def _create_batches(self, batch_size: int):
        """Create random batches from pre-loaded data."""
        indices = torch.randperm(self.num_samples, device=self.device)
        for i in range(0, self.num_samples, batch_size):
            batch_idx = indices[i:i+batch_size]
            yield self.X_train[batch_idx], self.y_train[batch_idx]
    
    def train_fedavg(self, epochs: int, batch_size: int, lr: float, verbose: bool = True):
        self.model.train()
        optimizer = optim.Adam(self.model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()
        scaler = GradScaler(enabled=self.use_amp)
        
        total_loss = 0.0
        total_samples = 0
        
        for ep in range(epochs):
            ep_loss = 0.0
            ep_samples = 0
            
            for X_batch, y_batch in self._create_batches(batch_size):
                optimizer.zero_grad()
                
                with self._amp_ctx():
                    out = self.model(X_batch)
                    loss = criterion(out, y_batch)
                
                if self.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    optimizer.step()
                
                bs = len(y_batch)
                ep_loss += loss.item() * bs
                ep_samples += bs
            
            total_loss += ep_loss
            total_samples += ep_samples
        
        return {
            "client_id": self.client_id,
            "num_samples": self.num_samples,
            "loss": total_loss / max(1, total_samples)
        }
    
    def train_fedprox(self, epochs: int, batch_size: int, global_params: OrderedDict,
                      mu: float, lr: float, verbose: bool = True):
        self.model.train()
        optimizer = optim.Adam(self.model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()
        scaler = GradScaler(enabled=self.use_amp)
        
        total_loss = 0.0
        total_samples = 0
        
        for ep in range(epochs):
            ep_loss = 0.0
            ep_samples = 0
            
            for X_batch, y_batch in self._create_batches(batch_size):
                optimizer.zero_grad()
                
                with self._amp_ctx():
                    out = self.model(X_batch)
                    ce_loss = criterion(out, y_batch)
                    
                    # Proximal term
                    prox = 0.0
                    for name, param in self.model.named_parameters():
                        if param.requires_grad:
                            gp = global_params[name].to(self.device)
                            prox += torch.sum((param - gp)**2)
                    
                    loss = ce_loss + (mu / 2.0) * prox
                
                if self.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    optimizer.step()
                
                bs = len(y_batch)
                ep_loss += loss.item() * bs
                ep_samples += bs
            
            total_loss += ep_loss
            total_samples += ep_samples
        
        return {
            "client_id": self.client_id,
            "num_samples": self.num_samples,
            "loss": total_loss / max(1, total_samples)
        }
    def train_fedplus(self, epochs: int, batch_size: int, global_params: OrderedDict,
                      mu: float, lr: float, verbose: bool = True, use_sgd: bool = True):
        """
        Train using Fed+ update rule (Eq. 8 in the paper).
        
        Paper: "Federated Learning for Mobile Keyboard Prediction" (Yang et al.)
        Reference [290] in your citations.
        
        Update Rule (Eq. 8):
            θ = 1 / (1 + ν * λ)       where ν=learning_rate, λ=mu
            W_new = θ * W_updated + (1-θ) * W_global
        
        Note: Paper derives this formula for vanilla SGD. Using Adam is an approximation
        since Adam has adaptive per-parameter learning rates.
        
        Args:
            use_sgd: If True (default), use SGD as per paper. If False, use Adam.
        """
        self.model.train()
        
        # SGD is recommended for Fed+ as the paper derives formula for fixed lr
        if use_sgd:
            optimizer = optim.SGD(self.model.parameters(), lr=lr)
        else:
            optimizer = optim.Adam(self.model.parameters(), lr=lr)
            
        criterion = nn.CrossEntropyLoss()
        scaler = GradScaler(enabled=self.use_amp)
        
        # Calculate theta: θ = 1 / (1 + ν * λ)
        # ν = learning rate, λ = regularization strength (mu)
        theta = 1.0 / (1.0 + lr * mu)
        
        total_loss = 0.0
        total_samples = 0
        
        for ep in range(epochs):
            ep_loss = 0.0
            ep_samples = 0
            
            for X_batch, y_batch in self._create_batches(batch_size):
                optimizer.zero_grad()
                
                # 1. Standard Forward & Backward
                with self._amp_ctx():
                    out = self.model(X_batch)
                    loss = criterion(out, y_batch)
                
                if self.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    optimizer.step()  # W = W - lr * grad
                
                # 2. Apply Fed+ Correction (Eq. 8)
                # After optimizer step: W_local = W_old - lr * grad
                # Now interpolate towards global: W = θ * W_local + (1-θ) * W_global
                with torch.no_grad():
                    for name, param in self.model.named_parameters():
                        if name in global_params:
                            gp = global_params[name].to(self.device)
                            param.data = theta * param.data + (1.0 - theta) * gp

                bs = len(y_batch)
                ep_loss += loss.item() * bs
                ep_samples += bs
            
            total_loss += ep_loss
            total_samples += ep_samples
        
        return {
            "client_id": self.client_id,
            "num_samples": self.num_samples,
            "loss": total_loss / max(1, total_samples)
        }
    def get_model_params(self):
        return OrderedDict((k, v.detach().cpu().clone())
                           for k, v in self.model.state_dict().items())
    
    def set_model_params(self, params: OrderedDict):
        self.model.load_state_dict(params)

## FederatedServer V2


In [None]:
class FederatedServerV2:
    def __init__(self, global_model: nn.Module, clients: List[FederatedClientV2],
                 test_data: Dict, device: str, config: Dict = None):
        self.global_model = global_model.to(device)
        self.clients = clients
        self.test_data = test_data
        self.device = device
        self.num_classes = config.get("num_classes", 34) if config else 34
        
        # Extended history với nhiều metrics hơn
        self.history = {
            "train_loss": [],
            "test_loss": [],
            "test_accuracy": [],
            "test_f1_macro": [],
            "test_f1_weighted": [],
            "test_precision_macro": [],
            "test_recall_macro": [],
            "test_auc_macro": [],
        }
    
        # --- FedAvgM SETUP ---
        self.server_momentum = config.get("server_momentum", 0.0) if config else 0.0
        self.server_lr = config.get("server_lr", 1.0) if config else 1.0
        
        # Init velocity vectors for fedavgm
        self.velocity = OrderedDict(
            (k, torch.zeros_like(v).to(device)) 
            for k, v in global_model.state_dict().items()
        )
    
    def get_global_params(self):
        return OrderedDict((k, v.detach().clone())
                           for k, v in self.global_model.state_dict().items())
    
    def set_global_params(self, params: OrderedDict):
        self.global_model.load_state_dict(params)
    
    def distribute_model(self):
        """Distribute global model to ALL clients."""
        g = self.get_global_params()
        for c in self.clients:
            c.set_model_params(g)
    
    def aggregate_fedavg(self, results: List[Dict]):
        total_samples = sum(r["num_samples"] for r in results)
        agg = self.get_global_params()
        
        for k, v in agg.items():
            if v.dtype.is_floating_point:
                agg[k] = torch.zeros_like(v)
        
        for r in results:
            cid = r["client_id"]
            w_i = r["num_samples"] / max(1, total_samples)
            client_params = self.clients[cid].get_model_params()
            
            for k in agg.keys():
                p = client_params[k].to(self.device)
                if p.dtype.is_floating_point:
                    agg[k] = agg[k] + w_i * p
                else:
                    agg[k] = p
        return agg
    
    def train_round_fedavg(self, num_epochs: int, batch_size: int, lr: float,
                           verbose: bool = True):
        """Train one round with FedAvg using ALL clients."""
        n_clients = len(self.clients)
        
        if verbose:
            print(f"→ FedAvg: training ALL {n_clients} clients.")
        
        self.distribute_model()
        
        results = []
        for c in tqdm(self.clients, desc="Training clients", leave=False):
            r = c.train_fedavg(num_epochs, batch_size, lr, verbose=False)
            results.append(r)
        
        new_params = self.aggregate_fedavg(results)
        self.set_global_params(new_params)
        
        avg_loss = float(np.mean([r["loss"] for r in results]))
        if verbose:
            print(f"→ Train loss: {avg_loss:.4f}")
        return {"train_loss": avg_loss}
    
    def train_round_fedprox(self, num_epochs: int, batch_size: int, mu: float, 
                            lr: float, verbose: bool = True):
        """Train one round with FedProx using ALL clients."""
        n_clients = len(self.clients)
        
        if verbose:
            print(f"→ FedProx: training ALL {n_clients} clients.")
        
        global_params = self.get_global_params()
        self.distribute_model()
        
        results = []
        for c in tqdm(self.clients, desc="Training clients", leave=False):
            r = c.train_fedprox(num_epochs, batch_size, global_params, mu, lr, verbose=False)
            results.append(r)
        
        new_params = self.aggregate_fedavg(results)
        self.set_global_params(new_params)
        
        avg_loss = float(np.mean([r["loss"] for r in results]))
        if verbose:
            print(f"→ Train loss: {avg_loss:.4f}")
        return {"train_loss": avg_loss}
    
    def train_round_fedavgm(self, num_epochs: int, batch_size: int, lr: float, verbose: bool = True):
        """
        Train one round with FedAvgM (Hsu et al., 2019).
        """
        n_clients = len(self.clients)
        if verbose:
            print(f"→ FedAvgM (β={self.server_momentum}, η={self.server_lr}): training ALL {n_clients} clients.")

        w_t = self.get_global_params()
        self.distribute_model()
        
        results = []
        for c in tqdm(self.clients, desc="Training clients", leave=False):
            r = c.train_fedavg(num_epochs, batch_size, lr, verbose=False)
            results.append(r)
        
        w_avg = self.aggregate_fedavg(results)
        
        beta = self.server_momentum
        new_params = OrderedDict()
        
        for k in w_t.keys():
            if w_t[k].dtype.is_floating_point:
                curr_w = w_t[k].to(self.device)
                avg_w = w_avg[k].to(self.device)
                curr_v = self.velocity[k].to(self.device)
                
                delta = avg_w - curr_w
                # FedAvgM momentum: v_{t+1} = β * v_t + Δw (tích lũy momentum)
                new_v = beta * curr_v + delta
                self.velocity[k] = new_v
                # Update: w_{t+1} = w_t + η * v_{t+1}
                new_params[k] = curr_w + self.server_lr * new_v
            else:
                new_params[k] = w_avg[k]
                
        self.set_global_params(new_params)
        
        avg_loss = float(np.mean([r["loss"] for r in results]))
        if verbose:
            print(f"→ Train loss: {avg_loss:.4f}")
        return {"train_loss": avg_loss}
    
    def train_round_fedplus(self, num_epochs: int, batch_size: int, mu: float, lr: float, verbose: bool = True):
        """Train one round with Fed+."""
        n_clients = len(self.clients)
        
        if verbose:
            print(f"→ Fed+ (mu={mu}): training ALL {n_clients} clients.")
        
        global_params = self.get_global_params()
        self.distribute_model()
        
        results = []
        for c in tqdm(self.clients, desc="Training clients", leave=False):
            r = c.train_fedplus(num_epochs, batch_size, global_params, mu, lr, verbose=False)
            results.append(r)
        
        new_params = self.aggregate_fedavg(results)
        self.set_global_params(new_params)
        
        avg_loss = float(np.mean([r["loss"] for r in results]))
        if verbose:
            print(f"→ Train loss: {avg_loss:.4f}")
        return {"train_loss": avg_loss}
    
    def evaluate_global(self, batch_size: int = 1024, compute_auc: bool = True):
        """
        Evaluate với đầy đủ metrics: Accuracy, F1, Precision, Recall, AUC.
        """
        self.global_model.eval()
        criterion = nn.CrossEntropyLoss()
        
        X_test = self.test_data['X_test']
        y_test = self.test_data['y_test']
        n_test = len(y_test)
        
        all_preds = []
        all_targets = []
        all_proba = []
        total_loss = 0.0
        
        with torch.no_grad():
            for i in range(0, n_test, batch_size):
                X_batch = X_test[i:i+batch_size]
                y_batch = y_test[i:i+batch_size]
                
                out = self.global_model(X_batch)
                loss = criterion(out, y_batch)
                total_loss += loss.item() * len(y_batch)
                
                # Predictions
                proba = F.softmax(out, dim=1)
                preds = out.argmax(dim=1)
                
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(y_batch.cpu().numpy())
                all_proba.append(proba.cpu().numpy())
        
        # Convert to numpy
        y_true = np.array(all_targets)
        y_pred = np.array(all_preds)
        y_proba = np.vstack(all_proba) if compute_auc else None
        
        # Compute all metrics
        metrics = MetricsCalculator.compute_all_metrics(
            y_true, y_pred, y_proba, self.num_classes
        )
        metrics["loss"] = total_loss / n_test
        
        return metrics
    
    def evaluate_global_simple(self, batch_size: int = 1024):
        """Backward compatible - simple evaluation."""
        metrics = self.evaluate_global(batch_size, compute_auc=False)
        return {"accuracy": metrics["accuracy"], "loss": metrics["loss"]}

## Training Loop


In [None]:
def train_federated_v3(server: FederatedServerV2, config: Dict, 
                       checkpoint_manager: CheckpointManager = None,
                       start_round: int = 0):
    """
    Training loop với checkpoint và extended metrics.
    
    Args:
        server: FederatedServerV2
        config: CONFIG dict
        checkpoint_manager: CheckpointManager instance
        start_round: Round to start from (for resume)
    """
    algo = config["algorithm"].lower()
    R = config["num_rounds"]
    E = config["local_epochs"]
    bs = config["batch_size"]
    lr = config["learning_rate"]
    eval_every = config["eval_every"]
    save_every = config.get("save_checkpoint_every", 1)
    mu = config["mu"]
    
    history = server.history
    
    for ridx in tqdm(range(start_round, R), desc="Global Rounds", initial=start_round, total=R):
        print(f"\n{'='*60}")
        print(f"ROUND {ridx+1}/{R} ({algo})")
        print(f"{'='*60}")
        
        # Train
        if algo == "fedavg":
            r_res = server.train_round_fedavg(E, bs, lr, verbose=True)
        elif algo == "fedprox":
            r_res = server.train_round_fedprox(E, bs, mu, lr, verbose=True)
        elif algo == "fedavgm":
            r_res = server.train_round_fedavgm(E, bs, lr, verbose=True)
        elif algo == "fedplus":
            r_res = server.train_round_fedplus(E, bs, mu, lr, verbose=True)
        else:
            raise ValueError(f"Unknown algorithm: {algo}")
        
        # Evaluate
        if (ridx + 1) % eval_every == 0:
            print("\n  📊 Evaluating global model...")
            metrics = server.evaluate_global(bs, compute_auc=True)
            
            # Update history
            history["train_loss"].append(r_res["train_loss"])
            history["test_loss"].append(metrics["loss"])
            history["test_accuracy"].append(metrics["accuracy"])
            history["test_f1_macro"].append(metrics["f1_macro"])
            history["test_f1_weighted"].append(metrics["f1_weighted"])
            history["test_precision_macro"].append(metrics["precision_macro"])
            history["test_recall_macro"].append(metrics["recall_macro"])
            history["test_auc_macro"].append(metrics.get("auc_macro_ovr"))
            
            # Print summary
            MetricsCalculator.print_metrics_summary(metrics, ridx)
        
        # Save checkpoint
        if checkpoint_manager and (ridx + 1) % save_every == 0:
            checkpoint_manager.save_checkpoint(ridx, server, history, config)
    
    return history

## Main Execution


In [None]:
# Step 1: Pre-load all data into RAM
client_data, test_data, input_shape, num_classes = load_all_client_data_v3(
    CONFIG["data_dir"],
    CONFIG["num_clients"],
    CONFIG["device"]
)

CONFIG["input_shape"] = input_shape
CONFIG["num_classes"] = num_classes

print(f"\nConfig: {json.dumps(CONFIG, indent=2, default=str)}")

In [None]:
# Step 2: Initialize models & clients
print("\nInitializing global model & clients...")

global_model = CNN_GRU_Model(input_shape, num_classes).to(device)
base_state = global_model.state_dict()

clients = []
for cid in range(CONFIG["num_clients"]):
    m = CNN_GRU_Model(input_shape, num_classes).to(device)
    m.load_state_dict(base_state)
    
    c = FederatedClientV2(
        client_id=cid,
        model=m,
        X_train=client_data[cid]['X_train'],
        y_train=client_data[cid]['y_train'],
        device=device
    )
    clients.append(c)

server = FederatedServerV2(global_model, clients, test_data, device, config=CONFIG)

# Initialize checkpoint manager
ckpt_manager = CheckpointManager(CONFIG["checkpoint_dir"], CONFIG["algorithm"], CONFIG["num_clients"])

print(f"✓ Initialized {len(clients)} clients")

# Analyze Non-IID distribution
noniid_stats = NonIIDAnalyzer.analyze_client_distribution(client_data, num_classes)
# Visualization đã chuyển sang script non_iid_visualize.py để tránh lỗi kích thước khi số client lớn.

In [None]:
# Step 3: Train with checkpoints
start_time = datetime.now()

# Check for resume
start_round = 0
if CONFIG.get("resume_from_checkpoint"):
    start_round, _ = CheckpointManager.load_checkpoint(
        CONFIG["resume_from_checkpoint"], server, device
    )
    start_round += 1  # Start from next round

history = train_federated_v3(server, CONFIG, ckpt_manager, start_round)

end_time = datetime.now()
duration = (end_time - start_time).total_seconds()

print(f"\n✓ Training complete! Duration: {duration:.2f}s ({duration/60:.2f} min)")

## Visualization Functions


In [None]:
def plot_roc_curves(server, num_classes: int, class_names: List[str] = None,
                    save_path: str = None, top_k: int = 10):
    """
    Vẽ ROC curves cho multiclass classification (Đã fix lỗi NaN).
    """
    print("\n📈 Generating ROC Curves...")
    server.global_model.eval()
    
    X_test = server.test_data['X_test']
    y_test = server.test_data['y_test']
    batch_size = 4096
    
    all_proba = []
    all_targets = []
    
    with torch.no_grad():
        for i in tqdm(range(0, len(y_test), batch_size), desc="Computing probabilities"):
            X_batch = X_test[i:i+batch_size]
            y_batch = y_test[i:i+batch_size]
            
            out = server.global_model(X_batch)
            proba = F.softmax(out, dim=1)
            
            all_proba.append(proba.cpu().numpy())
            all_targets.extend(y_batch.cpu().numpy())
    
    y_true = np.array(all_targets)
    y_proba = np.vstack(all_proba)
    
    # --- PHẦN SỬA LỖI (FIX NaN) ---
    # Kiểm tra xem có giá trị NaN hay Infinity không
    if not np.isfinite(y_proba).all():
        print("\n⚠️ WARNING: Phát hiện NaN/Inf trong dự đoán của model!")
        print("  -> Đang thay thế NaN bằng 0.0 để tiếp tục vẽ biểu đồ.")
        # Thay thế NaN bằng 0.0, Inf bằng số cực lớn/nhỏ giới hạn
        y_proba = np.nan_to_num(y_proba, nan=0.0, posinf=1.0, neginf=0.0)
    # ------------------------------
    
    # Binarize labels
    y_true_bin = label_binarize(y_true, classes=list(range(num_classes)))
    
    # Handle case where not all classes are present in test set
    if y_true_bin.shape[1] != num_classes:
        # Re-create with correct shape if needed
        temp = np.zeros((len(y_true), num_classes))
        temp[:, :y_true_bin.shape[1]] = y_true_bin
        y_true_bin = temp

    # Compute ROC curve and AUC for each class
    fpr = {}
    tpr = {}
    roc_auc = {}
    
    for i in range(num_classes):
        # Chỉ tính nếu class đó có ít nhất 1 sample positive
        if y_true_bin[:, i].sum() > 0:
            try:
                fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_proba[:, i])
                roc_auc[i] = auc(fpr[i], tpr[i])
            except Exception as e:
                print(f"  ⚠ Lỗi tính ROC class {i}: {e}")
                roc_auc[i] = None
        else:
            roc_auc[i] = None
    
    # Sort by AUC to get top-k
    valid_aucs = [(i, val) for i, val in roc_auc.items() if val is not None]
    # Sắp xếp giảm dần, xử lý trường hợp val có thể bị NaN dù đã nan_to_num (ít gặp)
    valid_aucs.sort(key=lambda x: x[1] if not np.isnan(x[1]) else -1, reverse=True)
    
    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # Plot 1: All ROC curves (Limit to top 20 to avoid clutter)
    colors = plt.cm.tab20(np.linspace(0, 1, min(20, num_classes)))
    plot_limit = min(20, len(valid_aucs))
    
    for idx, (i, auc_val) in enumerate(valid_aucs[:plot_limit]): 
        label = class_names[i] if class_names else f"Class {i}"
        val_str = f"{auc_val:.3f}" if auc_val is not None else "NaN"
        axes[0].plot(fpr[i], tpr[i], color=colors[idx % 20], lw=1, alpha=0.7,
                    label=f'{label} (AUC={val_str})')
    
    axes[0].plot([0, 1], [0, 1], 'k--', lw=1, label='Random')
    axes[0].set_xlabel('False Positive Rate')
    axes[0].set_ylabel('True Positive Rate')
    axes[0].set_title(f'ROC Curves - Top {plot_limit} Classes')
    axes[0].legend(loc='lower right', fontsize=7)
    axes[0].grid(True, alpha=0.3)
    
    # Plot 2: Top-k classes specific
    actual_top_k = min(top_k, len(valid_aucs))
    for idx, (i, auc_val) in enumerate(valid_aucs[:actual_top_k]):
        label = class_names[i] if class_names else f"Class {i}"
        val_str = f"{auc_val:.3f}" if auc_val is not None else "NaN"
        axes[1].plot(fpr[i], tpr[i], lw=1, label=f'{label} (AUC={val_str})')
    
    axes[1].plot([0, 1], [0, 1], 'k--', lw=1, label='Random')
    axes[1].set_xlabel('False Positive Rate')
    axes[1].set_ylabel('True Positive Rate')
    axes[1].set_title(f'ROC Curves - Top {actual_top_k} Classes by AUC')
    axes[1].legend(loc='lower right')
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"  💾 Saved ROC curves: {save_path}")
    
    plt.show()
    
    # Print AUC summary
    print(f"\n📊 AUC Summary (Top {actual_top_k}):")
    for i, (cls_idx, auc_val) in enumerate(valid_aucs[:actual_top_k]):
        label = class_names[cls_idx] if class_names else f"Class {cls_idx}"
        val_str = f"{auc_val:.4f}" if auc_val is not None else "NaN"
        print(f"  {i+1}. {label}: {val_str}")
    
    return roc_auc

In [None]:
def plot_training_curves(history: Dict, save_path: str = None):
    """
    Vẽ biểu đồ training curves với tất cả metrics.
    """
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    rounds = range(len(history["test_loss"]))
    
    # Plot 1: Loss
    axes[0, 0].plot(rounds, history["test_loss"], 'b-o', label="Test Loss")
    axes[0, 0].plot(rounds, history["train_loss"], 'r-s', label="Train Loss")
    axes[0, 0].set_xlabel("Round")
    axes[0, 0].set_ylabel("Loss")
    axes[0, 0].set_title("Loss")
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Plot 2: Accuracy
    acc = [a * 100 for a in history["test_accuracy"]]
    axes[0, 1].plot(rounds, acc, 'g-o', linewidth=1)
    axes[0, 1].set_xlabel("Round")
    axes[0, 1].set_ylabel("Accuracy (%)")
    axes[0, 1].set_title("Test Accuracy")
    axes[0, 1].set_ylim([0, 100])
    axes[0, 1].grid(True, alpha=0.3)
    
    # Plot 3: F1 Score
    f1_macro = [f * 100 for f in history["test_f1_macro"]]
    f1_weighted = [f * 100 for f in history["test_f1_weighted"]]
    axes[0, 2].plot(rounds, f1_macro, 'purple', marker='o', label="F1 Macro")
    axes[0, 2].plot(rounds, f1_weighted, 'orange', marker='s', label="F1 Weighted")
    axes[0, 2].set_xlabel("Round")
    axes[0, 2].set_ylabel("F1 Score (%)")
    axes[0, 2].set_title("F1 Score")
    axes[0, 2].legend()
    axes[0, 2].grid(True, alpha=0.3)
    
    # Plot 4: Precision & Recall
    prec = [p * 100 for p in history["test_precision_macro"]]
    rec = [r * 100 for r in history["test_recall_macro"]]
    axes[1, 0].plot(rounds, prec, 'b-o', label="Precision (Macro)")
    axes[1, 0].plot(rounds, rec, 'r-s', label="Recall (Macro)")
    axes[1, 0].set_xlabel("Round")
    axes[1, 0].set_ylabel("Score (%)")
    axes[1, 0].set_title("Precision & Recall")
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Plot 5: AUC
    auc_vals = history.get("test_auc_macro", [])
    if auc_vals and auc_vals[0] is not None:
        auc_pct = [a * 100 if a else 0 for a in auc_vals]
        axes[1, 1].plot(rounds, auc_pct, 'teal', marker='o', linewidth=1)
        axes[1, 1].set_xlabel("Round")
        axes[1, 1].set_ylabel("AUC (%)")
        axes[1, 1].set_title("AUC (Macro OvR) ")
        axes[1, 1].grid(True, alpha=0.3)
    else:
        axes[1, 1].text(0.5, 0.5, "AUC not computed", ha='center', va='center')
        axes[1, 1].set_title("AUC (Macro OvR)")
    
    # Plot 6: All metrics comparison (final values)
    final_metrics = {
        "Accuracy": history["test_accuracy"][-1] * 100,
        "F1 Macro": history["test_f1_macro"][-1] * 100,
        "F1 Weighted": history["test_f1_weighted"][-1] * 100,
        "Precision": history["test_precision_macro"][-1] * 100,
        "Recall": history["test_recall_macro"][-1] * 100,
    }
    if auc_vals and auc_vals[-1] is not None:
        final_metrics["AUC"] = auc_vals[-1] * 100
    
    bars = axes[1, 2].bar(final_metrics.keys(), final_metrics.values(), 
                          color=['green', 'purple', 'orange', 'blue', 'red', 'teal'][:len(final_metrics)])
    axes[1, 2].set_ylabel("Score (%)")
    axes[1, 2].set_title("Final Metrics Summary")
    axes[1, 2].set_ylim([0, 100])
    for bar, val in zip(bars, final_metrics.values()):
        axes[1, 2].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, 
                       f'{val:.1f}%', ha='center', va='bottom', fontsize=9)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"💾 Saved training curves: {save_path}")
    
    plt.show()
    return fig

## Save Results


In [None]:
# Save final model & history
# Dùng cùng folder name với checkpoint (algorithm_num_clients)
folder_name = ckpt_manager.get_folder_name()
out_dir = os.path.join(CONFIG["output_dir"], folder_name)
os.makedirs(out_dir, exist_ok=True)

algo_name = CONFIG["algorithm"]

# Save model với tên tự động: <algo>_global_model_<time>.pth
model_filename = make_auto_filename(
    prefix=f"{algo_name}_global_model",
    ext="pth",
)
model_path = os.path.join(out_dir, model_filename)
torch.save(server.global_model.state_dict(), model_path)
print(f"Saved model: {model_path}")

# Save full history với tên tự động: <algo>_global_history_<time>.json
hist_filename = make_auto_filename(
    prefix=f"{algo_name}_global_history",
    ext="json",
)
hist_path = os.path.join(out_dir, hist_filename)
with open(hist_path, "w") as f:
    json.dump(history, f, indent=2)
print(f"Saved history: {hist_path}")

# Plot all training curves: <algo>_training_curves_<time>.png
curves_filename = make_auto_filename(
    prefix=f"{algo_name}_training_curves",
    ext="png",
)
plot_training_curves(history, save_path=os.path.join(out_dir, curves_filename))

# Plot ROC curves: <algo>_roc_curves_<time>.png
roc_filename = make_auto_filename(
    prefix=f"{algo_name}_roc_curves",
    ext="png",
)
plot_roc_curves(server, num_classes, save_path=os.path.join(out_dir, roc_filename))

# Save run_info.json với metadata tổng hợp
def save_run_info(out_dir, config, history, start_time, end_time, 
                  model_filename, hist_filename, curves_filename, roc_filename, cm_filename=None):
    """Lưu thông tin tổng hợp của lần train"""
    # Tính best metrics
    best_auc = None
    if history.get("test_auc_macro"):
        auc_values = [a for a in history["test_auc_macro"] if a is not None]
        if auc_values:
            best_auc = max(auc_values) * 100
    
    run_info = {
        "run_id": ckpt_manager.get_run_id(),
        "algorithm": config["algorithm"],
        "timestamp": {
            "start": start_time.isoformat(),
            "end": end_time.isoformat(),
            "duration_seconds": (end_time - start_time).total_seconds(),
            "duration_minutes": (end_time - start_time).total_seconds() / 60
        },
        "config": {
            "num_clients": config["num_clients"],
            "num_rounds": config["num_rounds"],
            "local_epochs": config["local_epochs"],
            "learning_rate": config["learning_rate"],
            "batch_size": config["batch_size"],
            "mu": config.get("mu"),
            "server_momentum": config.get("server_momentum"),
            "server_lr": config.get("server_lr"),
        },
        "best_metrics": {
            "best_accuracy": max(history["test_accuracy"]) * 100,
            "best_f1_macro": max(history["test_f1_macro"]) * 100,
            "best_f1_weighted": max(history["test_f1_weighted"]) * 100,
            "best_auc": best_auc,
            "final_accuracy": history["test_accuracy"][-1] * 100,
            "final_f1_macro": history["test_f1_macro"][-1] * 100,
        },
        "files": {
            "model": model_filename,
            "history": hist_filename,
            "training_curves": curves_filename,
            "roc_curves": roc_filename,
        }
    }
    
    if cm_filename:
        run_info["files"]["confusion_matrix"] = cm_filename
    
    info_path = os.path.join(out_dir, "run_info.json")
    with open(info_path, "w") as f:
        json.dump(run_info, f, indent=2)
    print(f"💾 Saved run info: {info_path}")
    return run_info

# Lưu run_info (cm_filename sẽ được thêm sau khi vẽ confusion matrix)
run_info = save_run_info(out_dir, CONFIG, history, start_time, end_time,
                         model_filename, hist_filename, curves_filename, roc_filename)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

def generate_confusion_matrix(server, class_names=None, out_dir=None):
    """
    Dự đoán trên tập test toàn cục và vẽ Confusion Matrix.
    
    Args:
        server: FederatedServerV2 instance
        class_names: List tên các class (optional)
        out_dir: Thư mục để lưu file (nếu None thì dùng CONFIG["output_dir"])
    
    Returns:
        cm_filename: Tên file đã lưu
    """
    print("\nGenering Confusion Matrix...")
    server.global_model.eval()
    
    X_test = server.test_data['X_test']
    y_test = server.test_data['y_test']
    batch_size = 4096 # Dùng batch lớn để infer cho nhanh
    
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for i in tqdm(range(0, len(y_test), batch_size), desc="Inferencing"):
            X_batch = X_test[i:i+batch_size]
            y_batch = y_test[i:i+batch_size]
            
            outputs = server.global_model(X_batch)
            _, predicted = torch.max(outputs.data, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(y_batch.cpu().numpy())
            
    # Tính toán Confusion Matrix
    cm = confusion_matrix(all_targets, all_preds)
    
    # Chuẩn bị vẽ
    plt.figure(figsize=(20, 16))
    
    # Nếu không có tên class cụ thể, tạo tên dạng Class 0, Class 1...
    if class_names is None:
        num_classes = cm.shape[0]
        class_names = [f"Class {i}" for i in range(num_classes)]
    
    # Vẽ Heatmap
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    
    plt.title(f'Confusion Matrix ( Rounds: {len(server.history["test_loss"])})', fontsize=16)
    plt.ylabel('True Label', fontsize=14)
    plt.xlabel('Predicted Label', fontsize=14)
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    
    # Lưu ảnh với tên tự động: <algo>_confusion_matrix_<time>.png
    algo_name = CONFIG["algorithm"]
    cm_filename = make_auto_filename(
        prefix=f"{algo_name}_confusion_matrix",
        ext="png",
    )
    # Lưu vào cùng folder với kết quả khác
    if out_dir is None:
        out_dir = CONFIG["output_dir"]
    save_path = os.path.join(out_dir, cm_filename)
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"Saved Confusion Matrix to: {save_path}")
    plt.show()
    
    return cm_filename

# --- CHẠY HÀM VẼ ---
# Nếu bạn có danh sách tên các loại tấn công (34 loại), hãy điền vào đây.
class_names_list = [
    "Benign",

    "ACK fragmentation",
    "UDP flood",
    "SlowLoris",
    "ICMP flood",
    "RSTFIN flood",
    "PSHACK flood",
    "HTTP flood",
    "UDP fragmentation",
    "TCP flood",
    "SYN flood",
    "SynonymousIP flood",

    
    "Dictionary brute force",

    
    "Arp spoofing",
    "DNS spoofing",

    
    "TCP flood (DoS)",
    "HTTP flood (DoS)",
    "SYN flood (DoS)",
    "UDP flood (DoS)",

    
    "Ping sweep",
    "OS scan",
    "Vulnerability scan",
    "Port scan",
    "Host discovery",

    
    "Sql injection",
    "Command injection",
    "Backdoor malware",
    "Uploading attack",
    "XSS",
    "Browser hijacking",

    
    "GREIP flood",
    "Greeth flood",
    "UDPPlain"
]

# Nếu không, để None nó sẽ tự đánh số 0-33.

# Vẽ và lưu Confusion Matrix vào cùng folder với kết quả khác
cm_filename = generate_confusion_matrix(server, class_names=class_names_list, out_dir=out_dir)

# Cập nhật run_info.json với cm_filename
run_info["files"]["confusion_matrix"] = cm_filename
info_path = os.path.join(out_dir, "run_info.json")
with open(info_path, "w") as f:
    json.dump(run_info, f, indent=2)
print(f"💾 Updated run_info.json with confusion matrix filename")

## Summary

### Key Features of FedLearning V2:

1. **✅ Compatible with step2-version3 format**

   - Loads `client_X_train.npz` (train only)
   - Loads `global_test_data.npz` (independent test)

2. **✅ Pre-loads all data into RAM**

   - Load once at startup
   - Fast training (no disk I/O per epoch)
   - Direct tensor operations

3. **✅ Optimized for large-scale FL**
   - Supports 1000+ clients
   - **Always uses ALL clients** (no client fraction sampling)
   - AMP support

### Performance Comparison:

| Metric          | V1 (DataLoader) | V2 (RAM)            |
| --------------- | --------------- | ------------------- |
| Data loading    | Per epoch       | Once                |
| Disk I/O        | High            | Low                 |
| Training speed  | ~1x             | ~2-3x faster        |
| RAM usage       | Low             | High                |
| Client sampling | Yes             | **No (always all)** |

**Best for**: Large RAM machines, repeated training runs, scenarios where all clients participate in every round
