<a href="https://colab.research.google.com/github/kunal-shetty/Chest-X-ray-disease-prediction/blob/main/ML_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [2]:
import os
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torchvision import models, transforms

from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import GroupShuffleSplit
from sklearn.metrics import roc_auc_score


In [3]:
CSV_PATH = "/content/drive/MyDrive/Colab Notebooks/outputs/FINAL_clean_data.csv"
IMAGE_ROOT = "/content/drive/MyDrive/Colab Notebooks/real data/images_0013/images"

df = pd.read_csv(CSV_PATH)

# split labels
df["finding_labels"] = df["finding_labels"].apply(lambda x: x.split("|"))

# build image paths
df["image_path"] = df["Image Index"].astype(str).apply(
    lambda x: os.path.join(IMAGE_ROOT, x)
)

df["image_exists"] = df["image_path"].apply(os.path.exists)
df = df[df["image_exists"]].reset_index(drop=True)

print("Images that exist:", len(df))


Images that exist: 5830


In [4]:
# def is_readable(path):
#     try:
#         with Image.open(path) as img:
#             img.verify()
#         return True
#     except:
#         return False

# df["image_readable"] = [is_readable(p) for p in tqdm(df["image_path"])]

# print("Unreadable images:", (~df["image_readable"]).sum())

# df = df[df["image_readable"]].reset_index(drop=True)
# print("Final usable images:", len(df))


In [5]:
mlb = MultiLabelBinarizer()
Y = mlb.fit_transform(df["finding_labels"])

NUM_CLASSES = len(mlb.classes_)
print("Number of disease labels:", NUM_CLASSES)


Number of disease labels: 13


In [6]:
gss = GroupShuffleSplit(test_size=0.2, n_splits=1, random_state=42)

train_idx, test_idx = next(
    gss.split(df, Y, groups=df["Patient ID"])
)

train_df = df.iloc[train_idx].reset_index(drop=True)
test_df  = df.iloc[test_idx].reset_index(drop=True)

Y_train = Y[train_idx]
Y_test  = Y[test_idx]

print("Train:", len(train_df))
print("Test:", len(test_df))


Train: 4655
Test: 1175


In [7]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])


In [None]:
X_train, Y_train_mem = [], []
X_test,  Y_test_mem  = [], []

print("Loading TRAIN images into memory...")
for i in tqdm(range(len(train_df))):
    try:
        img = Image.open(train_df.loc[i, "image_path"]).convert("RGB")
        img = transform(img)
        X_train.append(img)
        Y_train_mem.append(torch.tensor(Y_train[i], dtype=torch.float32))
    except:
        continue

print("Loading TEST images into memory...")
for i in tqdm(range(len(test_df))):
    try:
        img = Image.open(test_df.loc[i, "image_path"]).convert("RGB")
        img = transform(img)
        X_test.append(img)
        Y_test_mem.append(torch.tensor(Y_test[i], dtype=torch.float32))
    except:
        continue


Loading TRAIN images into memory...


 31%|███▏      | 1456/4655 [12:37<26:26,  2.02it/s]

In [None]:
X_train = torch.stack(X_train)
Y_train_mem = torch.stack(Y_train_mem)

X_test = torch.stack(X_test)
Y_test_mem = torch.stack(Y_test_mem)

print("Final train samples:", X_train.shape[0])
print("Final test samples:", X_test.shape[0])


In [None]:
train_loader = DataLoader(
    TensorDataset(X_train, Y_train_mem),
    batch_size=8,
    shuffle=True,
    num_workers=0
)

test_loader = DataLoader(
    TensorDataset(X_test, Y_test_mem),
    batch_size=8,
    shuffle=False,
    num_workers=0
)


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
model = model.to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


In [None]:
EPOCHS = 3

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for imgs, labels in tqdm(train_loader):
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {total_loss/len(train_loader):.4f}")


In [None]:
model.eval()
preds, targets = [], []

with torch.no_grad():
    for imgs, labels in train_loader:
        imgs = imgs.to(device)
        outputs = torch.sigmoid(model(imgs)).cpu().numpy()
        preds.append(outputs)
        targets.append(labels.numpy())

preds = np.vstack(preds)
targets = np.vstack(targets)

auc = roc_auc_score(targets, preds, average="micro")
print("TRAIN MICRO-AUROC:", auc)

In [None]:
torch.save({
    "model_state": model.state_dict(),
    "label_names": mlb.classes_
}, "xray_multilabel_model.pth")


In [None]:
checkpoint = torch.load(
    "xray_multilabel_model.pth",
    map_location=device,
    weights_only=False
)

model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, len(checkpoint["label_names"]))
model.load_state_dict(checkpoint["model_state"])
model = model.to(device)
model.eval()

label_names = checkpoint["label_names"]


In [None]:

