In [None]:
# train_brain_tumor_with_earlystopping.py
import os
import random
from pathlib import Path
import math
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms, models, datasets
from sklearn.metrics import classification_report, confusion_matrix


TRAIN_DIR = r"C:\Users\ACER\Desktop\archive (3)\Training"
VAL_DIR   = r"C:\Users\ACER\Desktop\archive (3)\Testing"
ARCHIVE_DIR = Path(r"C:\Users\ACER\Desktop\archive (3)")  

NUM_CLASSES = 4
BATCH_SIZE = 32
EPOCHS = 10
IMAGE_SIZE = 224
BASE_MODEL_NAME = "resnet50"  
PRETRAINED = True
FREEZE_BACKBONE = False
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 1e-4
CLIP_GRAD_NORM = 1.0
NUM_WORKERS = min(8, os.cpu_count() or 4)
SEED = 42
OUTPUT_DIR = Path("outputs")
OUTPUT_DIR.mkdir(exist_ok=True)
ARCHIVE_DIR.mkdir(parents=True, exist_ok=True)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Early stopping / monitoring
MONITOR = "val_loss"  
EARLY_STOPPING_PATIENCE = 5
MIN_DELTA = 1e-4  


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(SEED)

def get_transforms(image_size=224):
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std =[0.229, 0.224, 0.225])
    ])
    val_transform = transforms.Compose([
        transforms.Resize(int(image_size*1.14)),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std =[0.229, 0.224, 0.225])
    ])
    return train_transform, val_transform

def make_dataloaders(train_dir, val_dir, batch_size, num_workers):
    train_tf, val_tf = get_transforms(IMAGE_SIZE)
    train_ds = datasets.ImageFolder(train_dir, transform=train_tf)
    val_ds = datasets.ImageFolder(val_dir, transform=val_tf)

    targets = [s[1] for s in train_ds.samples]
    class_sample_count = np.array([targets.count(c) for c in range(len(train_ds.classes))])
    class_sample_count = np.where(class_sample_count == 0, 1, class_sample_count)
    class_weights = 1.0 / class_sample_count
    samples_weight = np.array([class_weights[t] for t in targets])
    samples_weight = torch.from_numpy(samples_weight).double()
    sampler = WeightedRandomSampler(samples_weight, num_samples=len(samples_weight), replacement=True)

    train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler,
                              num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
                            num_workers=num_workers, pin_memory=True)

    print(f"Classes found: {train_ds.classes}")
    print("Training samples per class:", dict(zip(train_ds.classes, class_sample_count)))
    return train_loader, val_loader, train_ds.classes

def create_model(num_classes=4, backbone="resnet50", pretrained=True, freeze_backbone=False):
    if backbone == "resnet50":
        model = models.resnet50(pretrained=pretrained)
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)
    elif backbone == "resnet18":
        model = models.resnet18(pretrained=pretrained)
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)
    elif backbone == "efficientnet_b0":
        model = models.efficientnet_b0(pretrained=pretrained)
        in_features = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(in_features, num_classes)
    else:
        raise ValueError("Unsupported backbone")

    if freeze_backbone:
        for name, param in model.named_parameters():
            if "fc" in name or "classifier" in name:
                param.requires_grad = True
            else:
                param.requires_grad = False

    return model

