# X-Ray Classification Training Notebook

This notebook provides a complete pipeline for training the X-Ray classification model on Google Colab.

**Prerequisites:**
1.  Upload your cleaned dataset (as a zip file) to your Google Drive.
2.  Mount your Google Drive in this notebook.
3.  Update the `DATASET_ZIP_PATH` variable to point to your zip file.

In [None]:
# @title 1. Setup & Dependencies

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Install required libraries
!pip install timm albumentations torchmetrics tensorboard pyyaml

In [2]:
# @title 2. Dataset Extraction
import os
import zipfile

# --- CONFIGURATION ---
# UPDATE THIS PATH to where your zip file is located on Google Drive
DATASET_ZIP_PATH = '/content/drive/MyDrive/data.zip'
EXTRACT_PATH = '/content/data'
# ---------------------

if not os.path.exists(EXTRACT_PATH):
    print(f"Extracting dataset from {DATASET_ZIP_PATH} to {EXTRACT_PATH}...")
    with zipfile.ZipFile(DATASET_ZIP_PATH, 'r') as zip_ref:
        zip_ref.extractall(EXTRACT_PATH)
    print("Extraction complete.")
else:
    print(f"Dataset already exists at {EXTRACT_PATH}")

Extracting dataset from /content/drive/MyDrive/data.zip to /content/data...
Extraction complete.


In [3]:
# @title 3. Configuration
import torch

# Configuration dictionary
CONFIG = {
    "paths": {
        "data_root": "/content/data",
        "models_dir": "/content/drive/MyDrive/CNN/models",
        "logs_dir": "/content/drive/MyDrive/CNN/logs"
    },
    "dataset": {
        "splits": ["train", "val", "test"],
        "classes": ["normal", "pneumonia", "tuberculosis"],
        "image_exts": [".png", ".jpg", ".jpeg"],
        "target_image_size": [512, 512]
    },
    "data_module": {
        "batch_size": 16,
        "num_workers": 2,  # Colab usually has 2 CPUs
        "pin_memory": True,
        "normalization_mean": [0.485, 0.456, 0.406],
        "normalization_std": [0.229, 0.224, 0.225],
        "use_weighted_sampler": False
    },
    "augmentation": {
        "random_resized_crop": {
            "scale": [0.70, 1.0],
            "ratio": [0.95, 1.05]
        },
        "horizontal_flip": False,
        "rotation_degrees": 12,
        "color_jitter": [0.1, 0.1, 0.1, 0.05],
        "color_jitter_prob": 0.7
    },
    "training": {
        "model_name": "tf_efficientnetv2_s",
        "pretrained": True,
        "batch_size": 16,
        "epochs": 30,
        "learning_rate": 3e-4,
        "weight_decay": 0.01,
        "optimizer": "adamw",
        "scheduler": "reduce_on_plateau",
        "patience": 3,
        "factor": 0.5,
        "min_lr": 1e-6,
        "label_smoothing": 0.05,
        "use_class_weights": True,
        "class_weights": [0.9385, 1.4584, 0.8007]
    },
    "evaluation": {
        "metrics": ["accuracy", "precision", "recall", "f1"],
        "confusion_matrix": True,
        "roc_auc": True,
        "primary_metric": "accuracy"
    }
}

# Create directories
import os
os.makedirs(CONFIG["paths"]["models_dir"], exist_ok=True)
os.makedirs(CONFIG["paths"]["logs_dir"], exist_ok=True)

In [4]:
# @title 4. Utilities & Imports
import random
import numpy as np
import torch
import torch.nn as nn
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Tuple, List, Callable
import warnings
from dataclasses import dataclass
from collections import Counter

def seed_everything(seed: int) -> None:
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def prepare_device(force_cpu: bool = False) -> torch.device:
    """Select the device for training."""
    if force_cpu or not torch.cuda.is_available():
        return torch.device("cpu")
    return torch.device("cuda")

seed_everything(42)
DEVICE = prepare_device()
print(f"Using device: {DEVICE}")