def predict_xray(image_path, threshold=0.5):
    img = Image.open(image_path).convert("RGB")
    img = transform(img).unsqueeze(0).to(device)  # add batch dim

    with torch.no_grad():
        logits = model(img)
        probs = torch.sigmoid(logits).cpu().numpy()[0]

    results = {
        label_names[i]: float(probs[i])
        for i in range(len(label_names))
        if probs[i] >= threshold
    }

    return results


In [None]:
def batch_predict(indices, threshold=0.5):
    imgs = torch.stack([X_train[i] for i in indices]).to(device)

    with torch.no_grad():
        probs = torch.sigmoid(model(imgs)).cpu().numpy()

    for k, idx in enumerate(indices):
        print("=" * 60)
        print(f"Image index: {idx}")

        age = train_df.loc[idx, "patient_age"] if "patient_age" in train_df.columns else "NA"
        gender = train_df.loc[idx, "gender"] if "gender" in train_df.columns else "NA"
        view = train_df.loc[idx, "view_position"] if "view_position" in train_df.columns else "NA"

        print(f"Age: {age}, Gender: {gender}, View: {view}")
        print("Detected findings:")

        found = False
        for i, p in enumerate(probs[k]):
            if p >= threshold:
                print(f"• {label_names[i]} → {p:.3f}")
                found = True

        if not found:
            print("• Normal")


In [None]:
def predict_raw(idx):
    img = X_train[idx].unsqueeze(0).to(device)

    with torch.no_grad():
        logits = model(img)
        probs = torch.sigmoid(logits).cpu().numpy()[0]

    print(f"\nRAW probabilities for image index {idx}:")
    for i, p in enumerate(probs):
        print(f"{label_names[i]:25s} : {p:.4f}")


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from torch.utils.data import DataLoader, TensorDataset

# --------------------------------------------------
# STEP 1: Get predictions on training data
# --------------------------------------------------

model.eval()
preds, targets = [], []

with torch.no_grad():
    for imgs, labels in DataLoader(
        TensorDataset(X_train, Y_train_mem),
        batch_size=16,
        shuffle=False
    ):
        imgs = imgs.to(device)
        outputs = torch.sigmoid(model(imgs)).cpu().numpy()
        preds.append(outputs)
        targets.append(labels.numpy())

preds = np.vstack(preds)
targets = np.vstack(targets)

print("Predictions shape:", preds.shape)
print("Targets shape:", targets.shape)

# --------------------------------------------------
# STEP 2: Plot ROC curves for ALL diseases
# --------------------------------------------------

plt.figure(figsize=(10, 8))

valid_disease_count = 0
auc_scores = {}

for i, disease in enumerate(label_names):

    # AUROC undefined if only one class present
    if len(np.unique(targets[:, i])) < 2:
        auc_scores[disease] = None
        continue

    fpr, tpr, _ = roc_curve(targets[:, i], preds[:, i])
    roc_auc = auc(fpr, tpr)

    plt.plot(
        fpr,
        tpr,
        lw=1.5,
        label=f"{disease} (AUC={roc_auc:.2f})"
    )

    auc_scores[disease] = roc_auc
    valid_disease_count += 1

# Diagonal reference line
plt.plot([0, 1], [0, 1], "k--", lw=1)

plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title(f"ROC Curves for All Diseases (Valid = {valid_disease_count})")
plt.legend(fontsize=8, loc="lower right", ncol=2)
plt.grid(True)
plt.tight_layout()
plt.show()

# --------------------------------------------------
# STEP 3: Print per-disease AUROC table
# --------------------------------------------------

print("\nPer-Disease AUROC Summary")
print("-" * 45)

for disease, score in auc_scores.items():
    if score is None:
        print(f"{disease:25s} : N/A (single class)")
    else:
        print(f"{disease:25s} : {score:.3f}")


In [None]:
import torch
import torch.nn.functional as F
import cv2
import numpy as np
import matplotlib.pyplot as plt

def gradcam_generate(model, img, target_class, target_layer):
    """
    Hook-free Grad-CAM implementation
    """
    activations = None
    gradients = None

    def forward_hook(module, input, output):
        nonlocal activations
        activations = output

    def backward_hook(module, grad_in, grad_out):
        nonlocal gradients
        gradients = grad_out[0]

    # Register hooks TEMPORARILY
    fh = target_layer.register_forward_hook(forward_hook)
    bh = target_layer.register_backward_hook(backward_hook)

    # Forward + backward
    model.zero_grad()
    output = model(img)
    score = output[0, target_class]
    score.backward()

    # Compute Grad-CAM
    weights = gradients.mean(dim=(2, 3), keepdim=True)
    cam = (weights * activations).sum(dim=1)
    cam = F.relu(cam)
    cam = cam / (cam.max() + 1e-8)

    # Remove hooks immediately (IMPORTANT)
    fh.remove()
    bh.remove()

    return cam[0].detach().cpu().numpy()


In [None]:
from sklearn.metrics import roc_auc_score

model.eval()
preds, targets = [], []