class EarlyStopping:
    """
    Early stops the training if monitored metric doesn't improve after a given patience.
    When improvement happens it saves a checkpoint (.pth) and the full model (.h5) to disk.
    """
    def __init__(self, patience=7, min_delta=1e-4, mode="min",
                 checkpoint_dir=OUTPUT_DIR, archive_dir=ARCHIVE_DIR,
                 checkpoint_name="best_brain_tumor_model.pth",
                 fullmodel_name="trained_brain_tumor_model.h5"):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.checkpoint_dir = Path(checkpoint_dir)
        self.archive_dir = Path(archive_dir)
        self.checkpoint_path = self.checkpoint_dir / checkpoint_name
        self.fullmodel_path = self.archive_dir / fullmodel_name
        self.counter = 0
        self.best_epoch = None
        if mode == "min":
            self.best_score = math.inf
        elif mode == "max":
            self.best_score = -math.inf
        else:
            raise ValueError("mode must be 'min' or 'max'")

    def _is_improvement(self, current):
        if self.mode == "min":
            return (self.best_score - current) > self.min_delta
        else:
            return (current - self.best_score) > self.min_delta

    def step(self, current, model, optimizer, epoch, classes):
        """
        Call after each validation. Returns True if training should stop.
        If improvement -> saves checkpoint and resets counter.
        """
        improved = False
        if self._is_improvement(current):
            improved = True
            self.best_score = current
            self.best_epoch = epoch
            self.counter = 0
            # save checkpoint (.pth)
            torch.save({
                "epoch": epoch,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "monitor_value": current,
                "classes": classes
            }, str(self.checkpoint_path))
            # save full model object (.h5) to archive folder
            torch.save(model, str(self.fullmodel_path))
            print(f">>> New best { 'lowest' if self.mode=='min' else 'highest' } {MONITOR}: {current:.6f} at epoch {epoch+1}")
            print(f"Saved checkpoint: {self.checkpoint_path}")
            print(f"Saved full model object: {self.fullmodel_path}")
        else:
            self.counter += 1
            print(f"No improvement in {MONITOR}. EarlyStopping counter: {self.counter}/{self.patience}")
        return self.counter >= self.patience

def train_one_epoch(model, loader, criterion, optimizer, scaler, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    pbar = tqdm(loader, desc="Train", leave=False)
    for inputs, labels in pbar:
        inputs = inputs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=(scaler is not None)):
            outputs = model(inputs)
            loss = criterion(outputs, labels)

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRAD_NORM)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRAD_NORM)
            optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (preds == labels).sum().item()
        pbar.set_postfix(loss=f"{running_loss/total:.4f}", acc=f"{100.*correct/total:.2f}%")

    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

def validate_epoch(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        pbar = tqdm(loader, desc="Val", leave=False)
        for inputs, labels in pbar:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()

            all_preds.extend(preds.cpu().numpy().tolist())
            all_labels.extend(labels.cpu().numpy().tolist())

            pbar.set_postfix(loss=f"{running_loss/total:.4f}", acc=f"{100.*correct/total:.2f}%")

    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc, all_preds, all_labels

def plot_history(train_losses, val_losses, train_accs, val_accs, out_path):
    plt.figure(figsize=(12,4))
    plt.subplot(1,2,1)
    plt.plot(train_losses, label='train_loss')
    plt.plot(val_losses, label='val_loss')
    plt.xlabel('epoch'); plt.ylabel('loss'); plt.legend(); plt.title('Loss')

    plt.subplot(1,2,2)
    plt.plot(train_accs, label='train_acc')
    plt.plot(val_accs, label='val_acc')
    plt.xlabel('epoch'); plt.ylabel('accuracy (%)'); plt.legend(); plt.title('Accuracy')

    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()

def save_confusion_matrix(y_true, y_pred, classes, out_path):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(8,6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=classes, yticklabels=classes)
    plt.xlabel('Predicted'); plt.ylabel('True'); plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()

def main():
    print("Device:", DEVICE)
    train_loader, val_loader, classes = make_dataloaders(TRAIN_DIR, VAL_DIR, BATCH_SIZE, NUM_WORKERS)

    model = create_model(num_classes=NUM_CLASSES, backbone=BASE_MODEL_NAME, pretrained=PRETRAINED,
                         freeze_backbone=FREEZE_BACKBONE)
    model = model.to(DEVICE)

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.AdamW(params, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-6)
    criterion = nn.CrossEntropyLoss()
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None

    # Early stopping instance
    es_mode = "min" if MONITOR == "val_loss" else "max"
    early_stopper = EarlyStopping(patience=EARLY_STOPPING_PATIENCE,
                                  min_delta=MIN_DELTA,
                                  mode=es_mode,
                                  checkpoint_dir=OUTPUT_DIR,
                                  archive_dir=ARCHIVE_DIR)

    best_val_acc = 0.0
    history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}
    last_val_preds, last_val_labels = [], []

    for epoch in range(EPOCHS):
        print(f"\nEpoch [{epoch+1}/{EPOCHS}]")
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, scaler, DEVICE)
        val_loss, val_acc, val_preds, val_labels = validate_epoch(model, val_loader, criterion, DEVICE)

        # choose metric for monitoring
        monitored_value = val_loss if MONITOR == "val_loss" else val_acc

        scheduler.step()

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["train_acc"].append(train_acc)
        history["val_acc"].append(val_acc)

        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc:.2f}%")

        # Save last checkpoint (.pth)
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "val_acc": val_acc,
            "classes": classes
        }, OUTPUT_DIR / "last_checkpoint.pth")

        # Step early stopping 
        stop = early_stopper.step(monitored_value, model, optimizer, epoch, classes)

        # also keep track of best val_acc to print at the end
        if val_acc > best_val_acc:
            best_val_acc = val_acc

        last_val_preds, last_val_labels = val_preds, val_labels

        if stop:
            print(f"Early stopping triggered at epoch {epoch+1}. Best epoch: {early_stopper.best_epoch+1}")
            break

    # Save final model object (last epoch) to archive folder as well
    final_h5_path = ARCHIVE_DIR / "final_brain_tumor_model.h5"
    torch.save(model, str(final_h5_path))
    print(f"Saved final model object to: {final_h5_path}")

    # Save training history & plots
    plot_history(history["train_loss"], history["val_loss"],
                 history["train_acc"], history["val_acc"],
                 OUTPUT_DIR / "training_history.png")
    print("Training history saved to", OUTPUT_DIR / "training_history.png")

    # classification report & confusion matrix on last validation predictions
    if last_val_labels:
        print("\nClassification Report:")
        print(classification_report(last_val_labels, last_val_preds, target_names=classes, digits=4))
        save_confusion_matrix(last_val_labels, last_val_preds, classes, OUTPUT_DIR / "confusion_matrix.png")
        print("Confusion matrix saved to", OUTPUT_DIR / "confusion_matrix.png")
    else:
        print("No validation results to report.")

    print(f"\nTraining completed. Best {MONITOR}: {early_stopper.best_score}. Best epoch: {early_stopper.best_epoch+1 if early_stopper.best_epoch is not None else 'N/A'}")
    print("Checkpoints in", OUTPUT_DIR, "and full .h5 models in", ARCHIVE_DIR)

