# 🔍 Enhanced MRI Brain Tumor Classification

This notebook builds a ResNet50V2-based classifier that distinguishes between four brain MRI categories:
**Glioma · Meningioma · Pituitary Tumor · Normal**

A key design goal is robustness on *low-quality* images, simulating conditions in under-resourced healthcare settings.

---
## Table of Contents
1. [Setup & Imports](#1-setup--imports)
2. [Dataset Analysis](#2-dataset-analysis)
3. [Dataset Preparation](#3-dataset-preparation)
   - 3.1 [Build Image DataFrame](#31-build-image-dataframe)
   - 3.2 [Train / Validation Split](#32-train--validation-split)
   - 3.3 [Image Degradation Augmentation](#33-image-degradation-augmentation)
   - 3.4 [Data Generators](#34-data-generators)
4. [Model Architecture](#4-model-architecture)
5. [Training & Fine-Tuning](#5-training--fine-tuning)
6. [Performance Assessment](#6-performance-assessment)
   - 6.1 [Learning Curves](#61-learning-curves)
   - 6.2 [Classification Report & Confusion Matrix](#62-classification-report--confusion-matrix)
7. [Save Model for Deployment](#7-save-model-for-deployment)


## 1  Setup & Imports

In [None]:
import warnings
warnings.filterwarnings('ignore')

# Standard library
import os
import random
import subprocess

# Third-party
import cv2
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

# Scikit-learn
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

# TensorFlow / Keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.resnet_v2 import preprocess_input
from tensorflow.keras.applications import ResNet50V2
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
from keras.utils import plot_model

# IPython helpers
from IPython.display import FileLink, display

# Reproducibility
SEED = 42
np.random.seed(SEED)
random.seed(SEED)

# Plot styling
sns.set(rc={'axes.facecolor': '#e9eef2'}, style='darkgrid')


## 2  Dataset Analysis

In [None]:
# ── Copy dataset from Kaggle input to working directory ──────────────────
!cp -r /kaggle/input/mri-images/Data /kaggle/working/

BASE_DIR = '/kaggle/working/Data'


### 2.1  Class distribution

In [None]:
# Collect class names and image counts
classes = [
    name for name in os.listdir(BASE_DIR)
    if os.path.isdir(os.path.join(BASE_DIR, name))
]

counts = [
    len(os.listdir(os.path.join(BASE_DIR, cls)))
    for cls in classes
]

total = sum(counts)
percentages = [c / total * 100 for c in counts]

# Plot
fig, ax = plt.subplots(figsize=(14, 4))
sns.barplot(y=classes, x=counts, orient='h', color='#102C42', ax=ax)

ax.set_xticks(range(0, max(counts) + 300, 100))
ax.set_xlabel('Number of Images', fontsize=13)
ax.set_title('Images per Class', fontsize=16)

for i, patch in enumerate(ax.patches):
    ax.text(
        patch.get_width() + 5,
        patch.get_y() + patch.get_height() / 2,
        f'{percentages[i]:.1f}%  ({counts[i]})',
        va='center', fontsize=13
    )

plt.tight_layout()
plt.show()

print(f"Total images: {total}")


> **Note:** The `normal` class is the smallest — worth monitoring for class imbalance during training.

### 2.2  Image dimensions

In [None]:
heights, widths = [], []
unique_dims = set()

for cls in classes:
    folder = os.path.join(BASE_DIR, cls)
    for fname in os.listdir(folder):
        img = cv2.imread(os.path.join(folder, fname))
        if img is not None:
            h, w = img.shape[:2]
            unique_dims.add((h, w))
            heights.append(h)
            widths.append(w)

if len(unique_dims) == 1:
    print(f"All images share the same dimensions: {list(unique_dims)[0]}")
else:
    print(f"{len(unique_dims)} unique dimension(s) found.")
    print(f"Height — min: {min(heights)}, max: {max(heights)}, mean: {np.mean(heights):.1f}")
    print(f"Width  — min: {min(widths)},  max: {max(widths)},  mean: {np.mean(widths):.1f}")


### 2.3  Visual sample per class

In [None]:
def plot_image_grid(image_paths, title, n_cols=6):
    """Display a row of images with a shared title."""
    fig, axes = plt.subplots(1, n_cols, figsize=(15, 3))
    for ax, path in zip(axes, image_paths):
        img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
        ax.imshow(img)
        ax.axis('off')
    fig.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.show()


for cls in classes:
    folder = os.path.join(BASE_DIR, cls)
    all_paths = [os.path.join(folder, f) for f in os.listdir(folder)]
    sample = np.random.choice(all_paths, 6, replace=False)
    plot_image_grid(sample, f"{cls}  —  original samples")


## 3  Dataset Preparation

### 3.1  Build Image DataFrame

In [None]:
# Collect (filepath, label) pairs for every image in every class
records = [
    (os.path.join(BASE_DIR, cls, fname), cls)
    for cls in classes
    for fname in os.listdir(os.path.join(BASE_DIR, cls))
    if os.path.isfile(os.path.join(BASE_DIR, cls, fname))
]

df = pd.DataFrame(records, columns=['filepath', 'label'])
print(f"Total images in DataFrame: {len(df)}")
df.head()


### 3.2  Train / Validation Split

In [None]:
# 80 / 20 stratified split so class proportions are preserved
train_df, val_df = train_test_split(
    df,
    test_size=0.2,
    stratify=df['label'],
    random_state=SEED
)

print(f"Train size : {len(train_df)}")
print(f"Val size   : {len(val_df)}")

del df  # free memory


### 3.3  Image Degradation Augmentation

To make the model robust to poor imaging conditions we synthetically degrade a copy of every
image using one or more of three techniques:

| Technique | Simulates |
|-----------|-----------|
| **Gaussian noise** | sensor noise / imaging artefacts |
| **Gaussian blur** | patient movement / focus issues |
| **Downsample → upsample** | low-resolution scanners |

The degraded copies are merged back into the dataset *after* the train/val split to prevent data leakage.


In [None]:
# ── Degradation helpers ───────────────────────────────────────────────────

def add_gaussian_noise(image, mean=0, std=0.05):
    """Add Gaussian noise; output is clipped to [0, 255]."""
    noise = np.random.normal(mean, std, image.shape)
    return np.clip(image + noise, 0, 255)


def apply_gaussian_blur(image, kernel_size=5):
    """Apply Gaussian blur with the given (square) kernel size."""
    return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)


def downsample_upsample(image, scale_percent=50):
    """Shrink the image then restore its original size, reducing apparent resolution."""
    h, w = image.shape[:2]
    small_w = max(1, int(w * scale_percent / 100))
    small_h = max(1, int(h * scale_percent / 100))
    small = cv2.resize(image, (small_w, small_h), interpolation=cv2.INTER_AREA)
    return cv2.resize(small, (w, h), interpolation=cv2.INTER_LINEAR)


DEGRADATION_METHODS = {
    'noise':      add_gaussian_noise,
    'blur':       apply_gaussian_blur,
    'downsample': downsample_upsample,
}


In [None]:
def augment_with_degraded_copies(dataframe):
    """
    For every image in `dataframe`, create a degraded copy using a random
    subset of degradation methods, save it alongside the original, and
    return a new DataFrame that includes both the originals and the copies.
    
    Parameters
    ----------
    dataframe : pd.DataFrame
        Must have columns ['filepath', 'label'].
    
    Returns
    -------
    pd.DataFrame
        Combined original + degraded rows, index reset.
    """
    degraded_records = []

    for _, row in dataframe.iterrows():
        img = cv2.imread(row['filepath'], cv2.IMREAD_GRAYSCALE)
        if img is None:
            continue  # skip unreadable files

        # Pick 1–3 degradation methods at random
        chosen = random.sample(
            list(DEGRADATION_METHODS.keys()),
            k=random.randint(1, len(DEGRADATION_METHODS))
        )

        for method_name in chosen:
            img = DEGRADATION_METHODS[method_name](img)

        # Save degraded image next to original with a prefix
        directory, filename = os.path.split(row['filepath'])
        degraded_path = os.path.join(directory, f"degraded_{filename}")
        cv2.imwrite(degraded_path, img)

        degraded_records.append({'filepath': degraded_path, 'label': row['label']})

    degraded_df = pd.DataFrame(degraded_records)
    return pd.concat([dataframe, degraded_df], ignore_index=True)


In [None]:
print(f"Before augmentation — train: {len(train_df)}, val: {len(val_df)}")

train_df = augment_with_degraded_copies(train_df)
val_df   = augment_with_degraded_copies(val_df)

print(f"After  augmentation — train: {len(train_df)}, val: {len(val_df)}")


In [None]:
# Visual check: confirm tumour features are still visible after degradation
for cls in classes:
    folder = os.path.join(BASE_DIR, cls)
    all_paths = [os.path.join(folder, f) for f in os.listdir(folder)]
    sample = np.random.choice(all_paths, 6, replace=False)
    plot_image_grid(sample, f"{cls}  —  post-degradation samples")


### 3.4  Data Generators

We use Keras's `ImageDataGenerator` for memory-efficient on-the-fly augmentation during training.
The **validation generator** intentionally skips geometric augmentation to give an unbiased estimate of performance.


In [None]:
# ── Hyperparameters ───────────────────────────────────────────────────────
IMAGE_SIZE  = (224, 224)   # ResNet50V2 expects 224×224
BATCH_SIZE  = 32


def create_data_generators(train_df, val_df,
                            preprocessing_fn=None,
                            batch_size=BATCH_SIZE,
                            image_size=IMAGE_SIZE):
    """
    Build and return (train_generator, val_generator).

    Training generator applies light geometric augmentation to improve
    generalisation; validation generator only applies preprocessing.

    Parameters
    ----------
    train_df : pd.DataFrame
        Columns: ['filepath', 'label']
    val_df : pd.DataFrame
        Columns: ['filepath', 'label']
    preprocessing_fn : callable, optional
        Model-specific preprocessing (e.g. ResNet50V2's preprocess_input).
    batch_size : int
    image_size : tuple of (height, width)

    Returns
    -------
    train_generator, val_generator
    """
    shared_flow_kwargs = dict(
        x_col='filepath',
        y_col='label',
        target_size=image_size,
        batch_size=batch_size,
        class_mode='categorical',
        seed=SEED,
    )

    train_datagen = ImageDataGenerator(
        rotation_range=20,
        width_shift_range=0.10,
        height_shift_range=0.10,
        zoom_range=0.10,
        horizontal_flip=True,
        preprocessing_function=preprocessing_fn,
    )

    val_datagen = ImageDataGenerator(preprocessing_function=preprocessing_fn)

    train_gen = train_datagen.flow_from_dataframe(
        dataframe=train_df, shuffle=True, **shared_flow_kwargs
    )
    val_gen = val_datagen.flow_from_dataframe(
        dataframe=val_df, shuffle=False, **shared_flow_kwargs
    )

    return train_gen, val_gen


In [None]:
train_generator, val_generator = create_data_generators(
    train_df, val_df, preprocessing_fn=preprocess_input
)

# Sanity-check: confirm expected batch shape (batch_size × 224 × 224 × 3)
sample_batch, _ = next(train_generator)
print(f"Batch shape: {sample_batch.shape}")

# Class index mapping
class_names = sorted(train_generator.class_indices, key=train_generator.class_indices.get)
print(f"Classes: {class_names}")


> **Why 3-channel grayscale?**  MRI scans are naturally greyscale, but we load them as 3-channel images so they match the input format expected by ImageNet pre-trained models.

## 4  Model Architecture

We use **ResNet50V2** pre-trained on ImageNet as a feature extractor, then attach a small
classification head. Transfer learning lets us achieve high accuracy despite a relatively small
dataset while avoiding overfitting.


In [None]:
# ── Base model ────────────────────────────────────────────────────────────
base_model = ResNet50V2(
    weights='imagenet',
    include_top=False,
    input_shape=(*IMAGE_SIZE, 3)
)

# ── Classification head ───────────────────────────────────────────────────
x = GlobalAveragePooling2D()(base_model.output)
x = Dense(1024, activation='relu')(x)
x = Dropout(0.5)(x)                          # regularisation
output = Dense(4, activation='softmax')(x)   # 4 tumour classes

model = Model(inputs=base_model.input, outputs=output)

model.compile(
    optimizer=Adam(learning_rate=1e-4),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

model.summary()


In [None]:
# Architecture diagram (may be large — scroll down)
plot_model(model, show_shapes=True, show_layer_names=False, dpi=150)


## 5  Training & Fine-Tuning

In [None]:
# ── Training hyperparameters ──────────────────────────────────────────────
MAX_EPOCHS     = 50
LR_PATIENCE    = 3    # epochs without val_loss improvement before halving LR
EARLY_PATIENCE = 15   # epochs without improvement before stopping


def train_model(model, train_df, val_df,
                preprocessing_fn,
                image_size=IMAGE_SIZE,
                batch_size=BATCH_SIZE,
                max_epochs=MAX_EPOCHS):
    """
    Train `model` and return the trained model, history, and val_generator.

    Callbacks
    ---------
    ReduceLROnPlateau : halves the learning rate after `LR_PATIENCE` stagnant epochs.
    EarlyStopping     : restores the best weights and stops after `EARLY_PATIENCE` epochs.
    """
    train_gen, val_gen = create_data_generators(
        train_df, val_df, preprocessing_fn, batch_size, image_size
    )

    callbacks = [
        ReduceLROnPlateau(
            monitor='val_loss', factor=0.5,
            patience=LR_PATIENCE, min_lr=1e-5
        ),
        EarlyStopping(
            monitor='val_loss', mode='min',
            patience=EARLY_PATIENCE,
            restore_best_weights=True, verbose=1
        ),
    ]

    history = model.fit(
        train_gen,
        steps_per_epoch=len(train_gen),
        epochs=max_epochs,
        validation_data=val_gen,
        validation_steps=len(val_gen),
        callbacks=callbacks,
    )

    return model, history, val_gen


In [None]:
model, history, val_generator = train_model(
    model, train_df, val_df,
    preprocessing_fn=preprocess_input
)


## 6  Performance Assessment

### 6.1  Learning Curves

In [None]:
def plot_learning_curves(history, skip_first_n=5):
    """
    Plot train / validation loss and accuracy over epochs.

    Parameters
    ----------
    history : keras History object
    skip_first_n : int
        Skip the first N epochs to keep the y-axis scale readable
        (early epochs can have very high loss).
    """
    hist_df = pd.DataFrame(history.history).iloc[skip_first_n:]

    sns.set(rc={'axes.facecolor': '#f0f0fc'}, style='darkgrid')
    fig, (ax_loss, ax_acc) = plt.subplots(1, 2, figsize=(14, 5))

    # Loss
    sns.lineplot(data=hist_df, x=hist_df.index, y='loss',
                 color='#102C42', label='Train', ax=ax_loss)
    sns.lineplot(data=hist_df, x=hist_df.index, y='val_loss',
                 color='orangered', linestyle='--', label='Validation', ax=ax_loss)
    ax_loss.set_title('Loss', fontsize=14)
    ax_loss.set_xlabel('Epoch')

    # Accuracy
    sns.lineplot(data=hist_df, x=hist_df.index, y='accuracy',
                 color='#102C42', label='Train', ax=ax_acc)
    sns.lineplot(data=hist_df, x=hist_df.index, y='val_accuracy',
                 color='orangered', linestyle='--', label='Validation', ax=ax_acc)
    ax_acc.set_title('Accuracy', fontsize=14)
    ax_acc.set_xlabel('Epoch')

    plt.suptitle('Learning Curves', fontsize=16, y=1.02)
    plt.tight_layout()
    plt.show()


plot_learning_curves(history)


### 6.2  Classification Report & Confusion Matrix

In [None]:
def evaluate_model(model, val_generator):
    """
    Print classification report and plot confusion matrix for the validation set.

    Parameters
    ----------
    model : trained Keras model
    val_generator : validation ImageDataGenerator flow
    """
    class_labels = sorted(
        val_generator.class_indices, key=val_generator.class_indices.get
    )

    # Predictions
    raw_preds    = model.predict(val_generator, steps=len(val_generator))
    pred_labels  = np.argmax(raw_preds, axis=1)
    true_labels  = val_generator.classes

    # Classification report
    print(classification_report(true_labels, pred_labels, target_names=class_labels))

    # Confusion matrix
    cm = confusion_matrix(true_labels, pred_labels)
    cmap = LinearSegmentedColormap.from_list('navy_white', ['white', '#102C42'])

    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(
        cm, annot=True, fmt='d', cmap=cmap,
        xticklabels=class_labels, yticklabels=class_labels,
        ax=ax
    )
    ax.set_xlabel('Predicted Label', fontsize=12)
    ax.set_ylabel('True Label', fontsize=12)
    ax.set_title('Confusion Matrix', fontsize=14)
    plt.tight_layout()
    plt.show()


evaluate_model(model, val_generator)


## 7  Save Model for Deployment

In [None]:
MODEL_PATH = '/kaggle/working/ResNet50V2_brain_tumor.h5'
model.save(MODEL_PATH)
print(f"Model saved to {MODEL_PATH}")


In [None]:
def zip_and_link(file_paths, archive_name):
    """
    Zip `file_paths` into `archive_name`.zip inside /kaggle/working/
    and display a download link.
    """
    os.chdir('/kaggle/working/')
    zip_path = f"/kaggle/working/{archive_name}.zip"
    cmd = f"zip {zip_path} " + " ".join(file_paths)

    result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
    if result.returncode != 0:
        print("Zip failed:", result.stderr)
        return

    display(FileLink(f"{archive_name}.zip"))


zip_and_link([MODEL_PATH], 'brain_tumor_model_archive')
