# Privacy-Preserving Federated Vision Transformers for Pneumonia Detection on Chest X-Rays

This notebook implements a federated learning framework using Vision Transformers (ViT) for pneumonia classification on chest X-ray images under non-IID data distribution.
Differentially Private Federated Learning (DP-FL) is implemented in selected experiments using DP-SGD with explicit privacy budgets (ε).


## 1. Project Overview

This project explores privacy-aware federated learning for medical image classification.
Instead of centralizing sensitive chest X-ray data, model training is performed across multiple simulated clients, each holding local, non-identically distributed (non-IID) data.

The objective is to evaluate the robustness of Vision Transformers in a federated setting for pneumonia detection.


## 2. Environment Setup and Dependencies

This section defines the required Python libraries and deep learning frameworks used in the project.
The implementation is framework-agnostic and can be executed locally or in cloud environments such as Google Colab.


In [None]:
!pip install timm kagglehub opacus

import os
import numpy as np
from collections import defaultdict

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

from sklearn.metrics import accuracy_score, precision_recall_fscore_support

import kagglehub

from opacus import PrivacyEngine

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


## 3. Dataset Description and Preprocessing

A publicly available chest X-ray dataset is used for binary classification (Pneumonia vs Normal).
Standard preprocessing steps such as resizing, normalization, and data augmentation are applied to ensure consistency across clients.


In [None]:
# Download latest version of Chest X-Ray Pneumonia dataset
path = kagglehub.dataset_download("paultimothymooney/chest-xray-pneumonia")

print("Root path:", path)
print("Contents:", os.listdir(path))

data_root = os.path.join(path, "chest_xray")
print("Data root:", data_root)
print("Data root contents:", os.listdir(data_root))


**Transforms and dataset objects**

In [None]:
img_size = 224

train_tfms = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

test_tfms = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

train_ds_full = datasets.ImageFolder(os.path.join(data_root, "train"), transform=train_tfms)
val_ds        = datasets.ImageFolder(os.path.join(data_root, "val"),   transform=test_tfms)
test_ds       = datasets.ImageFolder(os.path.join(data_root, "test"),  transform=test_tfms)

print("Train size:", len(train_ds_full))
print("Val size:  ", len(val_ds))
print("Test size: ", len(test_ds))
print("Classes:   ", train_ds_full.classes)  # should be ['NORMAL', 'PNEUMONIA']


## 4. Non-IID Client Data Partitioning

To simulate real-world clinical scenarios, the dataset is partitioned across multiple federated clients using a non-IID strategy.
Each client contains a distinct class distribution, reflecting data heterogeneity commonly observed across hospitals and medical institutions.


NON-IID client splits:

Client 1: pneumonia-heavy

Client 2: normal-heavy

Client 3: remaining mixed

In [None]:
# Build indices per class for non-IID split
class_to_indices = defaultdict(list)
for idx, (_, label) in enumerate(train_ds_full.samples):
    class_to_indices[label].append(idx)

normal_indices = np.array(class_to_indices[0])  # label 0 = NORMAL
pneu_indices   = np.array(class_to_indices[1])  # label 1 = PNEUMONIA

np.random.shuffle(normal_indices)
np.random.shuffle(pneu_indices)

num_normal = len(normal_indices)
num_pneu   = len(pneu_indices)

print("Total NORMAL:", num_normal)
print("Total PNEUMONIA:", num_pneu)

# Client 1: pneumonia-heavy
c1_pneu = int(0.5 * num_pneu)      # ~50% of pneumonia cases
c1_norm = int(0.1 * num_normal)    # small subset of normals

# Client 2: normal-heavy
c2_pneu = int(0.1 * num_pneu)
c2_norm = int(0.5 * num_normal)

# Remaining for Client 3
used_norm = c1_norm + c2_norm
used_pneu = c1_pneu + c2_pneu

c3_norm_indices = normal_indices[used_norm:]
c3_pneu_indices = pneu_indices[used_pneu:]

c1_indices = np.concatenate([pneu_indices[:c1_pneu], normal_indices[:c1_norm]])
c2_indices = np.concatenate([
    pneu_indices[c1_pneu:c1_pneu + c2_pneu],
    normal_indices[c1_norm:c1_norm + c2_norm]
])
c3_indices = np.concatenate([c3_pneu_indices, c3_norm_indices])