if __name__ == "__main__":
    main()


Device: cpu
Classes found: ['glioma', 'meningioma', 'notumor', 'pituitary']
Training samples per class: {'glioma': 1321, 'meningioma': 1339, 'notumor': 1595, 'pituitary': 1457}





Epoch [1/10]


                                                                                                                       

Train Loss: 0.2864 | Train Acc: 90.06%
Val   Loss: 0.2286 | Val   Acc: 92.60%
>>> New best lowest val_loss: 0.228582 at epoch 1
Saved checkpoint: outputs\best_brain_tumor_model.pth
Saved full model object: C:\Users\ACER\Desktop\archive (3)\trained_brain_tumor_model.h5

Epoch [2/10]


                                                                                                                       

Train Loss: 0.1747 | Train Acc: 94.36%
Val   Loss: 0.3274 | Val   Acc: 89.86%
No improvement in val_loss. EarlyStopping counter: 1/3

Epoch [3/10]


                                                                                                                       

Train Loss: 0.1241 | Train Acc: 95.66%
Val   Loss: 0.2484 | Val   Acc: 91.84%
No improvement in val_loss. EarlyStopping counter: 2/3

Epoch [4/10]


                                                                                                                       

Train Loss: 0.1109 | Train Acc: 96.29%
Val   Loss: 0.2357 | Val   Acc: 92.30%
No improvement in val_loss. EarlyStopping counter: 3/3
Early stopping triggered at epoch 4. Best epoch: 1
Saved final model object to: C:\Users\ACER\Desktop\archive (3)\final_brain_tumor_model.h5
Training history saved to outputs\training_history.png

Classification Report:
              precision    recall  f1-score   support

      glioma     0.9961    0.8533    0.9192       300
  meningioma     0.7972    0.9379    0.8619       306
     notumor     0.9463    1.0000    0.9724       405
   pituitary     0.9850    0.8733    0.9258       300

    accuracy                         0.9230      1311
   macro avg     0.9311    0.9161    0.9198      1311
weighted avg     0.9317    0.9230    0.9238      1311

Confusion matrix saved to outputs\confusion_matrix.png

Training completed. Best val_loss: 0.22858198353783443. Best epoch: 1
Checkpoints in outputs and full .h5 models in C:\Users\ACER\Desktop\archive (3)
