# Chest X‑ray 4‑Class Classifier (Colab, PyTorch, T4 GPU)

This notebook trains a classifier for **COVID**, **Lung_Opacity**, **Normal**, and **Viral Pneumonia** using **PyTorch** on Google Colab (T4 GPU).  
It supports:
- Reading a `dataset.csv` which contains the paths to images and their labels.
- Robust path resolution (handles Windows paths or relative paths).
- Class imbalance handling (weighted sampler).
- Two training tracks:
  - **A. Small CNN baseline** (from scratch)
  - **B. Transfer Learning** with **ResNet‑18** or **EfficientNet‑B0** (recommended)
- Mixed precision (AMP), early stopping, best‑model checkpoint, and evaluation (confusion matrix + per‑class metrics).
- TensorBoard logging.

> **How to use**: Upload your images (folder tree) and `dataset.csv` to Google Drive, set the two variables in the **Configuration** cell, then run cells top‑to‑bottom.


## 0) Runtime & installs

In [None]:
#@title (Colab) Check GPU & install deps
!nvidia-smi -L || true

# Optional: upgrade pip & core libs
%pip -q install --upgrade pip
%pip -q install pandas scikit-learn matplotlib seaborn tqdm tensorboard

# Torch/torchvision come preinstalled on Colab; if needed, uncomment next:
# %pip -q install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

import torch, torchvision
print('Torch:', torch.__version__, '| Torchvision:', torchvision.__version__)

## 1) Imports & global config

In [None]:
import os, sys, time, math, shutil, random, json
from pathlib import Path
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.cuda import amp
import torch.optim as optim

from torchvision import transforms, models

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from collections import Counter
from tqdm import tqdm

# Reproducibility
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False  # faster + okay with AMP
    torch.backends.cudnn.benchmark = True

set_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## 2) Mount Drive (optional) & Configuration

In [None]:
#@title Mount Drive (optional)
# If your data is in Google Drive, mount it:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#@title Configuration (edit these two paths)
# Path to your images ROOT and to the CSV with image paths and labels.
IMAGES_ROOT = "/content/drive/MyDrive/datasets/chest_xray"  #@param {type:"string"}
CSV_PATH    = "/content/drive/MyDrive/datasets/chest_xray/dataset.csv"  #@param {type:"string"}

# Column names in CSV (auto-detected if blank)
LABEL_COL   = ""  # e.g., "label" (leave empty to auto-detect)
PATH_COL    = ""  # e.g., "img_path" or "path" (leave empty to auto-detect)

# Model selection: "cnn_small", "resnet18", or "efficientnet_b0"
MODEL_NAME  = "resnet18"  #@param ["cnn_small", "resnet18", "efficientnet_b0"]
IMAGE_SIZE  = 224         #@param {type:"integer"}
BATCH_SIZE  = 32          #@param {type:"integer"}
EPOCHS      = 15          #@param {type:"integer"}
LR          = 1e-4        #@param {type:"number"}
WEIGHT_DECAY= 1e-4        #@param {type:"number"}

# Whether to use WeightedRandomSampler to handle class imbalance
USE_WEIGHTED_SAMPLER = True  #@param {type:"boolean"}

# Where to save checkpoints and logs
OUT_DIR = Path('/content/xray_runs')
OUT_DIR.mkdir(parents=True, exist_ok=True)
print('Outputs:', OUT_DIR)

## 3) Load CSV & resolve image paths

In [None]:
# Helper: normalize Windows/relative paths and find images within IMAGES_ROOT
def index_all_images(root: Path):
    """Create a dict: basename -> [full_paths...] for quick fallback lookups."""
    idx = {}
    for dirpath, _, filenames in os.walk(root):
        for fn in filenames:
            if fn.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
                full = str(Path(dirpath) / fn)
                idx.setdefault(fn.lower(), []).append(full)
    return idx

