# Brain Tumor Segmentation with U-Net

This notebook trains a 2D U-Net to segment brain tumors from multimodal MRI scans using the **BraTS 2020** dataset. Each patient has four MRI modalities (T1, T1ce, T2, FLAIR) paired with expert-annotated segmentation masks identifying three tumour sub-regions:

| Label | Class | Description |
|-------|-------|-------------|
| 0 | Background | Healthy brain tissue |
| 1 | Necrotic core | Dead or non-enhancing tumour core |
| 2 | Edema | Peritumoral swelling |
| 4→3 | Enhancing tumour | Gadolinium-enhancing active tumour |

> **Label 3 is absent** from the dataset; we remap label 4 → 3 for contiguous one-hot encoding.

We use only **FLAIR + T1ce** — two complementary modalities that maximise tumour contrast while minimising computational cost.

---
## 1. Environment Setup

In [None]:
# Install Kaggle API (only needed once in Colab/Kaggle)
!pip install kaggle -q

In [None]:
# Authenticate and download the BraTS 2020 dataset
import os, shutil

os.makedirs(os.path.expanduser('~/.kaggle'), exist_ok=True)
shutil.copy('kaggle.json', os.path.expanduser('~/.kaggle/kaggle.json'))
os.chmod(os.path.expanduser('~/.kaggle/kaggle.json'), 0o600)

!kaggle datasets download awsaf49/brats20-dataset-training-validation -q
!unzip -q brats20-dataset-training-validation.zip

---
## 2. Imports & Configuration

In [None]:
# ── Standard library ──────────────────────────────────────────────────────
import os
import glob
import random

# ── Scientific computing ───────────────────────────────────────────────────
import numpy as np
import pandas as pd

# ── Image processing ───────────────────────────────────────────────────────
import cv2
import nibabel as nib
from skimage.util import montage
from skimage.transform import rotate

# ── Visualisation ──────────────────────────────────────────────────────────
import matplotlib
import matplotlib.pyplot as plt

# ── Machine learning ───────────────────────────────────────────────────────
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

# ── Deep learning ──────────────────────────────────────────────────────────
import tensorflow as tf
import keras
import keras.backend as K
from keras.callbacks import CSVLogger
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Conv2D, MaxPooling2D, UpSampling2D,
    concatenate, Dropout
)
from tensorflow.keras.utils import plot_model

In [None]:
# ── Global constants ──────────────────────────────────────────────────────
TRAIN_DATASET_PATH = '/content/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/'

IMG_SIZE       = 128   # Resize each 2D slice to this resolution
VOLUME_SLICES  = 100   # Number of axial slices to use per patient
VOLUME_START   = 22    # First slice index (skip uninformative edge slices)
N_CLASSES      = 4     # Background, Necrotic, Edema, Enhancing
BATCH_SIZE     = 1     # Patients per batch

SEGMENT_CLASSES = {
    0: 'Background',
    1: 'Necrotic / Core',
    2: 'Edema',
    3: 'Enhancing tumour',  # original label 4, remapped to 3
}

# Colourmap for 4-class segmentation masks
SEG_CMAP = matplotlib.colors.ListedColormap(['#440054', '#3b528b', '#18b880', '#e6d74f'])
SEG_NORM = matplotlib.colors.BoundaryNorm([-0.5, 0.5, 1.5, 2.5, 3.5], SEG_CMAP.N)

scaler = MinMaxScaler()
print('✓ Configuration loaded')

---
## 3. Fix Malformed File in the Dataset

One patient folder (`BraTS20_Training_355`) contains an incorrectly named segmentation file. We rename it to match the expected naming convention before proceeding.

In [None]:
old_name = os.path.join(TRAIN_DATASET_PATH, 'BraTS20_Training_355', 'W39_1998.09.19_Segm.nii')
new_name = os.path.join(TRAIN_DATASET_PATH, 'BraTS20_Training_355', 'BraTS20_Training_355_seg.nii')

try:
    os.rename(old_name, new_name)
    print('✓ Segmentation file renamed successfully.')
except FileNotFoundError:
    print('✓ File already renamed — nothing to do.')

---
## 4. Explore the Dataset

### 4.1 Load a Sample Patient

In [None]:
def load_and_scale(filepath):
    """Load a NIfTI volume and min-max scale it to [0, 1]."""
    vol = nib.load(filepath).get_fdata()
    return scaler.fit_transform(vol.reshape(-1, vol.shape[-1])).reshape(vol.shape)