np.random.shuffle(c1_indices)
np.random.shuffle(c2_indices)
np.random.shuffle(c3_indices)

client_indices = [c1_indices, c2_indices, c3_indices]
client_datasets = [Subset(train_ds_full, idxs) for idxs in client_indices]

for i, idxs in enumerate(client_indices):
    labels = [train_ds_full.samples[j][1] for j in idxs]
    n_norm = sum(1 for l in labels if l == 0)
    n_pneu = sum(1 for l in labels if l == 1)
    print(f"Client {i}: total={len(idxs)}, NORMAL={n_norm}, PNEUMONIA={n_pneu}")


**DataLoaders for clients, val, test**

In [None]:
batch_size = 32

client_loaders = [
    DataLoader(cd, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    for cd in client_datasets
]

val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)


## 5. Vision Transformer Model Architecture

A Vision Transformer (ViT) architecture is employed to model global contextual relationships within chest X-ray images.
The model is adapted for binary classification and serves as the base learner across all federated clients.


**ViT model, training & evaluation helpers**

In [None]:
import timm

def create_vit_model():
    model = timm.create_model("vit_tiny_patch16_224", pretrained=True)
    in_features = model.head.in_features
    model.head = nn.Linear(in_features, 2)
    return model.to(device)

def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0.0

    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * imgs.size(0)

    return total_loss / len(loader.dataset)

def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * imgs.size(0)

            preds = outputs.argmax(dim=1).detach().cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.detach().cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    prec, rec, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')
    return total_loss / len(loader.dataset), acc, prec, rec, f1


## 6. Federated Learning Framework

Federated learning is implemented by iteratively training local models on client-specific data and aggregating model parameters at a central server.
This framework ensures that raw medical images remain localized at each client throughout the training process.


**FedAvg helper functions**

In [None]:
def get_model_weights(model):
    raw_state = model.state_dict()
    clean_state = {}
    for k, v in raw_state.items():
        if k.startswith("_module."):
            new_key = k[len("_module."):]
        else:
            new_key = k
        clean_state[new_key] = v.detach().cpu().clone()
    return clean_state


def set_model_weights(model, weights):
    model.load_state_dict(weights)

def average_weights(weights_list):
    avg = {}
    for k in weights_list[0].keys():
        stacked = torch.stack([w[k] for w in weights_list], dim=0)
        avg[k] = stacked.mean(dim=0)
    return avg


**Experiment 1: Centralized baseline**

In [None]:
criterion = nn.CrossEntropyLoss()

central_model = create_vit_model()
optimizer = optim.Adam(central_model.parameters(), lr=1e-4)

central_history = {"epoch": [], "val_acc": [], "val_f1": []}

num_epochs = 5

central_train_loader = DataLoader(
    train_ds_full, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True
)

for epoch in range(1, num_epochs + 1):
    train_loss = train_one_epoch(central_model, central_train_loader, optimizer, criterion)
    val_loss, acc, prec, rec, f1 = evaluate(central_model, val_loader, criterion)
    central_history["epoch"].append(epoch)
    central_history["val_acc"].append(acc)
    central_history["val_f1"].append(f1)

    print(f"[Central] Epoch {epoch:02d}: train_loss={train_loss:.4f}, "
          f"val_acc={acc:.4f}, val_f1={f1:.4f}")

central_test_loss, central_test_acc, central_test_prec, central_test_rec, central_test_f1 = \
    evaluate(central_model, test_loader, criterion)

print("\n[Centralized ViT] Test results:")
print(f"  loss={central_test_loss:.4f}, acc={central_test_acc:.4f}, "
      f"prec={central_test_prec:.4f}, rec={central_test_rec:.4f}, "
      f"f1={central_test_f1:.4f}")

torch.save(central_model.state_dict(), "/content/central_vit.pth")

**Experiment 2: Federated ViT on NON-IID (no DP)**

In [None]:
rounds = 5
local_epochs = 5
lr = 1e-4

global_model = create_vit_model()
criterion = nn.CrossEntropyLoss()

fed_history = {"round": [], "val_acc": [], "val_f1": []}

