

## High‑Level Flowcharts

### FedAvg

```text
[Start]
   ↓
[Load data + init global model]
   ↓
For each round t = 1..T
   ↓
[Select subset of clients C_t]
   ↓
[Send global weights w_t to each client i ∈ C_t]
   ↓
[Client i trains locally on its data D_i]
   ↓
[Client i sends back updated weights w_i^(t+1) and |D_i|]
   ↓
[Server aggregates weighted average]
   w_(t+1) = Σ_i ( |D_i| / Σ_j |D_j| ) · w_i^(t+1)
   ↓
[Update global model w_(t+1)]
   ↓
[Evaluate or stop]
```

### FedProx

Same communication pattern as FedAvg, but **client objective is different**:

```text
Client i local objective:

   minimize over w:
      F_i(w) + (μ/2) · || w − w_global ||²

where:
   F_i(w)  = local empirical loss on D_i (e.g., cross‑entropy)
   μ       = proximal coefficient (controls how close client stays to global model)
```

Implementation wise:

```python
loss = ce_loss + (mu / 2.0) * proximal_term
```

The server still aggregates exactly like FedAvg.


## Imports & Global Settings

In [1]:

import os
import time
import json
from datetime import datetime
from collections import OrderedDict
from typing import List, Dict, Tuple
import contextlib

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
try:
    from torch.amp import autocast as torch_autocast, GradScaler
except ImportError:
    from torch.cuda.amp import autocast as torch_autocast, GradScaler

import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from sklearn.metrics import (
    confusion_matrix,
    classification_report,
    precision_recall_fscore_support,
)
import pandas as pd

# Make prints nicer
torch.set_printoptions(linewidth=120, sci_mode=False)


## Config, Colab Optimizer & Device Setup

In [2]:

# -----------------------------------------------------------------------------
# Colab detection
# -----------------------------------------------------------------------------
def is_colab() -> bool:
    try:
        import google.colab  # type: ignore
        return True
    except Exception:
        return False


# -----------------------------------------------------------------------------
# Base CONFIG (you can edit this)
# -----------------------------------------------------------------------------
CONFIG = {
    # Path containing client_0_data.npz, client_1_data.npz, ...
    "data_dir": "/kaggle/input/fed-5clients",

    # Output
    "output_dir": "./fed_results",

    # Federation
    "num_clients": 5,
    "algorithm": "fedprox",   # "fedavg" or "fedprox"
    "client_fraction": 1.0,  # fraction of clients per round

    # Model IO
    "input_shape": None,     # auto-detected
    "num_classes": None,     # auto-detected

    # Training
    "num_rounds": 3,
    "local_epochs": 3,
    "learning_rate": 1e-3,
    "batch_size": 1024,
    "mu": 0.01,              # FedProx proximal coefficient

    # Device
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "force_gpu": True,

    # DataLoader
    "num_workers": 0,
    "pin_memory": False,
    "persistent_workers": False,

    # Eval
    "eval_every": 1,
}


def optimize_for_colab(config: Dict) -> Dict:
    """If running on Colab, adjust settings for T4 GPU & weak CPU."""
    if not is_colab():
        return config

    print("\n Detected Google Colab → applying Colab-optimized defaults...\n")

    config = dict(config)  # shallow copy
    config["batch_size"] = 1024
    config["num_workers"] = 0
    config["pin_memory"] = False
    config["persistent_workers"] = False
    config["force_gpu"] = True

    torch.backends.cudnn.benchmark = True

    return config


def setup_device(config: Dict) -> str:
    """Choose device according to availability + config."""
    if config.get("force_gpu", False) and not torch.cuda.is_available():
        print(" force_gpu=True but no CUDA device found → using CPU.")
        device = "cpu"
    elif torch.cuda.is_available() and config.get("device") == "cuda":
        device = "cuda"
        idx = torch.cuda.current_device()
        print(f" Using GPU: {torch.cuda.get_device_name(idx)}")
    else:
        device = "cpu"
        print(" Using CPU.")
    config["device"] = device
    return device


## Model: CNN‑GRU