def normalize_path(s: str) -> str:
    # Replace backslashes, strip quotes/spaces
    s = str(s).strip().strip('\"\' )
    s = s.replace('\\', '/').replace('\\', '/').replace('\\', '/')
    s = s.replace('\\', '/').replace('\\', '/')
    s = s.replace('\\', '/')
    s = s.replace('\\', '/')
    s = s.replace('\\', '/')
    s = s.replace('\', '/')
    return s

def resolve_paths(df: pd.DataFrame, images_root: Path, path_col: str):
    images_root = Path(images_root)
    assert images_root.exists(), f"IMAGES_ROOT not found: {images_root}"

    # Build index once (can be ~1-2s for ~20k images)
    print('Indexing images under', images_root, '...')
    name_index = index_all_images(images_root)
    resolved, missing = [], []
    for p in tqdm(df[path_col].tolist(), desc='Resolving paths'):
        p_norm = normalize_path(p)
        cand = Path(p_norm)
        # Case 1: absolute path that actually exists (rare in Colab)
        if cand.exists():
            resolved.append(str(cand))
            continue
        # Case 2: treat as relative to IMAGES_ROOT
        cand2 = images_root / p_norm
        if cand2.exists():
            resolved.append(str(cand2))
            continue
        # Case 3: try basename lookup
        base = os.path.basename(p_norm).lower()
        hits = name_index.get(base, [])
        if len(hits) >= 1:
            # If multiple hits, pick the first; consider disambiguation if needed
            resolved.append(hits[0])
        else:
            resolved.append(None)
            missing.append(p_norm)
    if missing:
        print(f"WARNING: {len(missing)} images not found. First 10:\n", missing[:10])
    return resolved

# Load CSV
df = pd.read_csv(CSV_PATH)
print('CSV shape:', df.shape)
print('Columns:', df.columns.tolist())

# Auto-detect columns if user left them blank
if not PATH_COL:
    candidates = [c for c in df.columns if c.lower() in ['img_path','path','image_path','filepath','file','filename']]
    assert len(candidates) >= 1, "Could not auto-detect the image path column. Set PATH_COL manually."
    PATH_COL = candidates[0]
if not LABEL_COL:
    candidates = [c for c in df.columns if c.lower() in ['label','class','target','category']]
    assert len(candidates) >= 1, "Could not auto-detect the label column. Set LABEL_COL manually."
    LABEL_COL = candidates[0]

print('Using PATH_COL =', PATH_COL, '| LABEL_COL =', LABEL_COL)

# Resolve full paths
df['resolved_path'] = resolve_paths(df, Path(IMAGES_ROOT), PATH_COL)

# Drop rows with missing images
before = len(df)
df = df.dropna(subset=['resolved_path']).reset_index(drop=True)
after = len(df)
print(f'Retained {after}/{before} rows after resolving image paths.')

# Normalize labels (strip spaces)
df[LABEL_COL] = df[LABEL_COL].astype(str).str.strip()
print(df[LABEL_COL].value_counts())

# Look at a few rows
df.head()

## 4) Train/Val/Test split & Datasets

In [None]:
# Classes
classes = sorted(df[LABEL_COL].unique().tolist())
print('Classes:', classes)
cls_to_idx = {c:i for i,c in enumerate(classes)}
idx_to_cls = {i:c for c,i in cls_to_idx.items()}

# Stratified split 70/15/15
df_train, df_temp = train_test_split(df, test_size=0.30, stratify=df[LABEL_COL], random_state=42)
df_val, df_test = train_test_split(df_temp, test_size=0.50, stratify=df_temp[LABEL_COL], random_state=42)

print('Train:', df_train.shape, 'Val:', df_val.shape, 'Test:', df_test.shape)

# Transforms
mean = [0.485, 0.456, 0.406]
std  = [0.229, 0.224, 0.225]

train_tfms = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.85, 1.0)),
    transforms.RandomRotation(5),
    # X-rays: usually avoid horizontal flip unless you confirm it's acceptable for your task
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])
val_tfms = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize(int(IMAGE_SIZE * 1.14)),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

class XRayDataset(Dataset):
    def __init__(self, df, path_col, label_col, transforms=None, cls_to_idx=None):
        self.df = df.reset_index(drop=True)
        self.path_col = path_col
        self.label_col = label_col
        self.transforms = transforms
        self.cls_to_idx = cls_to_idx

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        row = self.df.iloc[i]
        p = row['resolved_path']
        y = self.cls_to_idx[row[self.label_col]]
        with Image.open(p) as img:
            img = img.convert('L')  # ensure grayscale
        if self.transforms:
            img = self.transforms(img)
        return img, y

train_ds = XRayDataset(df_train, 'resolved_path', LABEL_COL, transforms=train_tfms, cls_to_idx=cls_to_idx)
val_ds   = XRayDataset(df_val,   'resolved_path', LABEL_COL, transforms=val_tfms,   cls_to_idx=cls_to_idx)
test_ds  = XRayDataset(df_test,  'resolved_path', LABEL_COL, transforms=val_tfms,   cls_to_idx=cls_to_idx)

# Weighted sampler for imbalance
sampler = None
if USE_WEIGHTED_SAMPLER:
    counts = Counter(df_train[LABEL_COL].tolist())
    class_count = torch.tensor([counts[c] for c in classes], dtype=torch.float)
    class_weight = 1.0 / class_count
    sample_weights = [class_weight[cls_to_idx[label]].item() for label in df_train[LABEL_COL].tolist()]
    sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=(sampler is None), sampler=sampler, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

len(train_ds), len(val_ds), len(test_ds)

## 5) Define Models (Small CNN / ResNet‑18 / EfficientNet‑B0)

In [None]:
class SmallCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 112

            nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 56

            nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 28

            nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1,1)),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.3),
            nn.Linear(256, 128), nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

def create_model(name: str, num_classes: int):
    name = name.lower()
    if name == 'cnn_small':
        return SmallCNN(num_classes)
    elif name == 'resnet18':
        weights = models.ResNet18_Weights.DEFAULT
        model = models.resnet18(weights=weights)
        # Freeze first layers optionally (uncomment to fine‑tune last layers only)
        # for p in list(model.parameters())[:-5]: p.requires_grad = False
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)
        return model
    elif name == 'efficientnet_b0':
        weights = models.EfficientNet_B0_Weights.DEFAULT
        model = models.efficientnet_b0(weights=weights)
        in_features = model.classifier[-1].in_features
        model.classifier[-1] = nn.Linear(in_features, num_classes)
        return model
    else:
        raise ValueError(f"Unknown model: {name}")