Using device: cpu


In [5]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torchvision.datasets import ImageFolder

@dataclass
class TransformConfig:
    image_size: Tuple[int, int]
    normalization_mean: Tuple[float, float, float]
    normalization_std: Tuple[float, float, float]
    augmentation_cfg: Dict

class TransformFactory:
    def __init__(self, config: TransformConfig) -> None:
        self.config = config

    def build_train_transforms(self) -> A.Compose:
        augment_cfg = self.config.augmentation_cfg or {}
        image_size = self.config.image_size
        transforms = []

        if augment_cfg.get("random_resized_crop"):
            params = augment_cfg["random_resized_crop"]
            transforms.append(
                A.RandomResizedCrop(
                    size=image_size, # Changed to pass size as a tuple
                    scale=tuple(params.get("scale", (0.9, 1.0))),
                    ratio=tuple(params.get("ratio", (0.9, 1.1))),
                )
            )
        else:
            transforms.append(A.Resize(height=image_size[0], width=image_size[1]))

        if augment_cfg.get("horizontal_flip"):
            transforms.append(A.HorizontalFlip(p=0.5))

        if augment_cfg.get("rotation_degrees"):
            transforms.append(A.Rotate(limit=augment_cfg["rotation_degrees"], p=0.5))

        if augment_cfg.get("color_jitter"):
            brightness, contrast, saturation, hue = augment_cfg["color_jitter"]
            transforms.append(
                A.ColorJitter(
                    brightness=brightness, contrast=contrast, saturation=saturation, hue=hue,
                    p=augment_cfg.get("color_jitter_prob", 0.8),
                )
            )

        transforms.extend(self._common_transforms())
        return A.Compose(transforms)

    def build_eval_transforms(self) -> A.Compose:
        image_size = self.config.image_size
        transforms = [A.Resize(height=image_size[0], width=image_size[1])]
        transforms.extend(self._common_transforms())
        return A.Compose(transforms)

    def _common_transforms(self):
        return [
            A.Normalize(
                mean=self.config.normalization_mean,
                std=self.config.normalization_std,
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ]

class AlbumentationsDataset(Dataset):
    def __init__(self, dataset: ImageFolder, transform: A.Compose) -> None:
        self.dataset = dataset
        self.transform = transform

    def __getitem__(self, index: int):
        path, label = self.dataset.samples[index]
        with Image.open(path) as img:
            image = img.convert("RGB")
        image_np = np.asarray(image)
        transformed = self.transform(image=image_np)
        return transformed["image"], label

    def __len__(self) -> int:
        return len(self.dataset)

@dataclass
class DataModuleConfig:
    data_root: Path
    classes: Sequence[str]
    batch_size: int
    num_workers: int
    pin_memory: bool
    image_size: Tuple[int, int]
    normalization_mean: Tuple[float, float, float]
    normalization_std: Tuple[float, float, float]
    augmentation_cfg: Optional[Dict]
    allowed_exts: Sequence[str]
    use_weighted_sampler: bool = False

class DataModule:
    def __init__(self, config: DataModuleConfig) -> None:
        self.config = config
        transform_cfg = TransformConfig(
            image_size=config.image_size,
            normalization_mean=config.normalization_mean,
            normalization_std=config.normalization_std,
            augmentation_cfg=config.augmentation_cfg or {},
        )
        self.transform_factory = TransformFactory(transform_cfg)
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None
        self.class_to_idx = {}

    def setup(self) -> None:
        splits = {
            "train": self.transform_factory.build_train_transforms,
            "val": self.transform_factory.build_eval_transforms,
            "test": self.transform_factory.build_eval_transforms,
        }

        for split, transform_builder in splits.items():
            split_dir = self.config.data_root / split
            if not split_dir.is_dir():
                continue

            base_dataset = ImageFolder(split_dir)
            base_dataset = self._filter_by_extension(base_dataset)
            base_dataset = self._align_class_indices(base_dataset, split)

            transform = transform_builder()
            alb_dataset = AlbumentationsDataset(base_dataset, transform)

            if split == "train":
                self.train_dataset = alb_dataset
            elif split == "val":
                self.val_dataset = alb_dataset
            else:
                self.test_dataset = alb_dataset

    def train_dataloader(self) -> DataLoader:
        sampler = None
        if self.config.use_weighted_sampler:
            targets = [label for _, label in self.train_dataset.dataset.samples]
            class_counts = Counter(targets)
            num_classes = len(class_counts)
            total = len(targets)
            class_weights = {cls: total / (num_classes * count) for cls, count in class_counts.items()}
            sample_weights = [class_weights[label] for label in targets]
            sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

        return DataLoader(
            self.train_dataset,
            batch_size=self.config.batch_size,
            shuffle=sampler is None,
            sampler=sampler,
            num_workers=self.config.num_workers,
            pin_memory=self.config.pin_memory,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset,
            batch_size=self.config.batch_size,
            shuffle=False,
            num_workers=self.config.num_workers,
            pin_memory=self.config.pin_memory,
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.test_dataset,
            batch_size=self.config.batch_size,
            shuffle=False,
            num_workers=self.config.num_workers,
            pin_memory=self.config.pin_memory,
        )

    def _filter_by_extension(self, dataset: ImageFolder) -> ImageFolder:
        allowed = {ext.lower() for ext in self.config.allowed_exts}
        filtered_samples = [(p, l) for p, l in dataset.samples if Path(p).suffix.lower() in allowed]
        dataset.samples = filtered_samples
        dataset.imgs = filtered_samples
        dataset.targets = [l for _, l in filtered_samples]
        return dataset

    def _align_class_indices(self, dataset: ImageFolder, split: str) -> ImageFolder:
        if not self.class_to_idx:
            self.class_to_idx = dataset.class_to_idx.copy()
            return dataset

        remapped_samples = []
        for path, _ in dataset.samples:
            class_name = Path(path).parent.name
            if class_name in self.class_to_idx:
                remapped_samples.append((path, self.class_to_idx[class_name]))

        dataset.samples = remapped_samples
        dataset.imgs = remapped_samples
        dataset.targets = [l for _, l in remapped_samples]
        dataset.class_to_idx = self.class_to_idx
        dataset.classes = list(self.class_to_idx.keys())
        return dataset

In [6]:
# @title 6. Model Building
import timm

def build_model(model_cfg: Dict[str, Any], num_classes: int) -> torch.nn.Module:
    model_name = model_cfg["model_name"]
    pretrained = bool(model_cfg.get("pretrained", True))

    print(f"Building model: {model_name} (pretrained={pretrained})")
    model = timm.create_model(
        model_name,
        pretrained=pretrained,
        num_classes=num_classes,
    )
    return model

In [7]:
# @title 7. Training Components (Loss, Metrics, Optimizer)
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassAUROC, MulticlassF1Score, MulticlassPrecision, MulticlassRecall

def build_loss(training_cfg: Dict[str, Any]) -> nn.Module:
    label_smoothing = float(training_cfg.get("label_smoothing", 0.0))
    weight_tensor = None
    if training_cfg.get("use_class_weights"):
        custom_weights = training_cfg.get("class_weights")
        if custom_weights:
            weight_tensor = torch.tensor(custom_weights, dtype=torch.float32)
    return nn.CrossEntropyLoss(weight=weight_tensor, label_smoothing=label_smoothing)

def build_metrics(evaluation_cfg: Dict[str, Any], num_classes: int, device: torch.device) -> MetricCollection:
    metric_names = evaluation_cfg.get("metrics", [])
    instances = {}
    supported = {
        "accuracy": MulticlassAccuracy,
        "precision": MulticlassPrecision,
        "recall": MulticlassRecall,
        "f1": MulticlassF1Score,
    }
    for name in metric_names:
        key = name.lower()
        if key in supported:
            instances[key] = supported[key](num_classes=num_classes, average="macro")

    if evaluation_cfg.get("roc_auc", False):
        instances["roc_auc"] = MulticlassAUROC(num_classes=num_classes)

    if not instances:
        instances["accuracy"] = MulticlassAccuracy(num_classes=num_classes)

    return MetricCollection(instances).to(device)

def build_optimizer(model: torch.nn.Module, training_cfg: Dict[str, Any]) -> torch.optim.Optimizer:
    name = training_cfg.get("optimizer", "adam").lower()
    lr = float(training_cfg.get("learning_rate", 1e-3))
    weight_decay = float(training_cfg.get("weight_decay", 0.0))
    params = filter(lambda p: p.requires_grad, model.parameters())

    if name == "adamw":
        return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)
    return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)

