In [None]:
import os
import numpy as np
import pandas as pd
import pathlib
import imageio

In [None]:

from google.colab import drive
drive.mount('/content/drive')

DATASET_DIR = "/content/drive/MyDrive/chest_xray"

import os, shutil, random, glob, zipfile, pathlib
from pathlib import Path

assert os.path.exists(DATASET_DIR), f"Path not found: {DATASET_DIR}"
print("Using dataset root:", DATASET_DIR)


In [None]:
import torch
import torch.nn as nn
from torchvision import models

# ---- CBAM MODULES ----
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.mlp = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.mlp(self.avg_pool(x))
        max_out = self.mlp(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        assert kernel_size in (3, 7)
        padding = 3 if kernel_size == 7 else 1
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        x = self.conv(x_cat)
        return self.sigmoid(x)


class CBAMBlock(nn.Module):
    def __init__(self, channels, ratio=16, kernel_size=7):
        super().__init__()
        self.ca = ChannelAttention(channels, ratio)
        self.sa = SpatialAttention(kernel_size)

    def forward(self, x):
        out = self.ca(x) * x
        out = self.sa(out) * out
        return out


# ---- RESNET18 + CBAM WRAPPER ----
class ResNet18_CBAM(nn.Module):
    def __init__(self, num_classes=2, pretrained=True):
        super().__init__()
        base = models.resnet18(
            weights=models.ResNet18_Weights.DEFAULT if pretrained else None
        )

        # Insert a CBAM block after the last residual block (layer4)
        # layer4 output has 512 channels in ResNet18
        self.backbone = base
        self.cbam = CBAMBlock(512)

        # Replace classifier to match num_classes
        in_features = base.fc.in_features
        self.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        # Same forward as ResNet18 until layer4
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)   # [B,512,H,W]

        # Apply CBAM attention over layer4 feature map
        x = self.cbam(x)

        x = self.backbone.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


In [None]:
import os, zipfile, glob, random, shutil
from pathlib import Path

def maybe_unzip(root):
    zips = glob.glob(os.path.join(root, "*.zip"))
    if zips:
        for z in zips:
            print("Unzipping:", z)
            with zipfile.ZipFile(z, 'r') as zip_ref:
                zip_ref.extractall(root)

maybe_unzip(DATASET_DIR)
candidates = [
    DATASET_DIR,
    os.path.join(DATASET_DIR, "chest_xray"),
]

def has_split(root):
    return all(os.path.isdir(os.path.join(root, s)) for s in ["train","val","test"])

ROOT = None
for c in candidates:
    if has_split(c):
        ROOT = c
        break

if ROOT is None:
    for c in candidates:
        if all(os.path.isdir(os.path.join(c, s)) for s in ["train","test"]):
            ROOT = c
            val_dir = os.path.join(ROOT, "val")
            os.makedirs(val_dir, exist_ok=True)
            for cls in os.listdir(os.path.join(ROOT, "train")):
                src = os.path.join(ROOT, "train", cls)
                dst = os.path.join(val_dir, cls)
                os.makedirs(dst, exist_ok=True)
                files = [f for f in glob.glob(os.path.join(src, "*")) if os.path.isfile(f)]
                random.shuffle(files)
                take = max(1, int(0.1*len(files)))
                for f in files[:take]:
                    shutil.move(f, os.path.join(dst, os.path.basename(f)))
            break

if ROOT is None:
    classes = [d for d in os.listdir(DATASET_DIR) if os.path.isdir(os.path.join(DATASET_DIR, d))]
    if set(map(str.lower, classes)) >= {"normal","pneumonia"}:
        ROOT = os.path.join(DATASET_DIR, "_split")
        if not os.path.exists(ROOT):
            print("Creating train/val/test splits (80/10/10) from flat class folders...")
            for split in ["train","val","test"]:
                for cls in classes:
                    os.makedirs(os.path.join(ROOT, split, cls), exist_ok=True)
            for cls in classes:
                files = [f for f in glob.glob(os.path.join(DATASET_DIR, cls, "*")) if os.path.isfile(f)]
                random.shuffle(files)
                n = len(files)
                n_train = int(0.8*n); n_val = int(0.1*n)
                for i,f in enumerate(files):
                    if i < n_train:
                        dst = os.path.join(ROOT, "train", cls, os.path.basename(f))
                    elif i < n_train + n_val:
                        dst = os.path.join(ROOT, "val", cls, os.path.basename(f))
                    else:
                        dst = os.path.join(ROOT, "test", cls, os.path.basename(f))
                    shutil.copy2(f, dst)

assert ROOT is not None, "Could not detect a valid dataset structure. Ensure folders are one of the expected layouts."
print("Detected dataset root with splits:", ROOT)
!find "$ROOT" -maxdepth 2 -type d -print


In [None]:
import os, torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

IMG_SIZE = 224
BATCH_SIZE = 32
NUM_WORKERS = 2

train_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

val_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