## 6) Training & Evaluation Utilities

In [None]:
def train_one_epoch(model, loader, optimizer, scaler, criterion):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    for x, y in tqdm(loader, leave=False):
        x, y = x.to(device, non_blocking=True), torch.tensor(y, device=device)
        optimizer.zero_grad(set_to_none=True)
        with amp.autocast():
            logits = model(x)
            loss = criterion(logits, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item() * x.size(0)
        preds = logits.argmax(1)
        correct += (preds == y).sum().item()
        total += x.size(0)
    return total_loss / total, correct / total

@torch.no_grad()
def evaluate(model, loader, criterion):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    for x, y in tqdm(loader, leave=False):
        x, y = x.to(device, non_blocking=True), torch.tensor(y, device=device)
        logits = model(x)
        loss = criterion(logits, y)
        total_loss += loss.item() * x.size(0)
        preds = logits.argmax(1)
        correct += (preds == y).sum().item()
        total += x.size(0)
    return total_loss / total, correct / total

def save_checkpoint(state, is_best, out_dir: Path, tag='last'):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    path = out_dir / f'checkpoint_{tag}.pt'
    torch.save(state, path)
    if is_best:
        best_path = out_dir / 'best.pt'
        shutil.copy2(path, best_path)
    return path

def plot_confusion(y_true, y_pred, labels, normalize=True):
    cm = confusion_matrix(y_true, y_pred, labels=range(len(labels)))
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    plt.figure(figsize=(6,5))
    sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd',
                xticklabels=labels, yticklabels=labels)
    plt.xlabel('Predicted'); plt.ylabel('True'); plt.title('Confusion Matrix')
    plt.tight_layout(); plt.show()

## 7) Train Run

In [None]:
model = create_model(MODEL_NAME, num_classes=len(classes)).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)
scaler = amp.GradScaler()

best_acc = 0.0
patience, bad = 5, 0

history = {'epoch': [], 'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

for epoch in range(1, EPOCHS+1):
    print(f"Epoch {epoch}/{EPOCHS}")
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, scaler, criterion)
    val_loss, val_acc = evaluate(model, val_loader, criterion)
    scheduler.step(val_acc)

    history['epoch'].append(epoch)
    history['train_loss'].append(train_loss); history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss);     history['val_acc'].append(val_acc)

    print(f"  train: loss={train_loss:.4f} acc={train_acc:.4f} | val: loss={val_loss:.4f} acc={val_acc:.4f}")

    is_best = val_acc > best_acc
    if is_best:
        best_acc = val_acc
        bad = 0
    else:
        bad += 1

    save_checkpoint({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_acc': val_acc,
        'classes': classes,
        'model_name': MODEL_NAME,
        'image_size': IMAGE_SIZE
    }, is_best=is_best, out_dir=OUT_DIR, tag=f'epoch{epoch:03d}')

    if bad >= patience:
        print(f"Early stopping (no val acc improvement in {patience} epochs). Best val acc: {best_acc:.4f}")
        break

# Save training history
with open(OUT_DIR / 'history.json', 'w') as f:
    json.dump(history, f, indent=2)

print('Best val acc:', best_acc)

## 8) Evaluate on Test Set (confusion matrix & per‑class metrics)

In [None]:
# Load best model
ckpt = torch.load(OUT_DIR / 'best.pt', map_location=device)
model = create_model(ckpt.get('model_name', MODEL_NAME), num_classes=len(classes)).to(device)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()

y_true, y_pred = [], []
with torch.no_grad():
    for x, y in tqdm(test_loader):
        x = x.to(device, non_blocking=True)
        logits = model(x)
        preds = logits.argmax(1).cpu().numpy().tolist()
        y_pred.extend(preds)
        y_true.extend(y)

print(classification_report(y_true, y_pred, target_names=classes, digits=4))
plot_confusion(y_true, y_pred, labels=classes, normalize=True)

## 9) Inference on a single image

In [None]:
@torch.no_grad()
def predict_image(image_path: str):
    model.eval()
    img = Image.open(image_path).convert('L')
    tfm = transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize(int(IMAGE_SIZE * 1.14)),
        transforms.CenterCrop(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])
    x = tfm(img).unsqueeze(0).to(device)
    logits = model(x)
    probs = torch.softmax(logits, dim=1).squeeze(0).cpu().numpy()
    pred_idx = int(np.argmax(probs))
    return classes[pred_idx], {cls: float(probs[i]) for i, cls in enumerate(classes)}

# Example:
# img_path = test_ds.df.sample(1).iloc[0]['resolved_path']
# pred, probs = predict_image(img_path)
# print('Pred:', pred); probs

## 10) (Optional) TensorBoard

In [None]:
# You can also log to TensorBoard if you add SummaryWriter in the train loop.
# For now, just show how to launch it (if logs exist):
# %load_ext tensorboard
# %tensorboard --logdir /content/xray_runs