# Chest X‑ray 4‑Class Classifier + **MLflow** (Colab, PyTorch, T4 GPU)

This notebook trains a classifier for **COVID**, **Lung_Opacity**, **Normal**, and **Viral Pneumonia** with **PyTorch** and tracks experiments using **MLflow** (remote or local server).  
It supports:
- Reading a `dataset.csv` with image paths + labels.
- Path normalization (handles Windows paths; resolves relative to `IMAGES_ROOT`).
- Class imbalance handling via `WeightedRandomSampler`.
- Model variants: **Small CNN**, **ResNet‑18**, **EfficientNet‑B0**.
- Mixed precision (AMP), early stopping, checkpoints.
- **MLflow logging**: params, per-epoch metrics, confusion matrix & predictions as artifacts, and model via `mlflow.pytorch.log_model`.

> **How to use**: Point `MLFLOW_TRACKING_URI` to your remote MLflow server (or leave blank for local); set `IMAGES_ROOT` and `CSV_PATH`; pick which models to run in `MODELS_TO_RUN`; run cells top to bottom.


## 0) Runtime & installs

In [None]:
#@title (Colab) Check GPU & install deps
!nvidia-smi -L || true

%pip -q install --upgrade pip
%pip -q install pandas scikit-learn matplotlib seaborn tqdm tensorboard mlflow

import torch, torchvision, mlflow
print('Torch:', torch.__version__, '| Torchvision:', torchvision.__version__, '| MLflow:', mlflow.__version__)

## 1) Imports & global setup

In [None]:
import os, sys, time, math, shutil, random, json, re
from pathlib import Path
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.cuda import amp
import torch.optim as optim

from torchvision import transforms, models

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from collections import Counter
from tqdm import tqdm

import mlflow
import mlflow.pytorch

# Reproducibility
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

set_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## 2) Mount Drive (optional) & Configuration

In [None]:
#@title Mount Drive (optional)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#@title Configuration
# --- Data paths ---
IMAGES_ROOT = "/content/drive/MyDrive/datasets/chest_xray"  #@param {type:"string"}
CSV_PATH    = "/content/drive/MyDrive/datasets/chest_xray/dataset.csv"  #@param {type:"string"}

# Leave empty to auto-detect names in CSV
LABEL_COL   = ""  # e.g., "label"
PATH_COL    = ""  # e.g., "img_path"

# --- Training hyperparams ---
IMAGE_SIZE   = 224          #@param {type:"integer"}
BATCH_SIZE   = 32           #@param {type:"integer"}
EPOCHS       = 15           #@param {type:"integer"}
LR           = 1e-4         #@param {type:"number"}
WEIGHT_DECAY = 1e-4         #@param {type:"number"}
USE_WEIGHTED_SAMPLER = True #@param {type:"boolean"}

# Which models to run (each will become a separate MLflow run)
MODELS_TO_RUN = ["resnet18", "efficientnet_b0", "cnn_small"]  #@param

# --- Output dirs ---
OUT_DIR = Path('/content/xray_runs_mlflow')
OUT_DIR.mkdir(parents=True, exist_ok=True)

# --- MLflow config ---
USE_MLFLOW = True  #@param {type:"boolean"}
MLFLOW_TRACKING_URI = "http://YOUR_EC2_PUBLIC_IP:8050"  #@param {type:"string"}
MLFLOW_EXPERIMENT   = "xray-4class"  #@param {type:"string"}

if USE_MLFLOW and MLFLOW_TRACKING_URI.strip():
    mlflow.set_tracking_uri(MLFLOW_TRACKING_URI.strip())
mlflow.set_experiment(MLFLOW_EXPERIMENT)

print("Outputs:", OUT_DIR)
if USE_MLFLOW:
    print("MLflow tracking URI:", mlflow.get_tracking_uri())

## 3) Load CSV & resolve image paths