for r in range(1, rounds + 1):
    print(f"\n=== Federated Round {r} (non-DP) ===")
    client_weights = []

    for cid, loader in enumerate(client_loaders):
        print(f"  Client {cid}:")
        client_model = create_vit_model()
        set_model_weights(client_model, get_model_weights(global_model))

        optimizer = optim.Adam(client_model.parameters(), lr=lr)

        for e in range(1, local_epochs + 1):
            train_loss = train_one_epoch(client_model, loader, optimizer, criterion)
            print(f"    local_epoch {e}: loss={train_loss:.4f}")

        client_weights.append(get_model_weights(client_model))
        del client_model
        torch.cuda.empty_cache()

    new_global_weights = average_weights(client_weights)
    set_model_weights(global_model, new_global_weights)

    val_loss, acc, prec, rec, f1 = evaluate(global_model, val_loader, criterion)
    fed_history["round"].append(r)
    fed_history["val_acc"].append(acc)
    fed_history["val_f1"].append(f1)

    print(f"[Global after round {r}] val_acc={acc:.4f}, val_f1={f1:.4f}")

fl_nondp_test_loss, fl_nondp_test_acc, fl_nondp_test_prec, fl_nondp_test_rec, fl_nondp_test_f1 = \
    evaluate(global_model, test_loader, criterion)

print("\n[Federated ViT (non-DP, non-IID)] Test results:")
print(f"  loss={fl_nondp_test_loss:.4f}, acc={fl_nondp_test_acc:.4f}, "
      f"prec={fl_nondp_test_prec:.4f}, rec={fl_nondp_test_rec:.4f}, "
      f"f1={fl_nondp_test_f1:.4f}")

torch.save(global_model.state_dict(), "/content/fed_vit.pth")

**Helper for DP-SGD Optimizer per client**

In [None]:
def make_dp_client(model, loader, lr, local_epochs, target_epsilon, target_delta=1e-5, max_grad_norm=1.0):
    """
    Wraps model, optimizer, and DataLoader with Opacus DP-SGD for a given target epsilon.
    """
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    privacy_engine = PrivacyEngine()

    model, optimizer, dp_loader = privacy_engine.make_private_with_epsilon(
        module=model,
        optimizer=optimizer,
        data_loader=loader,
        epochs=local_epochs,
        target_epsilon=target_epsilon,
        target_delta=target_delta,
        max_grad_norm=max_grad_norm,
    )

    return model, optimizer, dp_loader, privacy_engine


**Experiment 3: DP-FL for ε = 8, 3, 1**

In [None]:
epsilons = [8.0, 3.0, 1.0]
dp_results = {}

for target_eps in epsilons:
    print(f"\n\n==============================")
    print(f" DP-FL RUN for target epsilon = {target_eps}")
    print(f"==============================")

    global_model = create_vit_model()
    dp_history = {"round": [], "val_acc": [], "val_f1": []}

    for r in range(1, rounds + 1):
        print(f"\n=== DP Federated Round {r} (target eps={target_eps}) ===")
        client_weights = []

        for cid, loader in enumerate(client_loaders):
            print(f"  Client {cid}:")
            client_model = create_vit_model()
            set_model_weights(client_model, get_model_weights(global_model))

            # Make client DP
            client_model, optimizer, dp_loader, privacy_engine = make_dp_client(
                client_model, loader, lr=lr, local_epochs=local_epochs,
                target_epsilon=target_eps, target_delta=1e-5, max_grad_norm=1.0
            )

            for e in range(1, local_epochs + 1):
                train_loss = train_one_epoch(client_model, dp_loader, optimizer, criterion)
                print(f"    local_epoch {e}: loss={train_loss:.4f}")

            eps_spent = privacy_engine.get_epsilon(delta=1e-5)
            print(f"    Approx epsilon spent so far (client {cid}): {eps_spent:.2f}")

            client_weights.append(get_model_weights(client_model))
            del client_model, optimizer, dp_loader, privacy_engine
            torch.cuda.empty_cache()

        new_global_weights = average_weights(client_weights)
        set_model_weights(global_model, new_global_weights)

        val_loss, acc, prec, rec, f1 = evaluate(global_model, val_loader, criterion)
        dp_history["round"].append(r)
        dp_history["val_acc"].append(acc)
        dp_history["val_f1"].append(f1)

        print(f"[Global after round {r} | DP eps={target_eps}] val_acc={acc:.4f}, val_f1={f1:.4f}")

    # After all rounds for this epsilon, evaluate on test set
    test_loss, test_acc, test_prec, test_rec, test_f1 = evaluate(global_model, test_loader, criterion)

    print(f"\n[DP-FL ViT] target eps={target_eps} | Test results:")
    print(f"  loss={test_loss:.4f}, acc={test_acc:.4f}, "
          f"prec={test_prec:.4f}, rec={test_rec:.4f}, f1={test_f1:.4f}")

    dp_results[target_eps] = {
        "test_loss": test_loss,
        "test_acc": test_acc,
        "test_prec": test_prec,
        "test_rec": test_rec,
        "test_f1": test_f1,
        "val_acc_per_round": dp_history["val_acc"],
        "val_f1_per_round": dp_history["val_f1"]
    }

    # ✅ Save the final global model for this epsilon
    if target_eps == 3.0:
        torch.save(global_model.state_dict(), "/content/dp_vit_eps3.pth")
        print("Saved DP-FL model for eps=3.0 to /content/dp_vit_eps3.pth")
    elif target_eps == 1.0:
        torch.save(global_model.state_dict(), "/content/dp_vit_eps1.pth")
        print("Saved DP-FL model for eps=1.0 to /content/dp_vit_eps1.pth")
    elif target_eps == 8.0:
        torch.save(global_model.state_dict(), "/content/dp_vit_eps8.pth")
        print("Saved DP-FL model for eps=8.0 to /content/dp_vit_eps8.pth")