sample_id   = 'BraTS20_Training_355'
sample_dir  = os.path.join(TRAIN_DATASET_PATH, sample_id)

flair_sample = load_and_scale(os.path.join(sample_dir, f'{sample_id}_flair.nii'))
t1_sample    = load_and_scale(os.path.join(sample_dir, f'{sample_id}_t1.nii'))
t1ce_sample  = load_and_scale(os.path.join(sample_dir, f'{sample_id}_t1ce.nii'))
t2_sample    = load_and_scale(os.path.join(sample_dir, f'{sample_id}_t2.nii'))
seg_sample   = nib.load(os.path.join(sample_dir, f'{sample_id}_seg.nii')).get_fdata()

print(f'Modality shape : {flair_sample.shape}')
print(f'Mask shape     : {seg_sample.shape}')
print(f'Intensity range after scaling: [{flair_sample.min():.3f}, {flair_sample.max():.3f}]')

### 4.2 Visualise All Four Modalities

In [None]:
EXAMPLE_SLICE = 95

fig, axes = plt.subplots(2, 3, figsize=(14, 9))
modalities = [
    (t1_sample,    'T1',   'gray'),
    (t1ce_sample,  'T1ce', 'gray'),
    (t2_sample,    'T2',   'gray'),
    (flair_sample, 'FLAIR','gray'),
    (seg_sample,   'Mask', None),
]

for ax, (vol, title, cmap) in zip(axes.flat, modalities):
    kwargs = dict(cmap=SEG_CMAP, norm=SEG_NORM) if cmap is None else dict(cmap=cmap)
    ax.imshow(vol[:, :, EXAMPLE_SLICE], **kwargs)
    ax.set_title(title, fontsize=13, fontweight='bold')
    ax.axis('off')

axes.flat[-1].set_visible(False)  # hide empty subplot
plt.suptitle(f'Sample patient — axial slice {EXAMPLE_SLICE}', fontsize=15)
plt.tight_layout()
plt.show()

### 4.3 Three Anatomical Planes

Each 3D volume can be sliced along three orthogonal planes: **axial** (top-down), **coronal** (front-back) and **sagittal** (left-right). We apply a 90° rotation so all views render upright.

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(14, 5))
views = [
    (t1ce_sample[:, :, EXAMPLE_SLICE],          'Axial'),
    (rotate(t1ce_sample[:, EXAMPLE_SLICE, :], 90, resize=True), 'Coronal'),
    (rotate(t1ce_sample[EXAMPLE_SLICE, :, :], 90, resize=True), 'Sagittal'),
]

for ax, (img, title) in zip(axes, views):
    ax.imshow(img, cmap='gray')
    ax.set_title(title, fontsize=13, fontweight='bold')
    ax.axis('off')

plt.suptitle('T1ce — three anatomical planes', fontsize=15)
plt.tight_layout()
plt.show()

### 4.4 Montage: All Axial Slices

Many edge slices are empty (pure black). We visualise all slices and then restrict to a useful window (50 : -50).

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(18, 8))

axes[0].imshow(rotate(montage(t1ce_sample[:, :, :]), 90, resize=True), cmap='gray')
axes[0].set_title('All slices', fontsize=13)
axes[0].axis('off')

axes[1].imshow(rotate(montage(t1ce_sample[50:-50, :, :]), 90, resize=True), cmap='gray')
axes[1].set_title('Slices 50 : -50 (informative range)', fontsize=13)
axes[1].axis('off')

plt.tight_layout()
plt.show()

### 4.5 Segmentation Class Breakdown

In [None]:
# Isolate each class by masking non-matching labels to NaN
def isolate_class(seg, label):
    masked = seg.copy().astype(float)
    masked[masked != label] = np.nan
    return masked

fig, axes = plt.subplots(1, 5, figsize=(22, 5))
class_labels = [0, 1, 2, 4]
class_titles = ['Original', 'Background (0)', 'Necrotic (1)', 'Edema (2)', 'Enhancing (4)']

axes[0].imshow(seg_sample[:, :, EXAMPLE_SLICE], cmap=SEG_CMAP, norm=SEG_NORM)
axes[0].set_title('Original', fontweight='bold')