train_ds = datasets.ImageFolder(os.path.join(ROOT,"train"), transform=train_tfms)
val_ds   = datasets.ImageFolder(os.path.join(ROOT,"val"),   transform=val_tfms)
test_ds  = datasets.ImageFolder(os.path.join(ROOT,"test"),  transform=val_tfms)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

class_names = train_ds.classes
len(train_ds), len(val_ds), len(test_ds), class_names


In [None]:
import torch
from collections import Counter

targets = [y for _, y in train_ds.samples]
counts = Counter(targets)
num_classes = len(class_names)
total = sum(counts.values())
class_weights = torch.zeros(num_classes, dtype=torch.float)
for c in range(num_classes):
    class_weights[c] = total / (num_classes * counts[c])
print("Class counts:", counts)
print("Class weights:", class_weights)


In [None]:

import matplotlib.pyplot as plt, numpy as np, torch, os
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve

plt.figure(); plt.plot(history["train_loss"], label="train_loss"); plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title("Training Loss"); plt.legend(); plt.savefig(os.path.join(DRIVE_DIR,"training_loss.png")); plt.show()
plt.figure(); plt.plot(history["val_acc"], label="val_acc"); plt.xlabel("Epoch"); plt.ylabel("Accuracy"); plt.title("Validation Accuracy"); plt.legend(); plt.savefig(os.path.join(DRIVE_DIR,"val_accuracy.png")); plt.show()

state_dict_path = os.path.join(DRIVE_DIR, f"best_{ARCH}.pth")
model.load_state_dict(torch.load(state_dict_path, map_location=DEVICE))
model.eval()

all_labels, all_preds, all_probs = [], [], []
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(DEVICE)
        outputs = model(images)
        probs = torch.softmax(outputs, dim=1).cpu().numpy()
        preds = np.argmax(probs, axis=1)
        all_labels.extend(labels.numpy())
        all_preds.extend(preds)
        if len(class_names)==2:
            all_probs.extend(probs[:,1])

print("Classes:", class_names)
print(classification_report(all_labels, all_preds, target_names=class_names))

cm = confusion_matrix(all_labels, all_preds)
print("Confusion Matrix:\n", cm)

import itertools
plt.figure()
plt.imshow(cm, interpolation='nearest'); plt.title('Confusion matrix'); plt.colorbar()
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45); plt.yticks(tick_marks, class_names)
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    plt.text(j, i, f"{cm[i,j]}", ha="center", va="center", color="white" if cm[i,j] > thresh else "black")
plt.ylabel('True label'); plt.xlabel('Predicted label'); plt.tight_layout()
plt.savefig(os.path.join(DRIVE_DIR, "confusion_matrix.png")); plt.show()

if len(class_names)==2 and len(all_probs)==len(all_labels):
    auc = roc_auc_score(all_labels, all_probs)
    fpr, tpr, _ = roc_curve(all_labels, all_probs)
    print("ROC AUC:", auc)
    plt.figure(); plt.plot(fpr, tpr, label=f"AUC = {auc:.3f}"); plt.plot([0,1],[0,1],'--'); plt.xlabel("FPR"); plt.ylabel("TPR"); plt.title("ROC Curve"); plt.legend(loc="lower right"); plt.savefig(os.path.join(DRIVE_DIR,"roc_curve.png")); plt.show()


In [None]:
import numpy as np
from PIL import Image, ImageFilter, ImageEnhance

# Reuse the same normalization as training
base_tfms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

def pil_from_tensor(x):
    # x: (C,H,W) in [0,1] (before normalization)
    x = x.clone().cpu()
    # assume already un-normalized, or you load fresh from path
    np_img = np.transpose(x.numpy(), (1, 2, 0))
    np_img = np.clip(np_img, 0, 1)
    return Image.fromarray((np_img * 255).astype(np.uint8))

def apply_perturbation(img_pil, kind, severity):
    img = img_pil.copy()
    if kind == "gaussian_noise":
        np_img = np.array(img)/255.0
        noise = np.random.normal(0, severity, np_img.shape)
        np_img = np.clip(np_img + noise, 0, 1)
        img = Image.fromarray((np_img*255).astype(np.uint8))
    elif kind == "blur":
        img = img.filter(ImageFilter.GaussianBlur(radius=severity))
    elif kind == "brightness":
        enhancer = ImageEnhance.Brightness(img)
        img = enhancer.enhance(severity)  # <1 darker, >1 brighter
    elif kind == "rotation":
        img = img.rotate(severity)
    # add others if you want: contrast, translation, etc.
    return img


In [None]:
import torch
import torch.nn.functional as F

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model.to(DEVICE)
model.eval()

# Define what you want to test
perturbation_configs = [
    ("none",        None,      0),
    ("gaussian_noise", "gaussian_noise", 0.05),
    ("gaussian_noise", "gaussian_noise", 0.10),
    ("blur",       "blur",     2),
    ("brightness", "brightness", 0.5),
    ("brightness", "brightness", 1.5),
    ("rotation",   "rotation", 5),
]