## Experimental Results and Observations

The experimental results highlight the feasibility of training Vision Transformers in a federated learning environment.
Observed performance trends across communication rounds provide insights into convergence behavior under data heterogeneity.


In [None]:
print("\n========== SUMMARY ==========\n")
print("Centralized ViT:")
print(f"  acc={central_test_acc:.4f}, f1={central_test_f1:.4f}")

print("\nNon-DP Federated ViT (non-IID):")
print(f"  acc={fl_nondp_test_acc:.4f}, f1={fl_nondp_test_f1:.4f}")

for eps in epsilons:
    r = dp_results[eps]
    print(f"\nDP-FL ViT (target eps={eps}):")
    print(f"  acc={r['test_acc']:.4f}, f1={r['test_f1']:.4f}, "
          f"prec={r['test_prec']:.4f}, rec={r['test_rec']:.4f}")


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# ===== FINAL METRICS FROM LATEST RUN =====
methods = [
    "Centralized",
    "Federated",
    "DP-FL (ε=8.0)",
    "DP-FL (ε=3.0)",
    "DP-FL (ε=1.0)"
]

accuracy = np.array([0.8029, 0.8333, 0.7821, 0.6490, 0.7308])
precision = np.array([0.7613, 0.7907, 0.8068, 0.6472, 0.7185])
recall    = np.array([0.9974, 0.9974, 0.8564, 0.9641, 0.9359])
f1_scores = np.array([0.8635, 0.8821, 0.8308, 0.7745, 0.8129])

# For privacy–utility curve (only DP-FL runs)
epsilons = np.array([8.0, 3.0, 1.0])
acc_dp   = np.array([0.7821, 0.6490, 0.7308])
f1_dp    = np.array([0.8308, 0.7745, 0.8129])

# ===== FIGURE 1: Accuracy comparison =====
plt.figure()
plt.bar(methods, accuracy)
plt.ylim(0, 1.0)
plt.ylabel("Accuracy")
plt.title("Accuracy Comparison Across Methods")
plt.xticks(rotation=25)
plt.tight_layout()
plt.savefig("fig_accuracy.png", dpi=300)
plt.close()

# ===== FIGURE 2: F1-score comparison =====
plt.figure()
plt.bar(methods, f1_scores)
plt.ylim(0, 1.0)
plt.ylabel("F1 Score")
plt.title("F1-Score Comparison Across Methods")
plt.xticks(rotation=25)
plt.tight_layout()
plt.savefig("fig_f1score.png", dpi=300)
plt.close()

# ===== FIGURE 3: Privacy–performance trade-off (ε vs F1 & Acc) =====
plt.figure()
plt.plot(epsilons, acc_dp, marker="o", linestyle="-", label="Accuracy")
plt.plot(epsilons, f1_dp,  marker="s", linestyle="--", label="F1 Score")
plt.ylim(0, 1.0)
plt.xlabel("Privacy Budget (ε)")
plt.ylabel("Score")
plt.title("Privacy–Performance Trade-off for DP-FL ViT")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.savefig("fig_privacy_tradeoff.png", dpi=300)
plt.close()