In [None]:
def index_all_images(root: Path):
    """Create a dict: basename -> [full_paths...] for quick fallback lookups."""
    idx = {}
    for dirpath, _, filenames in os.walk(root):
        for fn in filenames:
            if fn.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
                full = str(Path(dirpath) / fn)
                idx.setdefault(fn.lower(), []).append(full)
    return idx

def normalize_path(s: str) -> str:
    s = str(s).strip()
    # strip surrounding quotes
    if len(s) >= 2 and s[0] == s[-1] and s[0] in ("'", '"'):
        s = s[1:-1]
    s = s.replace("\\", "/")
    s = re.sub(r"/+", "/", s)
    return s

def resolve_paths(df: pd.DataFrame, images_root: Path, path_col: str):
    images_root = Path(images_root)
    assert images_root.exists(), f"IMAGES_ROOT not found: {images_root}"
    print('Indexing images under', images_root, '...')
    name_index = index_all_images(images_root)
    resolved, missing = [], []
    for p in tqdm(df[path_col].tolist(), desc='Resolving paths'):
        p_norm = normalize_path(p)
        cand = Path(p_norm)
        if cand.exists():
            resolved.append(str(cand)); continue
        cand2 = images_root / p_norm
        if cand2.exists():
            resolved.append(str(cand2)); continue
        base = os.path.basename(p_norm).lower()
        hits = name_index.get(base, [])
        if hits:
            resolved.append(hits[0])
        else:
            resolved.append(None); missing.append(p_norm)
    if missing:
        print(f"WARNING: {len(missing)} images not found. First 10:\n", missing[:10])
    return resolved

# Load CSV
df = pd.read_csv(CSV_PATH)
print('CSV shape:', df.shape)
print('Columns:', df.columns.tolist())

# Auto-detect columns if blank
if not PATH_COL:
    candidates = [c for c in df.columns if c.lower() in ['img_path','path','image_path','filepath','file','filename']]
    assert len(candidates) >= 1, "Could not auto-detect the image path column. Set PATH_COL manually."
    PATH_COL = candidates[0]
if not LABEL_COL:
    candidates = [c for c in df.columns if c.lower() in ['label','class','target','category']]
    assert len(candidates) >= 1, "Could not auto-detect the label column. Set LABEL_COL manually."
    LABEL_COL = candidates[0]

print('Using PATH_COL =', PATH_COL, '| LABEL_COL =', LABEL_COL)

# Resolve full paths
df['resolved_path'] = resolve_paths(df, Path(IMAGES_ROOT), PATH_COL)

# Drop rows with missing images
before = len(df)
df = df.dropna(subset=['resolved_path']).reset_index(drop=True)
after = len(df)
print(f'Retained {after}/{before} rows after resolving image paths.')

# Normalize labels
df[LABEL_COL] = df[LABEL_COL].astype(str).str.strip()
print(df[LABEL_COL].value_counts())

df.head()

## 4) Split & Datasets

In [None]:
classes = sorted(df[LABEL_COL].unique().tolist())
print('Classes:', classes)
cls_to_idx = {c:i for i,c in enumerate(classes)}
idx_to_cls = {i:c for c,i in cls_to_idx.items()}

df_train, df_temp = train_test_split(df, test_size=0.30, stratify=df[LABEL_COL], random_state=42)
df_val, df_test   = train_test_split(df_temp, test_size=0.50, stratify=df_temp[LABEL_COL], random_state=42)
print('Train:', df_train.shape, 'Val:', df_val.shape, 'Test:', df_test.shape)

mean = [0.485, 0.456, 0.406]
std  = [0.229, 0.224, 0.225]

train_tfms = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.85, 1.0)),
    transforms.RandomRotation(5),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])
val_tfms = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize(int(IMAGE_SIZE * 1.14)),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

class XRayDataset(Dataset):
    def __init__(self, df, path_col, label_col, transforms=None, cls_to_idx=None):
        self.df = df.reset_index(drop=True)
        self.path_col = path_col
        self.label_col = label_col
        self.transforms = transforms
        self.cls_to_idx = cls_to_idx

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        row = self.df.iloc[i]
        p = row['resolved_path']
        y = self.cls_to_idx[row[self.label_col]]
        with Image.open(p) as img:
            img = img.convert('L')
        if self.transforms:
            img = self.transforms(img)
        return img, y