results = {}

# We'll also store baseline predictions to measure flips
all_clean_preds = []

# First pass: get baseline preds
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = model(images)
        probs = F.softmax(outputs, dim=1)
        preds = probs.argmax(dim=1)
        all_clean_preds.append(preds.cpu())

all_clean_preds = torch.cat(all_clean_preds)


In [None]:
from tqdm.auto import tqdm

for name, kind, severity in perturbation_configs:
    correct = 0
    total = 0
    changed = 0
    idx_global = 0  # track index to compare to clean preds
    preds_list = []

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc=f"Perturbation: {name}-{severity}"):
            bs = images.size(0)
            # Convert each image to PIL and perturb
            pil_imgs = []
            for i in range(bs):
                # un-normalize first:
                img = images[i].cpu()
                # undo normalization
                mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
                std  = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
                img_unnorm = img * std + mean
                pil = pil_from_tensor(img_unnorm)
                if kind is not None:
                    pil = apply_perturbation(pil, kind, severity)
                pil_imgs.append(pil)

            # Re-apply transforms & stack
            pert_tensors = torch.stack([base_tfms(p) for p in pil_imgs]).to(DEVICE)
            labels = labels.to(DEVICE)

            outputs = model(pert_tensors)
            probs = F.softmax(outputs, dim=1)
            preds = probs.argmax(dim=1)

            correct += (preds == labels).sum().item()
            total += bs

            # Compare to clean preds for flip rate
            clean_batch = all_clean_preds[idx_global:idx_global+bs]
            changed += (preds.cpu() != clean_batch).sum().item()
            idx_global += bs

    acc = correct / total
    flip_rate = changed / total
    results[(name, severity)] = (acc, flip_rate)
    print(f"{name} (sev={severity}): accuracy={acc:.3f}, flip rate={flip_rate:.3f}")


In [None]:
print("\n=== Robustness Summary ===")
print("Perturbation\tSeverity\tAccuracy\tFlipRate")
for (name, sev), (acc, flip) in results.items():
    print(f"{name}\t{sev}\t{acc:.3f}\t{flip:.3f}")


In [None]:
!pip install grad-cam
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import numpy as np
import matplotlib.pyplot as plt

model.eval()
target_layer = model.layer4[-1]  # for ResNet18 / ResNet50


In [None]:
CLASS_NAMES = ["NORMAL", "PNEUMONIA"]

def gradcam_on_image(img_path, target_class_idx=None):
    # 1) Load and preprocess image
    img_pil = Image.open(img_path).convert("RGB")
    img_resized = img_pil.resize((224, 224))
    rgb_img = np.array(img_resized) / 255.0  # for overlay (H,W,3, 0â€“1)

    input_tensor = base_tfms(img_pil).unsqueeze(0).to(DEVICE)

    # 2) Get model prediction to pick default target class
    model.eval()
    with torch.no_grad():
        out = model(input_tensor)
        pred_idx = out.argmax(dim=1).item()

    class_idx = target_class_idx if target_class_idx is not None else pred_idx

    # 3) Run Grad-CAM (no 'use_cuda' argument in new API)
    with GradCAM(model=model, target_layers=[target_layer]) as cam:
        grayscale_cam = cam(
            input_tensor=input_tensor,
            targets=[ClassifierOutputTarget(class_idx)]
        )[0]  # [H,W]

    # 4) Overlay heatmap on image
    cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

    # 5) Plot original + Grad-CAM
    plt.figure(figsize=(6, 3))
    plt.subplot(1, 2, 1)
    plt.imshow(img_pil)
    plt.title("Original")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(cam_image)
    plt.title(f"Grad-CAM: {CLASS_NAMES[class_idx]}")
    plt.axis("off")
    plt.tight_layout()
    plt.show()

    return pred_idx, class_idx, grayscale_cam


In [None]:
img_path = "/content/drive/MyDrive/chest_xray/test/PNEUMONIA/person1_virus_6.jpeg"
pred_idx, class_idx, grayscale_cam = gradcam_on_image(img_path)
print("Model predicted:", CLASS_NAMES[pred_idx])


In [None]:
def central_focus_score(grayscale_cam, central_frac=0.5):
    h, w = grayscale_cam.shape
    ch = int(h * central_frac)
    cw = int(w * central_frac)
    y0 = (h - ch)//2
    x0 = (w - cw)//2

    central = grayscale_cam[y0:y0+ch, x0:x0+cw]
    total = grayscale_cam.sum() + 1e-8
    return central.sum() / total

# Example using the grayscale_cam inside gradcam_on_image


In [None]:
img_path = "/content/drive/MyDrive/chest_xray/test/PNEUMONIA/person1_virus_6.jpeg"

pred_idx, class_idx, grayscale_cam = gradcam_on_image(img_path)
score = central_focus_score(grayscale_cam, central_frac=0.5)

print("Model predicted:", CLASS_NAMES[pred_idx])
print("Central focus score:", score)