print("Saved: fig_accuracy.png, fig_f1score.png, fig_privacy_tradeoff.png")


In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1) Helper: load model from checkpoint using existing create_vit_model()
def load_model_from_ckpt(path):
    model = create_vit_model()  # same as in training
    state = torch.load(path, map_location=device)
    model.load_state_dict(state)
    model.to(device)
    model.eval()
    return model

central_model = load_model_from_ckpt("/content/central_vit.pth")
fed_model     = load_model_from_ckpt("/content/fed_vit.pth")
dp_model      = load_model_from_ckpt("/content/dp_vit_eps3.pth")  # eps = 3.0

# 2) Get one pneumonia image from the test set
pneumonia_label = 1  # adjust if Pneumonia != 1 in code
pneumonia_img = None

for images, labels in test_loader:
    idx = (labels == pneumonia_label).nonzero(as_tuple=True)[0]
    if len(idx) > 0:
        i = idx[0].item()
        pneumonia_img = images[i]  # [C,H,W]
        break

if pneumonia_img is None:
    raise RuntimeError("Could not find a pneumonia sample in test_loader!")

# 3) Saliency map generator (gradient-based)
def generate_saliency_heatmap(model, image_tensor, target_class=1):
    """
    image_tensor: [C, H, W], normalized, on CPU
    target_class: index of pneumonia class
    """
    model.eval()
    x = image_tensor.unsqueeze(0).to(device)  # [1, C, H, W]
    x.requires_grad_(True)

    with torch.enable_grad():
        logits = model(x)          # [1, num_classes]
        score = logits[0, target_class]

    model.zero_grad()
    score.backward()
    grad = x.grad.detach().cpu().numpy()[0]  # [C,H,W]

    saliency = np.mean(np.abs(grad), axis=0)
    saliency -= saliency.min()
    if saliency.max() > 0:
        saliency /= saliency.max()
    return saliency  # [H,W] in [0,1]

# 4) Denormalize for visualization (adjust if used different mean/std)
imagenet_mean = np.array([0.485, 0.456, 0.406]).reshape(3,1,1)
imagenet_std  = np.array([0.229, 0.224, 0.225]).reshape(3,1,1)

img_np = pneumonia_img.numpy()
img_denorm = (img_np * imagenet_std + imagenet_mean)
img_denorm = np.clip(img_denorm, 0.0, 1.0)
img_gray = img_denorm.mean(axis=0)  # HxW grayscale-ish

# 5) Generate saliency for 3 models
sal_central = generate_saliency_heatmap(central_model, pneumonia_img, target_class=pneumonia_label)
sal_fed     = generate_saliency_heatmap(fed_model,     pneumonia_img, target_class=pneumonia_label)
sal_dp      = generate_saliency_heatmap(dp_model,      pneumonia_img, target_class=pneumonia_label)

# 6) Plot triptych
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
titles = ["Centralized ViT", "Federated ViT", "DP-FL ViT (ε=3.0)"]
sal_maps = [sal_central, sal_fed, sal_dp]

for ax, title, sal in zip(axes, titles, sal_maps):
    ax.imshow(img_gray, cmap="gray")
    ax.imshow(sal, cmap="jet", alpha=0.5)
    ax.set_title(title)
    ax.axis("off")

plt.tight_layout()
plt.savefig("attention_triptych.png", dpi=300, bbox_inches="tight")
plt.close()

print("Saved: attention_triptych.png")


## Limitations and Future Directions

This implementation focuses on simulated federated learning and does not include formal differential privacy guarantees or secure aggregation protocols.
Future work may explore integrating differential privacy, secure aggregation, and real-world federated deployment scenarios.


## Privacy Considerations:

This work implements formally differentially private federated learning using DP-SGD via the Opacus library.
Per-sample gradients are clipped and perturbed with Gaussian noise, and privacy budgets (ε) are explicitly
computed using a Rényi Differential Privacy accountant in Experiments 3 and 4.
These mechanisms provide formal client-level differential privacy guarantees against gradient-based
inference attacks under the stated experimental assumptions.

## Disclaimer

This project is intended solely for academic and educational purposes.
It is not a clinical decision-support system and should not be used for medical diagnosis or treatment.