legend_patches = [
    plt.Rectangle((0,0), 1, 1, color=SEG_CMAP(i), label=f'Class {class_labels[i]}')
    for i in range(len(class_labels))
]
axes[0].legend(handles=legend_patches, loc='lower left', fontsize=8)

for ax, label, title in zip(axes[1:], class_labels, class_titles[1:]):
    ax.imshow(isolate_class(seg_sample, label)[:, :, EXAMPLE_SLICE], cmap=SEG_CMAP, norm=SEG_NORM)
    ax.set_title(title, fontweight='bold')

for ax in axes:
    ax.axis('off')

plt.suptitle('Segmentation class isolation', fontsize=14)
plt.tight_layout()
plt.show()

---
## 5. Split the Dataset

We split 369 patient IDs into **train / validation / test** sets (≈68 / 20 / 12 %) using stratified random sampling with a fixed seed for reproducibility.

In [None]:
RANDOM_SEED = 42

all_dirs = [f.path for f in os.scandir(TRAIN_DATASET_PATH) if f.is_dir()]
all_ids  = [d.split('/')[-1] for d in all_dirs]

train_val_ids, test_ids  = train_test_split(all_ids,  test_size=0.12, random_state=RANDOM_SEED)
train_ids,     val_ids   = train_test_split(train_val_ids, test_size=0.19, random_state=RANDOM_SEED)

print(f'Train      : {len(train_ids):>3} patients')
print(f'Validation : {len(val_ids):>3} patients')
print(f'Test       : {len(test_ids):>3} patients')
print(f'Total      : {len(all_ids):>3} patients')

# Bar chart
fig, ax = plt.subplots(figsize=(6, 4))
ax.bar(['Train', 'Validation', 'Test'],
       [len(train_ids), len(val_ids), len(test_ids)],
       color=['#2ecc71', '#e74c3c', '#3498db'])
ax.set_ylabel('Patients')
ax.set_title('Dataset split')
plt.tight_layout()
plt.show()

---
## 6. Data Generator

The `BraTSDataGenerator` streams preprocessed 2D slices directly from disk, avoiding out-of-memory issues with the full 3D volumes.

**Per batch, for each patient:**
1. Load FLAIR and T1ce volumes.
2. Extract `VOLUME_SLICES` axial slices starting at `VOLUME_START`.
3. Resize each slice to `IMG_SIZE × IMG_SIZE`.
4. Remap label 4 → 3 (no label 3 exists in BraTS 2020).
5. One-hot encode the mask into `N_CLASSES` channels.
6. Normalise X to `[0, 1]`.

In [None]:
class BraTSDataGenerator(keras.utils.Sequence):
    """Keras-compatible generator for the BraTS 2020 dataset.

    Yields batches of (X, Y) where:
        X : float32 array of shape (VOLUME_SLICES, IMG_SIZE, IMG_SIZE, 2)
            Channel 0 = FLAIR, Channel 1 = T1ce
        Y : float32 one-hot array of shape (VOLUME_SLICES, IMG_SIZE, IMG_SIZE, N_CLASSES)
    """

    def __init__(self, patient_ids, img_size=IMG_SIZE, batch_size=BATCH_SIZE,
                 n_channels=2, shuffle=True):
        self.patient_ids = patient_ids
        self.img_size    = img_size
        self.batch_size  = batch_size
        self.n_channels  = n_channels
        self.shuffle     = shuffle
        self.on_epoch_end()

    # ── Keras Sequence API ────────────────────────────────────────────────

    def __len__(self):
        """Number of batches per epoch."""
        return len(self.patient_ids) // self.batch_size

    def __getitem__(self, index):
        """Return one batch of (X, Y)."""
        batch_ids = [
            self.patient_ids[k]
            for k in self.indexes[index * self.batch_size : (index + 1) * self.batch_size]
        ]
        return self._load_batch(batch_ids)

    def on_epoch_end(self):
        """Shuffle patient order after every epoch (if enabled)."""
        self.indexes = np.arange(len(self.patient_ids))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    # ── Internal helpers ──────────────────────────────────────────────────

    def _load_volume(self, case_dir, case_id, modality):
        """Load a single NIfTI modality volume as a numpy array."""
        path = os.path.join(case_dir, f'{case_id}_{modality}.nii')
        return nib.load(path).get_fdata()

    def _load_batch(self, batch_ids):
        """Load and preprocess all patients in a batch."""
        n_slices = self.batch_size * VOLUME_SLICES
        dim      = (self.img_size, self.img_size)

        X = np.zeros((n_slices, *dim, self.n_channels), dtype=np.float32)
        y = np.zeros((n_slices, 240, 240),               dtype=np.float32)

        for patient_idx, patient_id in enumerate(batch_ids):
            case_dir = os.path.join(TRAIN_DATASET_PATH, patient_id)

            flair = self._load_volume(case_dir, patient_id, 'flair')
            t1ce  = self._load_volume(case_dir, patient_id, 't1ce')
            seg   = self._load_volume(case_dir, patient_id, 'seg')

            slice_offset = patient_idx * VOLUME_SLICES
            for j in range(VOLUME_SLICES):
                z = j + VOLUME_START
                X[slice_offset + j, :, :, 0] = cv2.resize(flair[:, :, z], dim)
                X[slice_offset + j, :, :, 1] = cv2.resize(t1ce[:, :, z], dim)
                y[slice_offset + j]           = seg[:, :, z]

        # Remap label 4 → 3 (label 3 is absent in BraTS 2020)
        y[y == 4] = 3

        # One-hot encode and resize masks
        Y = tf.image.resize(tf.one_hot(y.astype(np.int32), N_CLASSES), dim)

        # Normalise images to [0, 1]
        x_max = X.max()
        return (X / x_max if x_max > 0 else X), Y