train_ds = XRayDataset(df_train, 'resolved_path', LABEL_COL, transforms=train_tfms, cls_to_idx=cls_to_idx)
val_ds   = XRayDataset(df_val,   'resolved_path', LABEL_COL, transforms=val_tfms,   cls_to_idx=cls_to_idx)
test_ds  = XRayDataset(df_test,  'resolved_path', LABEL_COL, transforms=val_tfms,   cls_to_idx=cls_to_idx)

sampler = None
if USE_WEIGHTED_SAMPLER:
    counts = Counter(df_train[LABEL_COL].tolist())
    class_count = torch.tensor([counts[c] for c in classes], dtype=torch.float)
    class_weight = 1.0 / class_count
    sample_weights = [class_weight[cls_to_idx[label]].item() for label in df_train[LABEL_COL].tolist()]
    sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=(sampler is None), sampler=sampler, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

len(train_ds), len(val_ds), len(test_ds)

## 5) Models

In [None]:
class SmallCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1,1)),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.3),
            nn.Linear(256, 128), nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

def create_model(name: str, num_classes: int):
    name = name.lower()
    if name == 'cnn_small':
        return SmallCNN(num_classes)
    elif name == 'resnet18':
        weights = models.ResNet18_Weights.DEFAULT
        model = models.resnet18(weights=weights)
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)
        return model
    elif name == 'efficientnet_b0':
        weights = models.EfficientNet_B0_Weights.DEFAULT
        model = models.efficientnet_b0(weights=weights)
        in_features = model.classifier[-1].in_features
        model.classifier[-1] = nn.Linear(in_features, num_classes)
        return model
    else:
        raise ValueError(f"Unknown model: {name}")

## 6) Train/Eval utils + MLflow logging

In [None]:
def plot_confusion(y_true, y_pred, labels, normalize=True, save_path=None):
    cm = confusion_matrix(y_true, y_pred, labels=range(len(labels)))
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    plt.figure(figsize=(6,5))
    sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd',
                xticklabels=labels, yticklabels=labels)
    plt.xlabel('Predicted'); plt.ylabel('True'); plt.title('Confusion Matrix')
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=180)
    plt.show()

def get_lr(optim_):
    return optim_.param_groups[0]['lr']