with torch.no_grad():
    for imgs, labels in DataLoader(
        TensorDataset(X_train, Y_train_mem),
        batch_size=16,
        shuffle=False
    ):
        imgs = imgs.to(device)
        outputs = torch.sigmoid(model(imgs)).cpu().numpy()
        preds.append(outputs)
        targets.append(labels.numpy())

preds = np.vstack(preds)
targets = np.vstack(targets)


In [None]:
from torchvision import transforms
from PIL import Image

infer_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


In [None]:
import cv2
import numpy as np
from PIL import Image

def safe_load_image(path):
    try:
        # OpenCV read (much more stable with Drive)
        img = cv2.imread(path)
        if img is None:
            raise ValueError("cv2.imread failed")

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return Image.fromarray(img)

    except Exception as e:
        print(f" Skipping image due to read error: {path}")
        return None


In [None]:

for idx in range(50):

    print("=" * 70)
    print(f"Image index: {idx}")

    age = train_df.loc[idx, "patient_age"] if "patient_age" in train_df.columns else "NA"
    gender = train_df.loc[idx, "gender"] if "gender" in train_df.columns else "NA"
    view = train_df.loc[idx, "view_position"] if "view_position" in train_df.columns else "NA"

    img = X_train[idx].unsqueeze(0).to(device)

    with torch.no_grad():
        logits = model(img)
        probs = torch.sigmoid(logits).cpu().numpy()[0]

    print("\nRaw Output Probabilities (%):")
    for i, p in enumerate(probs):
        print(f"{label_names[i]:25s} : {p*100:.2f}%")

    detected = [(i, p) for i, p in enumerate(probs) if p >= 0.6]

    print("\nDetected Findings (threshold = 60%):")
    if not detected:
        print("• Normal (no abnormal findings)")
        continue

    for i, p in detected:
        print(f"• {label_names[i]} → {p*100:.2f}%")

    top_class = max(detected, key=lambda x: x[1])[0]

    # ---- Grad-CAM for TOP prediction ----
    top_class = max(detected, key=lambda x: x[1])[0]

    cam = gradcam_generate(model=model, img=img, target_class=top_class, target_layer=model.layer4)
    cam = cv2.resize(cam, (224, 224))
    heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)

    original = X_train[idx].permute(1, 2, 0).cpu().numpy()
    original = (original - original.min()) / (original.max() - original.min() + 1e-8)

    overlay = 0.6 * original + 0.4 * heatmap / 255.0

    plt.figure(figsize=(4, 4))
    plt.imshow(overlay)
    plt.title(f"Grad-CAM → {label_names[top_class]}")
    plt.axis("off")
    plt.show()


In [None]:
#For user input images

# user_images = [
#     "/content/drive/MyDrive/Colab Notebooks/real data/images_0013/images/00003929_000.png"
# ]

# for idx, img_path in enumerate(user_images):

#     print("=" * 70)
#     print(f"User Image {idx+1}: {img_path}")

#     # ---- Load image ----
#     img_pil = safe_load_image(img_path)
#     if img_pil is None:
#       continue  # skip this image safely

#     img_tensor = infer_transform(img_pil).unsqueeze(0).to(device)

#     # ---- Inference ----
#     with torch.no_grad():
#         logits = model(img_tensor)
#         probs = torch.sigmoid(logits).cpu().numpy()[0]

#     # ---- Raw probabilities ----
#     print("\nRaw Output Probabilities (%):")
#     for i, p in enumerate(probs):
#         print(f"{label_names[i]:25s} : {p*100:.2f}%")

#     # ---- Thresholding ----
#     detected = [(i, p) for i, p in enumerate(probs) if p >= 0.3]

#     print("\nDetected Findings (threshold = 30%):")
#     if not detected:
#         print("• Normal (no abnormal findings)")
#         continue

#     for i, p in detected:
#         print(f"• {label_names[i]} → {p*100:.2f}%")

#     # ---- Grad-CAM only if abnormal ----
#     top_class = max(detected, key=lambda x: x[1])[0]

#     cam = gradcam_generate(
#         model=model,
#         img=img_tensor,
#         target_class=top_class,
#         target_layer=model.layer4
#     )

#     # ---- Reddish-brown overlay ----
#     cam = cam - cam.min()
#     cam = cam / (cam.max() + 1e-8)
#     cam = cv2.resize(cam, (224, 224))

#     CAM_THRESHOLD = 0.4
#     cam_mask = cam > CAM_THRESHOLD

#     original = np.array(img_pil.resize((224, 224))) / 255.0

#     brown_overlay = np.zeros_like(original)
#     brown_overlay[..., 0] = cam * 0.9
#     brown_overlay[..., 1] = cam * 0.35
#     brown_overlay[..., 2] = cam * 0.15

#     overlay = original.copy()
#     overlay[cam_mask] = (
#         0.65 * original[cam_mask] +
#         0.35 * brown_overlay[cam_mask]
#     )

#     plt.figure(figsize=(4, 4))
#     plt.imshow(overlay)
#     plt.title(f"Grad-CAM → {label_names[top_class]}")
#     plt.axis("off")
#     plt.show()