# Instantiate generators
training_generator   = BraTSDataGenerator(train_ids, shuffle=True)
validation_generator = BraTSDataGenerator(val_ids,   shuffle=False)
test_generator       = BraTSDataGenerator(test_ids,  shuffle=False)

print(f'Training batches   : {len(training_generator)}')
print(f'Validation batches : {len(validation_generator)}')
print(f'Test batches       : {len(test_generator)}')

#### Sanity check — visualise a batch sample

In [None]:
def plot_slice_triplet(flair_slice, t1ce_slice, mask_slice, title=''):
    """Display FLAIR, T1ce and the segmentation mask side by side."""
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    axes[0].imshow(flair_slice, cmap='gray');  axes[0].set_title('FLAIR')
    axes[1].imshow(t1ce_slice,  cmap='gray');  axes[1].set_title('T1ce')
    axes[2].imshow(mask_slice,  cmap=SEG_CMAP, norm=SEG_NORM)
    axes[2].set_title('Segmentation mask')
    for ax in axes:
        ax.axis('off')
    if title:
        plt.suptitle(title, fontsize=13)
    plt.tight_layout()
    plt.show()

# Pull a batch and display one slice
X_batch, Y_batch = training_generator[8]
masks = np.argmax(Y_batch, axis=-1)   # one-hot → class index

DISPLAY_SLICE = 60
plot_slice_triplet(
    X_batch[DISPLAY_SLICE, :, :, 0],
    X_batch[DISPLAY_SLICE, :, :, 1],
    masks[DISPLAY_SLICE],
    title=f'Training batch 8, slice {DISPLAY_SLICE}',
)

---
## 7. Loss Function & Evaluation Metrics

We combine **categorical cross-entropy** (pixel-level class probabilities) with **Dice loss** (overlap-based). Additional per-class Dice scores track performance on each tumour sub-region separately.

In [None]:
# ── Dice coefficient (mean across all 4 classes) ──────────────────────────
def dice_coef(y_true, y_pred, smooth=1.0):
    """Macro-averaged Dice coefficient over all N_CLASSES channels."""
    total = 0.0
    for i in range(N_CLASSES):
        y_t = K.flatten(y_true[:, :, :, i])
        y_p = K.flatten(y_pred[:, :, :, i])
        intersection = K.sum(y_t * y_p)
        total += (2.0 * intersection + smooth) / (K.sum(y_t) + K.sum(y_p) + smooth)
    return total / N_CLASSES


# ── Per-class Dice helpers ────────────────────────────────────────────────
def _dice_for_class(y_true, y_pred, class_idx, epsilon=1e-6):
    """Dice coefficient for a single segmentation class."""
    y_t = y_true[:, :, :, class_idx]
    y_p = y_pred[:, :, :, class_idx]
    intersection = K.sum(K.abs(y_t * y_p))
    return (2.0 * intersection) / (K.sum(K.square(y_t)) + K.sum(K.square(y_p)) + epsilon)