def build_scheduler(optimizer: torch.optim.Optimizer, training_cfg: Dict[str, Any], epochs: int):
    name = training_cfg.get("scheduler", "none").lower()
    if name == "reduce_on_plateau":
        return torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            patience=int(training_cfg.get("patience", 3)),
            factor=float(training_cfg.get("factor", 0.1))
        )
    return None

In [8]:
# @title 8. Trainer Class
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm

class Trainer:
    def __init__(self, model, device, loss_fn, optimizer, metric_collection, scheduler=None, epochs=25, checkpoint_dir=None, writer=None):
        self.model = model
        self.device = device
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.metric_template = metric_collection
        self.epochs = epochs
        self.checkpoint_dir = checkpoint_dir
        self.writer = writer
        self.best_metric = float("-inf")
        self.model.to(self.device)

    def fit(self, train_loader, val_loader=None):
        history = {}
        for epoch in range(1, self.epochs + 1):
            print(f"\nEpoch {epoch}/{self.epochs}")
            train_loss, train_metrics = self._run_epoch(train_loader, train=True)

            val_loss, val_metrics = None, {}
            if val_loader:
                val_loss, val_metrics = self._run_epoch(val_loader, train=False)

            # Logging
            current_lr = self.optimizer.param_groups[0]['lr']
            print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
            print(f"Train Acc: {train_metrics.get('train_accuracy', 0):.4f} | Val Acc: {val_metrics.get('val_accuracy', 0):.4f}")

            if self.writer:
                self.writer.add_scalar("Loss/train", train_loss, epoch)
                self.writer.add_scalar("Loss/val", val_loss, epoch)
                self.writer.add_scalar("Accuracy/train", train_metrics.get('train_accuracy', 0), epoch)
                self.writer.add_scalar("Accuracy/val", val_metrics.get('val_accuracy', 0), epoch)

            # Save Best Model
            if val_metrics.get('val_accuracy', 0) > self.best_metric:
                self.best_metric = val_metrics['val_accuracy']
                self._save_checkpoint(epoch, is_best=True)

            # Scheduler Step
            if self.scheduler:
                if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                    self.scheduler.step(val_loss)
                else:
                    self.scheduler.step()

    def _run_epoch(self, loader, train):
        mode = "train" if train else "val"
        self.model.train() if train else self.model.eval()

        metric_collection = self.metric_template.clone(prefix=f"{mode}_")
        metric_collection.reset()
        total_loss = 0.0
        total_items = 0

        for images, targets in tqdm(loader, desc=mode, leave=False):
            images, targets = images.to(self.device), targets.to(self.device)

            with torch.set_grad_enabled(train):
                logits = self.model(images)
                loss = self.loss_fn(logits, targets)

            if train:
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            total_loss += loss.item() * images.size(0)
            total_items += images.size(0)
            probs = torch.softmax(logits.detach(), dim=1)
            metric_collection.update(probs, targets)

        avg_loss = total_loss / max(total_items, 1)
        metrics = metric_collection.compute()
        metrics_dict = {k: v.item() for k, v in metrics.items()}
        return avg_loss, metrics_dict

    def _save_checkpoint(self, epoch, is_best=False):
        if not self.checkpoint_dir: return
        state = {
            "epoch": epoch,
            "model_state": self.model.state_dict(),
            "optimizer_state": self.optimizer.state_dict(),
            "config": CONFIG
        }
        path = Path(self.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pth"
        torch.save(state, path)
        if is_best:
            torch.save(state, Path(self.checkpoint_dir) / "best_model.pth")
            print(f"Saved new best model to {Path(self.checkpoint_dir) / 'best_model.pth'}")

In [None]:
# @title 9. Run Training

# Initialize Data Module
dm_config = DataModuleConfig(
    data_root=Path(CONFIG["paths"]["data_root"]),
    classes=tuple(CONFIG["dataset"]["classes"]),
    batch_size=CONFIG["training"]["batch_size"],
    num_workers=CONFIG["data_module"]["num_workers"],
    pin_memory=CONFIG["data_module"]["pin_memory"],
    image_size=tuple(CONFIG["dataset"]["target_image_size"]),
    normalization_mean=tuple(CONFIG["data_module"]["normalization_mean"]),
    normalization_std=tuple(CONFIG["data_module"]["normalization_std"]),
    augmentation_cfg=CONFIG["augmentation"],
    allowed_exts=tuple(CONFIG["dataset"]["image_exts"]),
    use_weighted_sampler=CONFIG["data_module"]["use_weighted_sampler"]
)

data_module = DataModule(dm_config)
data_module.setup()

train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()
test_loader = data_module.test_dataloader()

print(f"Classes: {data_module.class_to_idx}")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

# Initialize Model & Training Components
model = build_model(CONFIG["training"], num_classes=len(CONFIG["dataset"]["classes"]))
loss_fn = build_loss(CONFIG["training"])
metrics = build_metrics(CONFIG["evaluation"], num_classes=len(CONFIG["dataset"]["classes"]), device=DEVICE)
optimizer = build_optimizer(model, CONFIG["training"])
scheduler = build_scheduler(optimizer, CONFIG["training"], epochs=CONFIG["training"]["epochs"])

# TensorBoard
%load_ext tensorboard
%tensorboard --logdir {CONFIG['paths']['logs_dir']}
writer = SummaryWriter(log_dir=CONFIG["paths"]["logs_dir"])

# Start Training
trainer = Trainer(
    model=model,
    device=DEVICE,
    loss_fn=loss_fn,
    optimizer=optimizer,
    metric_collection=metrics,
    scheduler=scheduler,
    epochs=CONFIG["training"]["epochs"],
    checkpoint_dir=CONFIG["paths"]["models_dir"],
    writer=writer
)

trainer.fit(train_loader, val_loader)

In [None]:
# @title 10. Run Testing

import torch
from pathlib import Path

# Load the best model
# checkpoint = torch.load(Path(CONFIG["paths"]["models_dir"]) / "best_model.pth", map_location=torch.device('cpu')) #cpu usage comment out if using gpu
checkpoint = torch.load(Path(CONFIG["paths"]["models_dir"]) / "best_model.pth")#gpu usage comment out if using cpu

model.load_state_dict(checkpoint["model_state"])
model.eval()

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
from tqdm.notebook import tqdm

device = DEVICE
model.to(device)
model.eval()

all_preds = []
all_labels = []
all_probs = []

class_names = CONFIG["dataset"]["classes"]
print(f"Test started ({len(test_loader.dataset)} images) for evaluation")

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing"):
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)

        probs = torch.softmax(outputs, dim=1)
        _, preds = torch.max(outputs, 1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())

print("\n" + "="*50)
print("Evaluation Report(Best Model)")
print("="*50)
print(classification_report(all_labels, all_preds, target_names=class_names, digits=4))

# Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
heatmap = sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names,
            annot_kws={"size": 14})

plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.title('Confusion Matrix - Test Set', fontsize=15)
plt.show()

accuracy = np.sum(np.array(all_preds) == np.array(all_labels)) / len(all_labels)
print(f"\n✅ Test Accuracy: {accuracy:.2%}")

Test started (1160 images) for evaluation


Testing:   0%|          | 0/73 [00:00<?, ?it/s]

