In [None]:
# ============================================================
# Imports
# ============================================================
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split

from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, roc_curve, auc, confusion_matrix
)

# ============================================================
# Configuration
# ============================================================
class Config:
    def __init__(self):
        self.base_path = "/workspace/DATASETS/CheXpert-v1.0-small"
        self.train_csv = f"{self.base_path}/chexpert-train.csv"
        self.valid_csv = f"{self.base_path}/chexpert-valid.csv"

        self.save_path = "./ckpt/fed_chexpert_visual"
        os.makedirs(self.save_path, exist_ok=True)

        self.batch_size = 32
        self.img_size = 224
        self.lr = 1e-4
        self.num_classes = 14

        self.num_clients = 5
        self.client_epoch = 2
        self.rounds = 5

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

opt = Config()

CLASS_NAMES = [
    'No Finding','Enlarged Cardiomediastinum','Cardiomegaly','Lung Opacity',
    'Lung Lesion','Edema','Consolidation','Pneumonia','Atelectasis',
    'Pneumothorax','Pleural Effusion','Pleural Other','Fracture','Support Devices'
]

# ============================================================
# Dataset
# ============================================================
class CheXpertDataset(Dataset):
    def __init__(self, df, transform):
        self.paths = df["Path"].values
        self.labels = df[CLASS_NAMES].values.astype(np.float32)
        self.labels[self.labels == -1] = 0
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(os.path.join(opt.base_path, self.paths[idx])).convert("RGB")
        return self.transform(img), self.labels[idx]

# ============================================================
# Data Loaders
# ============================================================
def get_dataloaders():
    transform = transforms.Compose([
        transforms.Resize((opt.img_size, opt.img_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])

    train_df = pd.read_csv(opt.train_csv).fillna(-1)
    val_df   = pd.read_csv(opt.valid_csv).fillna(-1)

    train_set = CheXpertDataset(train_df, transform)
    val_set   = CheXpertDataset(val_df, transform)

    client_sets = random_split(
        train_set,
        [len(train_set)//opt.num_clients] * opt.num_clients
    )

    train_loaders = [
        DataLoader(cs, batch_size=opt.batch_size, shuffle=True)
        for cs in client_sets
    ]

    val_loader = DataLoader(val_set, batch_size=opt.batch_size)

    return train_loaders, val_loader

# ============================================================
# Model
# ============================================================
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = torchvision.models.resnet18(pretrained=True)
        self.net.fc = nn.Sequential(
            nn.Linear(self.net.fc.in_features, opt.num_classes),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)

# ============================================================
# Training (Client-side)
# ============================================================
def train_client(model, loader):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=opt.lr)
    loss_fn = nn.BCELoss()

    total_loss, correct, total = 0, 0, 0

    for _ in range(opt.client_epoch):
        for x, y in loader:
            x, y = x.to(opt.device), y.to(opt.device)
            out = model(x)
            loss = loss_fn(out, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            correct += ((out >= 0.5) == y).sum().item()
            total += y.numel()

    return total_loss / len(loader), correct / total

# ============================================================
# Metrics & Visualization Helpers
# ============================================================
def tune_thresholds(gt, pred):
    thresholds = []
    for c in range(opt.num_classes):
        best_f1, best_t = 0, 0.5
        for t in np.linspace(0.1, 0.9, 50):
            f1 = f1_score(gt[:,c], (pred[:,c] >= t), zero_division=0)
            if f1 > best_f1:
                best_f1, best_t = f1, t
        thresholds.append(best_t)
    return thresholds

def plot_training_curves(history):
    r = history["round"]
    plt.figure(figsize=(12,4))

    plt.subplot(1,2,1)
    plt.plot(r, history["train_loss"], label="Train Loss")
    plt.plot(r, history["val_loss"], label="Val Loss")
    plt.legend()
    plt.title("Loss vs Round")

    plt.subplot(1,2,2)
    plt.plot(r, history["train_acc"], label="Train Acc")
    plt.plot(r, history["val_acc"], label="Val Acc")
    plt.legend()
    plt.title("Accuracy vs Round")

    plt.show()

def plot_auc_roc(gt, pred):
    plt.figure(figsize=(7,6))
    for i, cname in enumerate(CLASS_NAMES):
        fpr, tpr, _ = roc_curve(gt[:,i], pred[:,i])
        auc_score = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f"{cname} ({auc_score:.2f})")
    plt.plot([0,1],[0,1],'k--')
    plt.legend(fontsize=6)
    plt.title("AUC-ROC Curves")
    plt.show()

def plot_confusion_matrices(confusion):
    for cname, cm in confusion.items():
        mat = np.array([[cm["TN"], cm["FP"]],
                        [cm["FN"], cm["TP"]]])
        plt.figure(figsize=(3,3))
        sns.heatmap(mat, annot=True, fmt="d", cmap="Blues",
                    xticklabels=["Pred 0","Pred 1"],
                    yticklabels=["True 0","True 1"])
        plt.title(cname)
        plt.show()

# ============================================================
# MAIN
# ============================================================
def main():
    train_loaders, val_loader = get_dataloaders()

    global_model = Model().to(opt.device)
    clients = [Model().to(opt.device) for _ in range(opt.num_clients)]

    history = {
        "round": [], "train_loss": [], "train_acc": [],
        "val_loss": [], "val_acc": [], "val_auc": []
    }

    for rnd in range(opt.rounds):
        print(f"\n===== ROUND {rnd+1} =====")

        client_states, losses, accs = [], [], []

        # ---------- Client Training ----------
        for i in range(opt.num_clients):
            clients[i].load_state_dict(global_model.state_dict())
            loss, acc = train_client(clients[i], train_loaders[i])
            losses.append(loss)
            accs.append(acc)
            client_states.append(clients[i].state_dict())

        # ---------- FedAvg ----------
        new_state = {}
        for k in global_model.state_dict():
            new_state[k] = sum(cs[k] for cs in client_states) / opt.num_clients
        global_model.load_state_dict(new_state)

        # ---------- Validation ----------
        global_model.eval()
        GT, PRED = [], []

        with torch.no_grad():
            for x, y in val_loader:
                out = global_model(x.to(opt.device))
                GT.append(y.numpy())
                PRED.append(out.cpu().numpy())

        GT = np.vstack(GT)
        PRED = np.vstack(PRED)

        thresholds = tune_thresholds(GT, PRED)
        pred_bin = (PRED >= thresholds).astype(int)

        val_acc = accuracy_score(GT, pred_bin)
        val_auc = np.mean([
            roc_auc_score(GT[:,i], PRED[:,i])
            for i in range(opt.num_classes)
        ])

        confusion = {}
        for i, cname in enumerate(CLASS_NAMES):
            tn, fp, fn, tp = confusion_matrix(GT[:,i], pred_bin[:,i]).ravel()
            confusion[cname] = {"TP":tp,"FP":fp,"FN":fn,"TN":tn}

        history["round"].append(rnd+1)
        history["train_loss"].append(np.mean(losses))
        history["train_acc"].append(np.mean(accs))
        history["val_loss"].append(1 - val_acc)
        history["val_acc"].append(val_acc)
        history["val_auc"].append(val_auc)

        print(f"Val Acc: {val_acc:.4f}, Val AUC: {val_auc:.4f}")

    # ---------- Visualizations ----------
    plot_training_curves(history)
    plot_auc_roc(GT, PRED)
    plot_confusion_matrices(confusion)

    print("\nâœ… Training + Visualization Completed")

# ============================================================
# Entry Point
# ============================================================
if __name__ == "__main__":
    main()