def dice_coef_necrotic(y_true, y_pred):
    return _dice_for_class(y_true, y_pred, class_idx=1)

def dice_coef_edema(y_true, y_pred):
    return _dice_for_class(y_true, y_pred, class_idx=2)

def dice_coef_enhancing(y_true, y_pred):
    return _dice_for_class(y_true, y_pred, class_idx=3)


# ── Classification metrics ────────────────────────────────────────────────
def precision(y_true, y_pred):
    """Positive predictive value."""
    tp = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    pp = K.sum(K.round(K.clip(y_pred, 0, 1)))
    return tp / (pp + K.epsilon())

def sensitivity(y_true, y_pred):
    """True positive rate (recall)."""
    tp = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    ap = K.sum(K.round(K.clip(y_true, 0, 1)))
    return tp / (ap + K.epsilon())

def specificity(y_true, y_pred):
    """True negative rate."""
    tn = K.sum(K.round(K.clip((1 - y_true) * (1 - y_pred), 0, 1)))
    an = K.sum(K.round(K.clip(1 - y_true, 0, 1)))
    return tn / (an + K.epsilon())


ALL_METRICS = [
    'accuracy',
    tf.keras.metrics.MeanIoU(num_classes=N_CLASSES),
    dice_coef, precision, sensitivity, specificity,
    dice_coef_necrotic, dice_coef_edema, dice_coef_enhancing,
]
print('✓ Metrics defined')

---
## 8. U-Net Architecture

Our 2D U-Net follows the classic encoder → bottleneck → decoder design with skip connections. Each encoder stage halves spatial resolution while doubling filter count; the decoder mirrors this in reverse.

| Stage | Filters | Spatial size |
|-------|---------|-------------|
| Encoder 1 | 32 | 128 × 128 |
| Encoder 2 | 64 | 64 × 64 |
| Encoder 3 | 128 | 32 × 32 |
| Encoder 4 | 256 | 16 × 16 |
| Bottleneck | 512 | 8 × 8 |
| Decoder → Output | 256 → 32 | 128 × 128 |

In [None]:
def conv_block(x, filters, kernel_init):
    """Two consecutive Conv2D-ReLU layers (the core U-Net building block)."""
    x = Conv2D(filters, 3, activation='relu', padding='same', kernel_initializer=kernel_init)(x)
    x = Conv2D(filters, 3, activation='relu', padding='same', kernel_initializer=kernel_init)(x)
    return x

def upsample_block(x, filters, kernel_init):
    """Upsample by 2x then refine with a single Conv2D."""
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(filters, 2, activation='relu', padding='same', kernel_initializer=kernel_init)(x)
    return x


def build_unet(input_shape=(IMG_SIZE, IMG_SIZE, 2),
               kernel_init='he_normal',
               dropout_rate=0.2):
    """Build and return a 2D U-Net model.

    Args:
        input_shape  : (height, width, channels) — default (128, 128, 2).
        kernel_init  : Weight initialiser for Conv2D layers.
        dropout_rate : Dropout probability applied at the bottleneck.

    Returns:
        A compiled-ready Keras Model.
    """
    inputs = Input(input_shape)

    # ── Encoder ───────────────────────────────────────────────────────────
    enc1 = conv_block(inputs, 32,  kernel_init)
    enc2 = conv_block(MaxPooling2D()(enc1), 64,  kernel_init)
    enc3 = conv_block(MaxPooling2D()(enc2), 128, kernel_init)
    enc4 = conv_block(MaxPooling2D()(enc3), 256, kernel_init)

    # ── Bottleneck ────────────────────────────────────────────────────────
    bridge = conv_block(MaxPooling2D()(enc4), 512, kernel_init)
    bridge = Dropout(dropout_rate)(bridge)

    # ── Decoder ───────────────────────────────────────────────────────────
    dec4 = conv_block(concatenate([upsample_block(bridge, 256, kernel_init), enc4]), 256, kernel_init)
    dec3 = conv_block(concatenate([upsample_block(dec4,   128, kernel_init), enc3]), 128, kernel_init)
    dec2 = conv_block(concatenate([upsample_block(dec3,    64, kernel_init), enc2]),  64, kernel_init)
    dec1 = conv_block(concatenate([upsample_block(dec2,    32, kernel_init), enc1]),  32, kernel_init)

    # ── Output ────────────────────────────────────────────────────────────
    outputs = Conv2D(N_CLASSES, (1, 1), activation='softmax')(dec1)

    return Model(inputs=inputs, outputs=outputs, name='UNet_BraTS')