def train_and_log(model_name: str):
    # Create model
    model = create_model(model_name, num_classes=len(classes)).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)
    scaler = amp.GradScaler()

    # MLflow run
    run_name = f"{model_name}-IMG{IMAGE_SIZE}-B{BATCH_SIZE}"
    with mlflow.start_run(run_name=run_name, nested=False):
        # Params
        mlflow.log_params({
            'model_name': model_name,
            'image_size': IMAGE_SIZE,
            'batch_size': BATCH_SIZE,
            'epochs': EPOCHS,
            'lr': LR,
            'weight_decay': WEIGHT_DECAY,
            'use_weighted_sampler': USE_WEIGHTED_SAMPLER,
            'num_train': len(train_ds),
            'num_val': len(val_ds),
            'num_test': len(test_ds),
        })
        mlflow.set_tags({'framework': 'pytorch', 'task': 'xray-4class', 'device': str(device)})

        # Training loop
        best_acc = 0.0
        patience, bad = 5, 0
        history = {'epoch': [], 'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

        for epoch in range(1, EPOCHS+1):
            # Train
            model.train()
            total_loss, correct, total = 0.0, 0, 0
            for x, y in tqdm(train_loader, leave=False):
                x, y = x.to(device, non_blocking=True), torch.tensor(y, device=device)
                optimizer.zero_grad(set_to_none=True)
                with amp.autocast():
                    logits = model(x)
                    loss = criterion(logits, y)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()

                total_loss += loss.item() * x.size(0)
                preds = logits.argmax(1)
                correct += (preds == y).sum().item()
                total += x.size(0)
            train_loss = total_loss / total
            train_acc = correct / total

            # Val
            model.eval()
            total_loss, correct, total = 0.0, 0, 0
            with torch.no_grad():
                for x, y in tqdm(val_loader, leave=False):
                    x, y = x.to(device, non_blocking=True), torch.tensor(y, device=device)
                    logits = model(x)
                    loss = criterion(logits, y)
                    total_loss += loss.item() * x.size(0)
                    preds = logits.argmax(1)
                    correct += (preds == y).sum().item()
                    total += x.size(0)
            val_loss = total_loss / total
            val_acc = correct / total

            scheduler.step(val_acc)

            # Log metrics per epoch
            mlflow.log_metrics({'train_loss': train_loss, 'train_acc': train_acc,
                                'val_loss': val_loss, 'val_acc': val_acc}, step=epoch)

            history['epoch'].append(epoch)
            history['train_loss'].append(train_loss); history['train_acc'].append(train_acc)
            history['val_loss'].append(val_loss);     history['val_acc'].append(val_acc)

            print(f"Epoch {epoch}/{EPOCHS} | train: loss={train_loss:.4f} acc={train_acc:.4f} | val: loss={val_loss:.4f} acc={val_acc:.4f} | lr={get_lr(optimizer):.6f}")

            is_best = val_acc > best_acc
            if is_best:
                best_acc = val_acc
                bad = 0
                # Save best locally
                torch.save({'epoch': epoch, 'state_dict': model.state_dict(),
                            'val_acc': val_acc, 'classes': classes, 'model_name': model_name},
                           OUT_DIR / f'best_{model_name}.pt')
            else:
                bad += 1

            if bad >= patience:
                print(f"Early stopping. Best val acc: {best_acc:.4f}")
                break

        # Save and log history
        hist_path = OUT_DIR / f'history_{model_name}.json'
        with open(hist_path, 'w') as f:
            json.dump(history, f, indent=2)
        mlflow.log_artifact(str(hist_path))

        # Evaluate on test, log metrics + artifacts
        y_true, y_pred = [], []
        with torch.no_grad():
            for x, y in tqdm(test_loader, leave=False):
                x = x.to(device, non_blocking=True)
                logits = model(x)
                preds = logits.argmax(1).cpu().numpy().tolist()
                y_pred.extend(preds)
                y_true.extend(y)

        report = classification_report(y_true, y_pred, target_names=classes, output_dict=True, digits=4)
        # Flatten some main metrics
        test_acc = report['accuracy']
        mlflow.log_metrics({'test_accuracy': test_acc})
        # per-class f1
        for cls in classes:
            mlflow.log_metric(f"f1_{cls}", report[cls]['f1-score'])

        # Save & log confusion matrix figure
        fig_path = OUT_DIR / f'confusion_{model_name}.png'
        plot_confusion(y_true, y_pred, labels=classes, normalize=True, save_path=str(fig_path))
        mlflow.log_artifact(str(fig_path))

        # Save & log predictions CSV
        preds_path = OUT_DIR / f'preds_{model_name}.csv'
        pd.DataFrame({'y_true': [classes[i] for i in y_true],
                      'y_pred': [classes[i] for i in y_pred]}).to_csv(preds_path, index=False)
        mlflow.log_artifact(str(preds_path))

        # Log class list and data split sizes
        classes_path = OUT_DIR / f'classes_{model_name}.json'
        with open(classes_path, 'w') as f:
            json.dump({'classes': classes,
                       'num_train': len(train_ds), 'num_val': len(val_ds), 'num_test': len(test_ds)}, f, indent=2)
        mlflow.log_artifact(str(classes_path))

        # Log the trained model
        mlflow.pytorch.log_model(model, artifact_path="model")

        print(f"Run finished. Best val acc: {best_acc:.4f} | Test acc: {test_acc:.4f}")

## 7) Run experiments

In [None]:
for model_name in MODELS_TO_RUN:
    print("\n" + "="*24 + f" Running {model_name} " + "="*24)
    train_and_log(model_name)