In [3]:
# -------------------------------------------------------------------------
# CNN-GRU Model – kiến trúc giống DL.py (DeepFed), port sang PyTorch
# -------------------------------------------------------------------------
class CNN_GRU_Model(nn.Module):
    """
    CNN-GRU (CNN + GRU + MLP) – bản PyTorch bám sát kiến trúc trong DL.py
    - Input: (batch, 39)  hoặc (batch, 39, 1)
    - Output: logits có kích thước (batch, num_classes)
    """
    def __init__(self, input_shape, num_classes: int = 34):
        super().__init__()

        # input_shape có dạng (39,) → seq_length = 39
        if isinstance(input_shape, tuple):
            seq_length = input_shape[0]
        else:
            seq_length = int(input_shape)

        self.input_shape = input_shape
        self.num_classes = num_classes

        # ===== CNN MODULE =====
        # Block 1
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=64, kernel_size=3, padding=1)
        self.bn1   = nn.BatchNorm1d(64)
        self.pool1 = nn.MaxPool1d(kernel_size=2)
        self.dropout_cnn1 = nn.Dropout(0.2)

        # Block 2
        self.conv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn2   = nn.BatchNorm1d(128)
        self.pool2 = nn.MaxPool1d(kernel_size=2)
        self.dropout_cnn2 = nn.Dropout(0.2)

        # Block 3
        self.conv3 = nn.Conv1d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.bn3   = nn.BatchNorm1d(256)
        self.pool3 = nn.MaxPool1d(kernel_size=2)
        self.dropout_cnn3 = nn.Dropout(0.3)

        # Tính chiều dài sau CNN để biết kích thước flatten
        def conv_output_shape(L_in, kernel_size=1, stride=1, padding=0, dilation=1):
            return (L_in + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1

        cnn_len = seq_length
        cnn_len = conv_output_shape(cnn_len, kernel_size=3, stride=1, padding=1)  # conv1
        cnn_len = conv_output_shape(cnn_len, kernel_size=2, stride=2)            # pool1
        cnn_len = conv_output_shape(cnn_len, kernel_size=3, stride=1, padding=1)  # conv2
        cnn_len = conv_output_shape(cnn_len, kernel_size=2, stride=2)            # pool2
        cnn_len = conv_output_shape(cnn_len, kernel_size=3, stride=1, padding=1)  # conv3
        cnn_len = conv_output_shape(cnn_len, kernel_size=2, stride=2)            # pool3

        self.cnn_output_size = 256 * cnn_len  # 256 filters * chiều dài cuối

        # ===== GRU MODULE =====
        self.gru1 = nn.GRU(input_size=1,   hidden_size=128, batch_first=True)
        self.gru2 = nn.GRU(input_size=128, hidden_size=64,  batch_first=True)
        self.gru_output_size = 64

        # ===== MLP MODULE =====
        concat_size = self.cnn_output_size + self.gru_output_size

        self.dense1   = nn.Linear(concat_size, 256)
        self.bn_mlp1  = nn.BatchNorm1d(256)
        self.dropout1 = nn.Dropout(0.4)

        self.dense2   = nn.Linear(256, 128)
        self.bn_mlp2  = nn.BatchNorm1d(128)
        self.dropout2 = nn.Dropout(0.3)

        self.output = nn.Linear(128, num_classes)
        self.relu   = nn.ReLU()

    def forward(self, x):
        # x: (B, 39) hoặc (B, 39, 1)
        if x.ndim == 2:
            x = x.unsqueeze(-1)        # (B, 39, 1)

        batch_size = x.size(0)

        # ----- CNN branch -----
        x_cnn = x.permute(0, 2, 1)     # (B, 1, 39)
        x_cnn = self.pool1(self.relu(self.bn1(self.conv1(x_cnn))))
        x_cnn = self.dropout_cnn1(x_cnn)

        x_cnn = self.pool2(self.relu(self.bn2(self.conv2(x_cnn))))
        x_cnn = self.dropout_cnn2(x_cnn)

        x_cnn = self.pool3(self.relu(self.bn3(self.conv3(x_cnn))))
        x_cnn = self.dropout_cnn3(x_cnn)

        cnn_output = x_cnn.view(batch_size, -1)  # flatten

        # ----- GRU branch -----
        x_gru = x                              # (B, 39, 1)
        x_gru, _ = self.gru1(x_gru)
        x_gru, _ = self.gru2(x_gru)
        gru_output = x_gru[:, -1, :]           # lấy bước thời gian cuối (B, 64)

        # ----- CONCAT -----
        z = torch.cat([cnn_output, gru_output], dim=1)

        # ----- MLP -----
        z = self.dense1(z)
        if z.size(0) > 1:                      # BatchNorm yêu cầu batch_size > 1
            z = self.bn_mlp1(z)
        z = self.relu(z)
        z = self.dropout1(z)

        z = self.dense2(z)
        if z.size(0) > 1:
            z = self.bn_mlp2(z)
        z = self.relu(z)
        z = self.dropout2(z)

        logits = self.output(z)                # không softmax – dùng CrossEntropyLoss
        return logits


def build_model(input_shape, num_classes: int):
    print(f"Building CNN_GRU_Model with input_shape={input_shape}, num_classes={num_classes}")
    return CNN_GRU_Model(input_shape, num_classes=num_classes)


## Data Utilities (Dataset & Loaders)

In [4]:

# -----------------------------------------------------------------------------
# Dataset & loading utilities
# -----------------------------------------------------------------------------
class NumpyDataset(TensorDataset):
    """Wraps numpy arrays (X, y) into a TensorDataset."""
    def __init__(self, X: np.ndarray, y: np.ndarray):
        if X.ndim == 3:
            X = X.squeeze(-1)
        X = X.astype(np.float32)
        X_t = torch.from_numpy(X)
        y_t = torch.from_numpy(y).long()
        super().__init__(X_t, y_t)


def auto_detect_data_parameters(data_dir: str, num_clients: int) -> Tuple[Tuple[int], int, Dict]:
    """Scan client_x_data.npz to infer input_shape & num_classes."""
    print("\n Auto-detecting data parameters...")
    all_labels = []
    stats = {}
    total_train = 0
    total_test = 0

    path0 = os.path.join(data_dir, "client_0_data.npz")
    if not os.path.exists(path0):
        raise FileNotFoundError(path0)

    with np.load(path0) as d0:
        X0 = d0["X_train"]
        input_features = X0.shape[1]
        input_shape = (input_features,)

    for cid in range(num_clients):
        p = os.path.join(data_dir, f"client_{cid}_data.npz")
        with np.load(p) as d:
            X_tr, y_tr = d["X_train"], d["y_train"]
            X_te, y_te = d["X_test"], d["y_test"]

        total_train += len(X_tr)
        total_test += len(X_te)
        all_labels.append(y_tr)

        u, c = np.unique(y_tr, return_counts=True)
        stats[cid] = {
            "train_samples": int(len(X_tr)),
            "test_samples": int(len(X_te)),
            "num_labels": int(len(u)),
            "label_dist": {int(k): int(v) for k, v in zip(u, c)},
        }
        print(f"   Client {cid}: {len(X_tr):,} train, {len(X_te):,} test, {len(u)} labels")

    combined = np.concatenate(all_labels)
    num_classes = int(len(np.unique(combined)))

    print(f"\n input_shape = {input_shape}, num_classes = {num_classes}")
    print(f"   total train = {total_train:,}, total test = {total_test:,}")
    return input_shape, num_classes, stats


def load_federated_data(config: Dict):
    """Create train/test DataLoaders for each client."""
    data_dir = config["data_dir"]
    K = config["num_clients"]
    bs = config["batch_size"]
    num_workers = config.get("num_workers", 0)
    pin_memory = config.get("pin_memory", False)
    persistent_workers = config.get("persistent_workers", False)

    train_loaders, test_loaders = [], []
    for cid in range(K):
        p = os.path.join(data_dir, f"client_{cid}_data.npz")
        d = np.load(p)
        X_tr, y_tr = d["X_train"], d["y_train"]
        X_te, y_te = d["X_test"], d["y_test"]

        ds_tr = NumpyDataset(X_tr, y_tr)
        ds_te = NumpyDataset(X_te, y_te)

        tr_loader = DataLoader(
            ds_tr, batch_size=bs, shuffle=True, drop_last=False,
            num_workers=num_workers, pin_memory=pin_memory,
            persistent_workers=persistent_workers if num_workers > 0 else False,
        )
        te_loader = DataLoader(
            ds_te, batch_size=bs, shuffle=False, drop_last=False,
            num_workers=num_workers, pin_memory=pin_memory,
            persistent_workers=persistent_workers if num_workers > 0 else False,
        )
        train_loaders.append(tr_loader)
        test_loaders.append(te_loader)

    return train_loaders, test_loaders


## FederatedClient (FedAvg + FedProx, AMP)

In [5]:
# -------------------------------------------------------------------------
# Federated Client (FedAvg + FedProx) – AMP-compatible for PyTorch 2.x (Colab)
# -------------------------------------------------------------------------

# AMP import with fallback (Colab sometimes uses mixed API)
try:
    from torch.amp import autocast as torch_autocast, GradScaler
except ImportError:
    from torch.cuda.amp import autocast as torch_autocast, GradScaler


class FederatedClient:
    def __init__(self, client_id: int, model: nn.Module,
                 train_loader: DataLoader, test_loader: DataLoader,
                 device: str = "cpu"):
        self.client_id = client_id
        self.model = model.to(device)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device

        # AMP chỉ bật khi GPU CUDA tồn tại
        self.use_amp = (device == "cuda" and torch.cuda.is_available())

    # ----------------------------- AMP Context -----------------------------
    def _amp_ctx(self):
        """
        PyTorch >= 2.0 yêu cầu device_type bắt buộc: autocast(device_type="cuda").
        Đây là nơi gây lỗi của bạn. Bản này FIX hoàn toàn.
        """
        return (
            torch_autocast(device_type="cuda", dtype=torch.float16)
            if self.use_amp else contextlib.nullcontext()
        )

    # ---------------------------- FedAvg Training ---------------------------
    def train_fedavg(self, epochs: int, lr: float = 1e-3, verbose: bool = True) -> Dict:
        self.model.train()
        optimizer = optim.Adam(self.model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()

        # ✔ Không truyền device — PyTorch 2.x yêu cầu như vậy
        scaler = GradScaler(enabled=self.use_amp)

        total_loss = 0.0
        total_samples = 0

        for ep in range(epochs):
            ep_loss = 0.0
            ep_samples = 0

            loader = (
                tqdm(self.train_loader,
                     desc=f"[Client {self.client_id}] FedAvg Epoch {ep+1}/{epochs}",
                     unit="batch", leave=False)
                if verbose else self.train_loader
            )

            for data, target in loader:
                t0 = time.time()
                data, target = data.to(self.device), target.to(self.device)

                optimizer.zero_grad()

                # ---------------- AMP Region ----------------
                with self._amp_ctx():
                    out = self.model(data)
                    loss = criterion(out, target)

                if self.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    optimizer.step()
                # ---------------------------------------------

                bs = data.size(0)
                ep_loss += loss.item() * bs
                ep_samples += bs

                if verbose:
                    bt = (time.time() - t0) * 1000
                    loader.set_postfix(
                        loss=f"{loss.item():.4f}",
                        lr=f"{optimizer.param_groups[0]['lr']:.1e}",
                        bt=f"{bt:.0f}ms"
                    )

            total_loss += ep_loss
            total_samples += ep_samples
            if verbose:
                print(f"Client {self.client_id} Epoch {ep+1}: loss={ep_loss/ep_samples:.4f}")

        avg_loss = total_loss / max(1, total_samples)
        return {
            "client_id": self.client_id,
            "num_samples": total_samples // max(1, epochs),
            "loss": avg_loss
        }

    # ---------------------------- FedProx Training --------------------------
    def train_fedprox(self, epochs: int, global_params: OrderedDict,
                      mu: float = 0.01, lr: float = 1e-3,
                      verbose: bool = True) -> Dict:

        self.model.train()
        optimizer = optim.Adam(self.model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()

        scaler = GradScaler(enabled=self.use_amp)

        total_loss = 0.0
        total_samples = 0

        for ep in range(epochs):
            ep_ce = 0.0
            ep_samples = 0

            loader = (
                tqdm(self.train_loader,
                     desc=f"[Client {self.client_id}] FedProx Epoch {ep+1}/{epochs}",
                     unit="batch", leave=False)
                if verbose else self.train_loader
            )

            for data, target in loader:
                t0 = time.time()
                data, target = data.to(self.device), target.to(self.device)

                optimizer.zero_grad()

                with self._amp_ctx():
                    out = self.model(data)
                    ce_loss = criterion(out, target)

                    # Proximal term
                    prox = 0.0
                    for name, param in self.model.named_parameters():
                        if not param.requires_grad:
                            continue
                        gp = global_params[name].to(self.device)
                        prox += torch.sum((param - gp)**2)

                    loss = ce_loss + (mu / 2.0) * prox

                if self.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    optimizer.step()

                bs = data.size(0)
                ep_ce += ce_loss.item() * bs
                ep_samples += bs

                if verbose:
                    bt = (time.time() - t0) * 1000
                    loader.set_postfix(
                        ce=f"{ce_loss.item():.4f}",
                        lr=f"{optimizer.param_groups[0]['lr']:.1e}",
                        bt=f"{bt:.0f}ms"
                    )

            total_loss += ep_ce
            total_samples += ep_samples
            if verbose:
                print(f"Client {self.client_id} Epoch {ep+1}: CE={ep_ce/ep_samples:.4f}")

        avg_ce = total_loss / max(1, total_samples)
        return {
            "client_id": self.client_id,
            "num_samples": total_samples // max(1, epochs),
            "loss": avg_ce
        }

    # ----------------------------- Evaluation -------------------------------
    def evaluate(self) -> Dict:
        if self.test_loader is None:
            return {"accuracy": 0.0, "loss": 0.0, "num_samples": 0}

        self.model.eval()
        criterion = nn.CrossEntropyLoss()

        total_loss = 0.0
        total_samples = 0
        correct = 0

        with torch.no_grad():
            for data, target in self.test_loader:
                data, target = data.to(self.device), target.to(self.device)
                out = self.model(data)
                loss = criterion(out, target)

                total_loss += loss.item() * data.size(0)
                total_samples += data.size(0)
                correct += out.argmax(dim=1).eq(target).sum().item()

        acc = correct / total_samples if total_samples else 0.0
        avg_loss = total_loss / total_samples if total_samples else 0.0

        return {
            "accuracy": acc,
            "loss": avg_loss,
            "num_samples": total_samples
        }

    # ----------------------------- Param IO --------------------------------
    def get_model_params(self) -> OrderedDict:
        return OrderedDict((k, v.detach().cpu().clone())
                           for k, v in self.model.state_dict().items())

    def set_model_params(self, params: OrderedDict):
        self.model.load_state_dict(params)


## FederatedServer, Aggregation & Training Rounds

In [6]:

# -----------------------------------------------------------------------------
# Federated Server
# -----------------------------------------------------------------------------
class FederatedServer:
    def __init__(self, global_model: nn.Module, clients: List[FederatedClient],
                 client_test_loaders: List[DataLoader], device: str = "cpu"):
        self.global_model = global_model.to(device)
        self.clients = clients
        self.client_test_loaders = client_test_loaders
        self.device = device

        self.history = {"train_loss": [], "test_loss": [], "test_accuracy": []}

    def get_global_params(self) -> OrderedDict:
        return OrderedDict((k, v.detach().clone())
                           for k, v in self.global_model.state_dict().items())

    def set_global_params(self, params: OrderedDict):
        self.global_model.load_state_dict(params)

    def distribute_model(self, selected: List[FederatedClient]):
        g = self.get_global_params()
        for c in selected:
            c.set_model_params(g)

    # ---------------------------- Aggregation -----------------------
    def aggregate_fedavg(self, client_results: List[Dict]) -> OrderedDict:
        total_samples = sum(r["num_samples"] for r in client_results)
        agg = self.get_global_params()

        # zero out float params
        for k, v in agg.items():
            if v.dtype in (torch.float16, torch.float32, torch.float64):
                agg[k] = torch.zeros_like(v)

        for r in client_results:
            cid = r["client_id"]
            n_i = r["num_samples"]
            w_i = n_i / max(1, total_samples)
            client_params = self.clients[cid].get_model_params()

            for k in agg.keys():
                p = client_params[k]
                p = p.to(self.device)

                if p.dtype.is_floating_point:
                  agg[k] = agg[k] + w_i * p
                else:
                  agg[k] = p
        return agg

    # ---------------------------- One FedAvg round ------------------
    def train_round_fedavg(self, num_epochs: int, lr: float,
                           client_fraction: float, verbose: bool = True) -> Dict:
        n_clients = len(self.clients)
        n_sel = max(1, int(n_clients * client_fraction))
        selected = list(np.random.choice(self.clients, n_sel, replace=False))

        if verbose:
            print(f"→ FedAvg: selecting {n_sel}/{n_clients} clients.")

        self.distribute_model(selected)

        results = []
        for idx, c in enumerate(selected):
            if verbose:
                print(f"  • Training client {c.client_id} ({idx+1}/{n_sel})...")
            r = c.train_fedavg(num_epochs, lr, verbose)
            results.append(r)

        new_params = self.aggregate_fedavg(results)
        self.set_global_params(new_params)

        avg_loss = float(np.mean([r["loss"] for r in results])) if results else 0.0
        if verbose:
            print(f"→ Round train loss (FedAvg, avg clients): {avg_loss:.4f}")
        return {"train_loss": avg_loss, "num_clients": len(results)}

    # ---------------------------- One FedProx round -----------------
    def train_round_fedprox(self, num_epochs: int, mu: float, lr: float,
                            client_fraction: float, verbose: bool = True) -> Dict:
        n_clients = len(self.clients)
        n_sel = max(1, int(n_clients * client_fraction))
        selected = list(np.random.choice(self.clients, n_sel, replace=False))

        if verbose:
            print(f"→ FedProx: selecting {n_sel}/{n_clients} clients.")

        global_params = self.get_global_params()
        self.distribute_model(selected)

        results = []
        for idx, c in enumerate(selected):
            if verbose:
                print(f"  • Training client {c.client_id} ({idx+1}/{n_sel})...")
            r = c.train_fedprox(num_epochs, global_params, mu, lr, verbose)
            results.append(r)

        new_params = self.aggregate_fedavg(results)
        self.set_global_params(new_params)

        avg_ce = float(np.mean([r["loss"] for r in results])) if results else 0.0
        if verbose:
            print(f"→ Round train CE loss (FedProx, avg clients): {avg_ce:.4f}")
        return {"train_loss": avg_ce, "num_clients": len(results)}

    # ---------------------------- Global eval -----------------------
    def evaluate_global(self) -> Dict:
        self.global_model.eval()
        criterion = nn.CrossEntropyLoss()
        total_loss = 0.0
        total = 0
        correct = 0

        for idx, loader in enumerate(self.client_test_loaders):
            pbar = tqdm(loader, desc=f"[Eval] Client {idx}", unit="batch", leave=False)
            with torch.no_grad():
                for data, target in pbar:
                    data, target = data.to(self.device), target.to(self.device)
                    out = self.global_model(data)
                    loss = criterion(out, target)

                    total_loss += loss.item() * data.size(0)
                    pred = out.argmax(dim=1)
                    correct += pred.eq(target).sum().item()
                    total += data.size(0)

                    if total > 0:
                        pbar.set_postfix(
                            acc=f"{correct/total*100:.2f}%",
                            loss=f"{total_loss/total:.4f}"
                        )

        acc = correct / total if total else 0.0
        avg_loss = total_loss / total if total else 0.0
        return {"accuracy": acc, "loss": avg_loss}


## train_federated() Wrapper

In [7]:

# -----------------------------------------------------------------------------
# federated training loop
# -----------------------------------------------------------------------------
def train_federated(server: FederatedServer, config: Dict) -> Dict:
    algo = config["algorithm"].lower()
    R = config["num_rounds"]
    E = config["local_epochs"]
    lr = config["learning_rate"]
    cf = config["client_fraction"]
    eval_every = config["eval_every"]
    mu = config["mu"]

    history = server.history

    rounds = tqdm(range(R), desc="Global Rounds", unit="round")
    for ridx in rounds:
        print("\n" + "-" * 60)
        print(f"ROUND {ridx+1}/{R} ({algo})")
        print("-" * 60)

        if algo == "fedavg":
            r_res = server.train_round_fedavg(E, lr, cf, verbose=True)
        elif algo == "fedprox":
            r_res = server.train_round_fedprox(E, mu, lr, cf, verbose=True)
        else:
            raise ValueError(f"Unknown algorithm: {algo}")

        if (ridx + 1) % eval_every == 0:
            print("\n Evaluating global model...")
            e_res = server.evaluate_global()

            history["train_loss"].append(r_res["train_loss"])
            history["test_loss"].append(e_res["loss"])
            history["test_accuracy"].append(e_res["accuracy"])

            print(f"   Train loss: {r_res['train_loss']:.4f}")
            print(f"   Test  loss: {e_res['loss']:.4f}")
            print(f"   Test  acc : {e_res['accuracy']*100:.2f}%")

            rounds.set_postfix(
                algo=algo,
                train_loss=f"{r_res['train_loss']:.4f}",
                test_acc=f"{e_res['accuracy']*100:.2f}%",
            )

    return history


## Evaluation, Plotting & Saving

In [8]:

# -----------------------------------------------------------------------------
# Evaluation, plotting & saving utilities
# -----------------------------------------------------------------------------
def plot_history(history: Dict, save_path: str = None):
    if not history["test_accuracy"]:
        print(" No history to plot.")
        return

    rounds = range(len(history["test_loss"]))
    train_loss = history["train_loss"]
    test_loss = history["test_loss"]
    test_acc = [a * 100 for a in history["test_accuracy"]]

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    ax1 = axes[0]
    ax1.plot(rounds, test_loss, marker="o", label="Test loss")
    if len(train_loss) == len(test_loss):
        ax1.plot(rounds, train_loss, marker="s", label="Train loss (avg clients)")
    ax1.set_xlabel("Round"); ax1.set_ylabel("Loss"); ax1.grid(True, alpha=0.3)
    ax1.legend(); ax1.set_title("Loss vs round")

    ax2 = axes[1]
    ax2.plot(rounds, test_acc, marker="o", color="green", label="Test acc (%)")
    ax2.set_xlabel("Round"); ax2.set_ylabel("Accuracy (%)")
    ax2.set_ylim([0, 100]); ax2.grid(True, alpha=0.3)
    ax2.legend(); ax2.set_title("Accuracy vs round")

    plt.tight_layout()
    if save_path is not None:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=150, bbox_inches="tight")
        print(f" Saved training curves to: {save_path}")
    plt.show()


def final_evaluation_and_reports(server: FederatedServer, history: Dict,
                                 config: Dict, data_stats: Dict,
                                 output_dir: str,
                                 start_time: datetime, end_time: datetime):
    os.makedirs(output_dir, exist_ok=True)

    # Save global model
    model_path = os.path.join(output_dir, "global_model.pth")
    torch.save(server.global_model.state_dict(), model_path)
    print(f" Saved global model to {model_path}")

    # Save history
    hist_path = os.path.join(output_dir, "training_history.json")
    with open(hist_path, "w") as f:
        json.dump(history, f, indent=2)
    print(f" Saved history to {hist_path}")

    # Plot curves
    plot_history(history, os.path.join(output_dir, "training_curves.png"))

    # Global predictions (for classification report & confusion matrix)
    all_true, all_pred = [], []
    server.global_model.eval()
    with torch.no_grad():
        for loader in server.client_test_loaders:
            for data, target in loader:
                data, target = data.to(server.device), target.to(server.device)
                out = server.global_model(data)
                pred = out.argmax(dim=1)
                all_true.append(target.cpu().numpy())
                all_pred.append(pred.cpu().numpy())

    y_true = np.concatenate(all_true)
    y_pred = np.concatenate(all_pred)

    labels = list(range(config["num_classes"]))
    target_names = [str(l) for l in labels]

    rep = classification_report(
        y_true, y_pred, labels=labels, target_names=target_names,
        digits=4, zero_division=0,
    )
    print("\n" + "="*80)
    print("CLASSIFICATION REPORT")
    print("="*80)
    print(rep)

    rep_path = os.path.join(output_dir, "classification_report.txt")
    with open(rep_path, "w") as f:
        f.write(rep)
    print(f" Saved classification report to {rep_path}")

    cm = confusion_matrix(y_true, y_pred, labels=labels)
    fig, ax = plt.subplots(figsize=(8, 6))
    im = ax.imshow(cm, cmap="Blues")
    plt.colorbar(im, ax=ax)
    ax.set_xlabel("Predicted"); ax.set_ylabel("True")
    ax.set_title("Confusion Matrix (Global Model)")
    ax.set_xticks(range(len(labels))); ax.set_yticks(range(len(labels)))
    ax.set_xticklabels(target_names, rotation=90)
    ax.set_yticklabels(target_names)
    for i in range(len(labels)):
        for j in range(len(labels)):
            ax.text(j, i, cm[i, j], ha="center", va="center", color="black", fontsize=7)
    plt.tight_layout()
    cm_path = os.path.join(output_dir, "confusion_matrix.png")
    plt.savefig(cm_path, dpi=200, bbox_inches="tight")
    print(f" Saved confusion matrix to {cm_path}")
    plt.show()

    prec, rec, f1, sup = precision_recall_fscore_support(
        y_true, y_pred, labels=labels, average=None, zero_division=0,
    )
    df = pd.DataFrame({
        "class": target_names,
        "precision": prec,
        "recall": rec,
        "f1": f1,
        "support": sup,
    })
    metrics_path = os.path.join(output_dir, "metrics_per_class.csv")
    df.to_csv(metrics_path, index=False)
    print(f" Saved per-class metrics to {metrics_path}")

    # Summary
    dur = (end_time - start_time).total_seconds()
    summary_path = os.path.join(output_dir, "SUMMARY.txt")
    with open(summary_path, "w", encoding="utf-8") as f:
        f.write("FEDERATED LEARNING SUMMARY\n")
        f.write("="*80 + "\n\n")
        f.write(f"Algorithm      : {config['algorithm']}\n")
        f.write(f"Num rounds     : {config['num_rounds']}\n")
        f.write(f"Local epochs   : {config['local_epochs']}\n")
        f.write(f"Batch size     : {config['batch_size']}\n")
        f.write(f"Learning rate  : {config['learning_rate']}\n")
        if config['algorithm'] == 'fedprox':
            f.write(f"mu (proximal)  : {config['mu']}\n")
        f.write(f"Num clients    : {config['num_clients']}\n")
        f.write(f"Input features : {config['input_shape'][0]}\n")
        f.write(f"Num classes    : {config['num_classes']}\n")
        f.write(f"Device         : {config['device']}\n\n")
        if history['test_accuracy']:
            f.write(f"Final test acc : {history['test_accuracy'][-1]*100:.2f}%\n")
            f.write(f"Final test loss: {history['test_loss'][-1]:.4f}\n\n")
        f.write(f"Total time     : {dur:.2f}s ({dur/60:.2f} min)\n")
    print(f" Saved summary to {summary_path}")


## Main() Entry Point

In [None]:

# -----------------------------------------------------------------------------
# Main: wiring everything together
# -----------------------------------------------------------------------------
def initialize_system(config: Dict, device: str):
    input_shape, num_classes, data_stats = auto_detect_data_parameters(
        config["data_dir"], config["num_clients"]
    )
    config["input_shape"] = input_shape
    config["num_classes"] = num_classes

    train_loaders, test_loaders = load_federated_data(config)

    # Global model
    global_model = build_model(input_shape, num_classes).to(device)
    base_state = global_model.state_dict()

    # Clients
    clients = []
    for cid in range(config["num_clients"]):
        m = build_model(input_shape, num_classes)
        m.load_state_dict(base_state)
        c = FederatedClient(
            client_id=cid,
            model=m,
            train_loader=train_loaders[cid],
            test_loader=test_loaders[cid],
            device=device,
        )
        clients.append(c)

    server = FederatedServer(global_model, clients, test_loaders, device)
    return server, clients, data_stats


def main():
    # Config & device
    config = optimize_for_colab(CONFIG)
    device = setup_device(config)

    # Output dir
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    out_dir = os.path.join(config["output_dir"], f"run_{ts}_{config['algorithm']}")
    os.makedirs(out_dir, exist_ok=True)
    config["output_dir"] = out_dir

    print("\nFinal config:\n", json.dumps(config, indent=2, default=str))

    start_time = datetime.now()

    # Init system
    server, clients, data_stats = initialize_system(config, device)

    # Train
    history = train_federated(server, config)

    end_time = datetime.now()

    # Final eval + save
    final_evaluation_and_reports(
        server, history, config, data_stats, out_dir, start_time, end_time
    )
    print("\n Done!")



if __name__ == "__main__":
  main()



 Detected Google Colab → applying Colab-optimized defaults...

 Using GPU: Tesla T4

Final config:
 {
  "data_dir": "/kaggle/input/fed-5clients",
  "output_dir": "./fed_results/run_20251116_021441_fedprox",
  "num_clients": 5,
  "algorithm": "fedprox",
  "client_fraction": 1.0,
  "input_shape": null,
  "num_classes": null,
  "num_rounds": 3,
  "local_epochs": 3,
  "learning_rate": 0.001,
  "batch_size": 1024,
  "mu": 0.01,
  "device": "cuda",
  "force_gpu": true,
  "num_workers": 0,
  "pin_memory": false,
  "persistent_workers": false,
  "eval_every": 1
}

 Auto-detecting data parameters...
   Client 0: 6,910,741 train, 2,961,747 test, 34 labels
   Client 1: 1,897,023 train, 813,010 test, 32 labels
   Client 2: 7,610,796 train, 3,261,771 test, 34 labels
   Client 3: 7,632,755 train, 3,271,181 test, 34 labels
   Client 4: 7,462,140 train, 3,198,060 test, 34 labels

 input_shape = (39,), num_classes = 34
   total train = 31,513,455, total test = 13,505,769
Building CNN_GRU_Model with in

Global Rounds:   0%|          | 0/3 [00:00<?, ?round/s]


------------------------------------------------------------
ROUND 1/3 (fedprox)
------------------------------------------------------------
→ FedProx: selecting 5/5 clients.
  • Training client 0 (1/5)...


[Client 0] FedProx Epoch 1/3:   0%|          | 0/6749 [00:00<?, ?batch/s]

Client 0 Epoch 1: CE=0.5564


[Client 0] FedProx Epoch 2/3:   0%|          | 0/6749 [00:00<?, ?batch/s]

Client 0 Epoch 2: CE=0.5454


[Client 0] FedProx Epoch 3/3:   0%|          | 0/6749 [00:00<?, ?batch/s]

Client 0 Epoch 3: CE=0.5456
  • Training client 3 (2/5)...


[Client 3] FedProx Epoch 1/3:   0%|          | 0/7454 [00:00<?, ?batch/s]

Client 3 Epoch 1: CE=0.7022


[Client 3] FedProx Epoch 2/3:   0%|          | 0/7454 [00:00<?, ?batch/s]

Client 3 Epoch 2: CE=0.6921


[Client 3] FedProx Epoch 3/3:   0%|          | 0/7454 [00:00<?, ?batch/s]

Client 3 Epoch 3: CE=0.6915
  • Training client 2 (3/5)...


[Client 2] FedProx Epoch 1/3:   0%|          | 0/7433 [00:00<?, ?batch/s]

Client 2 Epoch 1: CE=0.1674


[Client 2] FedProx Epoch 2/3:   0%|          | 0/7433 [00:00<?, ?batch/s]

Client 2 Epoch 2: CE=0.1581


[Client 2] FedProx Epoch 3/3:   0%|          | 0/7433 [00:00<?, ?batch/s]

Client 2 Epoch 3: CE=0.1581
  • Training client 4 (4/5)...


[Client 4] FedProx Epoch 1/3:   0%|          | 0/7288 [00:00<?, ?batch/s]

Client 4 Epoch 1: CE=0.3921


[Client 4] FedProx Epoch 2/3:   0%|          | 0/7288 [00:00<?, ?batch/s]

Client 4 Epoch 2: CE=0.3830


[Client 4] FedProx Epoch 3/3:   0%|          | 0/7288 [00:00<?, ?batch/s]

Client 4 Epoch 3: CE=0.3829
  • Training client 1 (5/5)...


[Client 1] FedProx Epoch 1/3:   0%|          | 0/1853 [00:00<?, ?batch/s]

Client 1 Epoch 1: CE=0.5321


[Client 1] FedProx Epoch 2/3:   0%|          | 0/1853 [00:00<?, ?batch/s]

Client 1 Epoch 2: CE=0.4894


[Client 1] FedProx Epoch 3/3:   0%|          | 0/1853 [00:00<?, ?batch/s]

Client 1 Epoch 3: CE=0.4880
→ Round train CE loss (FedProx, avg clients): 0.4590

 Evaluating global model...


[Eval] Client 0:   0%|          | 0/2893 [00:00<?, ?batch/s]

[Eval] Client 1:   0%|          | 0/794 [00:00<?, ?batch/s]

[Eval] Client 2:   0%|          | 0/3186 [00:00<?, ?batch/s]

[Eval] Client 3:   0%|          | 0/3195 [00:00<?, ?batch/s]

[Eval] Client 4:   0%|          | 0/3124 [00:00<?, ?batch/s]

   Train loss: 0.4590
   Test  loss: 1.0032
   Test  acc : 70.42%

------------------------------------------------------------
ROUND 2/3 (fedprox)
------------------------------------------------------------
→ FedProx: selecting 5/5 clients.
  • Training client 1 (1/5)...


[Client 1] FedProx Epoch 1/3:   0%|          | 0/1853 [00:00<?, ?batch/s]

Client 1 Epoch 1: CE=0.4799


[Client 1] FedProx Epoch 2/3:   0%|          | 0/1853 [00:00<?, ?batch/s]

Client 1 Epoch 2: CE=0.4710


[Client 1] FedProx Epoch 3/3:   0%|          | 0/1853 [00:00<?, ?batch/s]

Client 1 Epoch 3: CE=0.4712
  • Training client 2 (2/5)...


[Client 2] FedProx Epoch 1/3:   0%|          | 0/7433 [00:00<?, ?batch/s]

Client 2 Epoch 1: CE=0.1425


[Client 2] FedProx Epoch 2/3:   0%|          | 0/7433 [00:00<?, ?batch/s]