model = build_unet()
model.summary(line_length=90)

In [None]:
plot_model(model, show_shapes=True, show_layer_names=True,
           rankdir='TB', expand_nested=False, dpi=70)

---
## 9. Callbacks & Training

In [None]:
callbacks = [
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss', factor=0.2, patience=2,
        min_lr=1e-6, verbose=1,
    ),
    keras.callbacks.ModelCheckpoint(
        filepath='model_epoch{epoch:02d}_valloss{val_loss:.6f}.weights.h5',
        monitor='val_loss', save_best_only=True, save_weights_only=True, verbose=1,
    ),
    CSVLogger('training.log', separator=',', append=False),
]

In [None]:
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss='categorical_crossentropy',
    metrics=ALL_METRICS,
)

K.clear_session()

history = model.fit(
    training_generator,
    epochs=35,
    steps_per_epoch=len(train_ids),
    validation_data=validation_generator,
    callbacks=callbacks,
)

In [None]:
model.save('brain_tumor_unet.keras')
print('✓ Model saved.')

---
## 10. Load a Saved Model

Use this cell to reload the final model or a specific checkpoint (e.g. the best epoch by validation loss).

In [None]:
CUSTOM_OBJECTS = {
    'dice_coef'          : dice_coef,
    'precision'          : precision,
    'sensitivity'        : sensitivity,
    'specificity'        : specificity,
    'dice_coef_necrotic' : dice_coef_necrotic,
    'dice_coef_edema'    : dice_coef_edema,
    'dice_coef_enhancing': dice_coef_enhancing,
}

# ── Option A: load the full saved model ───────────────────────────────────
model = keras.models.load_model('brain_tumor_unet.keras',
                                custom_objects=CUSTOM_OBJECTS,
                                compile=False)

# ── Option B: load best checkpoint weights ────────────────────────────────
# model = build_unet()
# model.load_weights('model_epoch19_valloss0.021449.weights.h5')

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss='categorical_crossentropy',
    metrics=ALL_METRICS,
)
print('✓ Model loaded and compiled.')

---
## 11. Training Curves

In [None]:
log = pd.read_csv('training.log')

fig, axes = plt.subplots(1, 4, figsize=(20, 5))

metric_pairs = [
    ('accuracy',  'val_accuracy',  'Accuracy'),
    ('loss',      'val_loss',      'Loss'),
    ('dice_coef', 'val_dice_coef', 'Dice Coefficient'),
    ('mean_io_u', 'val_mean_io_u', 'Mean IoU'),
]

epochs = range(1, len(log) + 1)

for ax, (train_col, val_col, title) in zip(axes, metric_pairs):
    ax.plot(epochs, log[train_col], 'b-o', markersize=3, label='Train')
    ax.plot(epochs, log[val_col],   'r-o', markersize=3, label='Validation')
    ax.set_title(title, fontsize=12, fontweight='bold')
    ax.set_xlabel('Epoch')
    ax.legend()
    ax.grid(alpha=0.3)

plt.suptitle('Training history', fontsize=14)
plt.tight_layout()
plt.show()

---
## 12. Predictions on the Test Set

In [None]:
def predict_patient(patient_id):
    """Run model inference on all VOLUME_SLICES axial slices for one patient.

    Returns:
        p : float32 array of shape (VOLUME_SLICES, IMG_SIZE, IMG_SIZE, N_CLASSES)
            Softmax probabilities for every class.
    """
    case_dir = os.path.join(TRAIN_DATASET_PATH, patient_id)
    dim = (IMG_SIZE, IMG_SIZE)
    X = np.zeros((VOLUME_SLICES, *dim, 2), dtype=np.float32)

    flair = nib.load(os.path.join(case_dir, f'{patient_id}_flair.nii')).get_fdata()
    t1ce  = nib.load(os.path.join(case_dir, f'{patient_id}_t1ce.nii')).get_fdata()

    for j in range(VOLUME_SLICES):
        z = j + VOLUME_START
        X[j, :, :, 0] = cv2.resize(flair[:, :, z], dim)
        X[j, :, :, 1] = cv2.resize(t1ce[:, :, z],  dim)

    x_max = X.max()
    return model.predict(X / x_max if x_max > 0 else X, verbose=0)

