In [None]:
# Enable GPU check
import torch
print("GPU Available:", torch.cuda.is_available())
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")


In [None]:
# Install missing libraries
!pip install barbar tqdm scikit-learn pillow matplotlib pandas


In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')


In [None]:
# =========================
# Imports
# =========================
import os
import random
import time
import logging
import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split

from sklearn.metrics import roc_auc_score
from barbar import Bar

import torch.backends.cudnn as cudnn
cudnn.benchmark = True

# =========================
# Logger
# =========================
def log(path, file):
    log_file = os.path.join(path, file)
    os.makedirs(path, exist_ok=True)

    logging.basicConfig(level=logging.INFO, format="%(message)s")
    logger = logging.getLogger()

    handler = logging.FileHandler(log_file)
    handler.setLevel(logging.INFO)
    handler.setFormatter(logging.Formatter("%(asctime)s: %(message)s"))
    logger.addHandler(handler)

    return logger

# =========================
# Config (Colab Version)
# =========================
class Config(object):
    def __init__(self):
        self.name = 'fed_chexpert_colab'

        self.base_path = '/content/drive/MyDrive/chexpert'
        self.save_path = f'{self.base_path}/ckpt'

        self.train_csv = f'{self.base_path}/chexpert-train.csv'
        self.valid_csv = f'{self.base_path}/chexpert-valid.csv'
        self.test_csv  = f'{self.base_path}/chexpert-test.csv'

        self.model_name = 'resnet18'
        self.pre_train = True

        self.img_size = 224
        self.batch_size = 16
        self.lr = 1e-4
        self.num_classes = 14

        self.num_workers = 2
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Federated Learning
        self.num_clients = 5
        self.client_epoch = 1
        self.com_round = 3
        self.fraction = 1.0

        os.makedirs(self.save_path, exist_ok=True)
        self.logger = log(self.save_path, f'{self.name}.log')

opt = Config()
opt.logger.info("Config Loaded")

# =========================
# Dataset
# =========================
class CheXpertDataSet(Dataset):
    def __init__(self, df, class_names, transform, policy="zeroes"):
        self.image_filepaths = df["Path"].values
        self.class_names = class_names
        self.transform = transform

        labels = []
        for c in class_names:
            labels.append(df[c].values)
        self.labels = np.array(labels).T.astype(np.float32)

        if policy == "zeroes":
            self.labels[self.labels == -1] = 0
        else:
            self.labels[self.labels == -1] = 1

    def __len__(self):
        return len(self.image_filepaths)

    def __getitem__(self, idx):
        img_path = os.path.join(opt.base_path, self.image_filepaths[idx])
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)
        label = self.labels[idx]
        return image, label

# =========================
# Transforms
# =========================
def get_transforms():
    return transforms.Compose([
        transforms.Resize((opt.img_size, opt.img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
    ])

# =========================
# Load Datasets
# =========================
def get_dataloaders():
    train_df = pd.read_csv(opt.train_csv).fillna(-1)
    valid_df = pd.read_csv(opt.valid_csv).fillna(-1)
    test_df  = pd.read_csv(opt.test_csv).fillna(-1)

    classes = [
        'No Finding','Enlarged Cardiomediastinum','Cardiomegaly','Lung Opacity',
        'Lung Lesion','Edema','Consolidation','Pneumonia','Atelectasis',
        'Pneumothorax','Pleural Effusion','Pleural Other','Fracture','Support Devices'
    ]

    transform = get_transforms()

    train_set = CheXpertDataSet(train_df, classes, transform)
    val_set   = CheXpertDataSet(valid_df, classes, transform)
    test_set  = CheXpertDataSet(test_df, classes, transform, policy="ones")

    # Small split for Colab
    client_sets = random_split(train_set, [1000]*opt.num_clients)

    train_loaders = [
        DataLoader(cs, batch_size=opt.batch_size, shuffle=True,
                   num_workers=opt.num_workers, pin_memory=True)
        for cs in client_sets
    ]

    val_loader = DataLoader(val_set, batch_size=opt.batch_size)
    test_loader = DataLoader(test_set, batch_size=1)

    return client_sets, train_loaders, val_loader, test_loader

# =========================
# Model
# =========================
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torchvision.models.resnet18(pretrained=opt.pre_train)
        in_features = self.model.fc.in_features
        self.model.fc = nn.Sequential(
            nn.Linear(in_features, opt.num_classes),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# =========================
# Metrics
# =========================
def compute_auroc(gt, pred):
    gt = gt.cpu().numpy()
    pred = pred.cpu().numpy()
    scores = []
    for i in range(opt.num_classes):
        try:
            scores.append(roc_auc_score(gt[:, i], pred[:, i]))
        except:
            pass
    return np.mean(scores)

# =========================
# Train / Val
# =========================
def train_epoch(model, loader, optimizer, loss_fn):
    model.train()
    total_loss = 0
    for x, y in Bar(loader):
        x, y = x.to(opt.device), y.to(opt.device)
        out = model(x)
        loss = loss_fn(out, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def validate(model, loader, loss_fn):
    model.eval()
    gt, pred = [], []
    loss_total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(opt.device), y.to(opt.device)
            out = model(x)
            loss_total += loss_fn(out, y).item()
            gt.append(y)
            pred.append(out)
    gt = torch.cat(gt)
    pred = torch.cat(pred)
    return loss_total / len(loader), compute_auroc(gt, pred)

# =========================
# Federated Training
# =========================
def main():
    client_sets, train_loaders, val_loader, _ = get_dataloaders()
    loss_fn = nn.BCELoss()

    global_model = Classifier().to(opt.device)
    client_models = [Classifier().to(opt.device) for _ in range(opt.num_clients)]

    for rnd in range(opt.com_round):
        opt.logger.info(f"\n===== ROUND {rnd+1} =====")
        client_weights = []

        for i in range(opt.num_clients):
            client_models[i].load_state_dict(global_model.state_dict())
            optimizer = optim.Adam(client_models[i].parameters(), lr=opt.lr)

            train_epoch(client_models[i], train_loaders[i], optimizer, loss_fn)
            client_weights.append(client_models[i].state_dict())

        # FedAvg
        new_state = {}
        for k in global_model.state_dict().keys():
            new_state[k] = sum(w[k] for w in client_weights) / len(client_weights)

        global_model.load_state_dict(new_state)

        val_loss, val_auc = validate(global_model, val_loader, loss_fn)
        opt.logger.info(f"Validation Loss: {val_loss:.4f}, AUROC: {val_auc:.4f}")

    torch.save(global_model.state_dict(), f"{opt.save_path}/global_model.pth")
    opt.logger.info("Training Complete")

# =========================
# Run
# =========================
if __name__ == "__main__":
    main()