In [None]:
def show_prediction(patient_id, slice_idx=60):
    """Overlay predicted tumour classes on the FLAIR background for one slice.

    Args:
        patient_id : BraTS patient ID string (e.g. 'BraTS20_Training_042').
        slice_idx  : Index into the VOLUME_SLICES window to display.
    """
    case_dir = os.path.join(TRAIN_DATASET_PATH, patient_id)
    dim = (IMG_SIZE, IMG_SIZE)

    # Load ground truth and background image
    flair_vol = nib.load(os.path.join(case_dir, f'{patient_id}_flair.nii')).get_fdata()
    gt_vol    = nib.load(os.path.join(case_dir, f'{patient_id}_seg.nii')).get_fdata()

    bg  = cv2.resize(flair_vol[:, :, slice_idx + VOLUME_START], dim)
    gt  = cv2.resize(gt_vol[:,   :, slice_idx + VOLUME_START], dim, interpolation=cv2.INTER_NEAREST)

    p = predict_patient(patient_id)

    fig, axes = plt.subplots(1, 6, figsize=(22, 5))

    panels = [
        (bg,                            None,   None,    'FLAIR (original)'),
        (gt,                            SEG_CMAP, SEG_NORM, 'Ground truth'),
        (p[slice_idx, :, :, 1:4],       'Reds', None,    'All tumour classes'),
        (p[slice_idx, :, :, 1],         'OrRd', None,    f'{SEGMENT_CLASSES[1]} (pred)'),
        (p[slice_idx, :, :, 2],         'OrRd', None,    f'{SEGMENT_CLASSES[2]} (pred)'),
        (p[slice_idx, :, :, 3],         'OrRd', None,    f'{SEGMENT_CLASSES[3]} (pred)'),
    ]

    for ax, (img, cmap, norm, title) in zip(axes, panels):
        # Show FLAIR background first, then overlay prediction with transparency
        if cmap is not None and 'Reds' in str(cmap) or cmap == 'OrRd':
            ax.imshow(bg, cmap='gray')
            ax.imshow(img, cmap=cmap, alpha=0.4, interpolation='none')
        elif norm is not None:
            ax.imshow(bg, cmap='gray')
            ax.imshow(img, cmap=cmap, norm=norm, alpha=0.4, interpolation='none')
        else:
            ax.imshow(img, cmap='gray')
        ax.set_title(title, fontsize=10, fontweight='bold')
        ax.axis('off')

    plt.suptitle(f'Patient {patient_id} — slice {slice_idx}', fontsize=12)
    plt.tight_layout()
    plt.show()


# Display predictions for the first 7 test patients
for pid in test_ids[:7]:
    show_prediction(pid, slice_idx=60)

---
## 13. Final Evaluation on the Test Set

In [None]:
results = model.evaluate(test_generator, batch_size=100, verbose=1)

metric_names = [
    'Loss', 'Accuracy', 'Mean IoU', 'Dice Coefficient',
    'Precision', 'Sensitivity', 'Specificity',
    'Dice — Necrotic', 'Dice — Edema', 'Dice — Enhancing',
]

print('\n' + '='*45)
print('  Model evaluation — test set')
print('='*45)
for name, value in zip(metric_names, results):
    print(f'  {name:<22} : {value:.4f}')
print('='*45)

---
## 14. Conclusion

This notebook walked through the full pipeline for brain tumour segmentation with U-Net on BraTS 2020:

- **Data exploration** — understood the four MRI modalities and three anatomical planes.
- **Preprocessing** — normalised intensities, selected informative axial slices, and one-hot encoded masks.
- **Data generator** — streamed batches from disk to avoid memory overflow.
- **U-Net** — implemented a modular encoder-decoder with skip connections.
- **Training** — used adaptive learning rate, best-model checkpointing, and CSV logging.
- **Evaluation** — measured Dice, IoU, sensitivity, and specificity on a held-out test set.

**Potential next steps:**
- Try a 3D U-Net to exploit inter-slice spatial context.
- Add data augmentation (flips, rotations, intensity jitter).
- Experiment with combined Dice + cross-entropy loss weighting.
- Use all four modalities instead of two.