In [None]:
# import necessary libraries

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger, ReduceLROnPlateau
import math
import os
from glob import glob
from  Utils import plot_training_history
from Data_processing_utils import data_generator


-----

### 📦 Architecture Overview

The **3D U-Net** follows an encoder-decoder structure with skip connections and is designed to capture both **spatial context** and **fine-grained localization** in volumetric data.

#### 🔹 Encoder Path (Contracting)
- Extracts high-level features while reducing spatial resolution.
- Each block applies two 3D convolutions followed by batch normalization and ReLU activation.
- Downsampling is done using 3D max pooling.

#### 🔹 Bottleneck
- The deepest part of the network.
- Contains convolutional layers without pooling or upsampling.
- Acts as a bridge between encoder and decoder.

#### 🔹 Decoder Path (Expanding)
- Gradually restores spatial dimensions using 3D upsampling.
- Each upsampling block is followed by concatenation with the corresponding encoder feature map (skip connection), helping retain spatial detail.
- Then it applies convolutional layers to refine the merged features.

#### 🔹 Output Layer
- A final 3D convolution with softmax activation outputs a probability map for each class, per voxel.

In [None]:
def conv_block(x, filters):
    x = layers.Conv3D(filters, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv3D(filters, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    return x

def encoder_block(x, filters):
    f = conv_block(x, filters)
    p = layers.MaxPooling3D((2, 2, 2))(f)
    return f, p

def decoder_block(x, skip, filters):
    us = layers.UpSampling3D((2, 2, 2))(x)
    concat = layers.Concatenate()([us, skip])
    return conv_block(concat, filters)

def build_3d_unet(input_shape=(128, 128, 128, 4), num_classes=5):

    inputs = tf.keras.Input(shape=input_shape)

    # Encoder
    s1, p1 = encoder_block(inputs, 32)
    s2, p2 = encoder_block(p1, 64)
    s3, p3 = encoder_block(p2, 128)
    s4, p4 = encoder_block(p3, 256)

    # Bottleneck
    b = conv_block(p4, 512)

    # Decoder
    d1 = decoder_block(b, s4, 256)
    d2 = decoder_block(d1, s3, 128)
    d3 = decoder_block(d2, s2, 64)
    d4 = decoder_block(d3, s1, 32)

    outputs = layers.Conv3D(num_classes, 1, activation='softmax')(d4)

    return tf.keras.Model(inputs, outputs)


----

### 🎯 Dice Coefficient (Multiclass)

The **Dice Coefficient** measures the overlap between predicted and ground truth segmentations. It’s especially useful in medical imaging tasks.


#### 🔹 Per-Class Dice

- Converts ground truth and prediction to binary masks for the target class.
- Calculates:
  \[
  \text{Dice} = \frac{2 \cdot \text{Intersection} + \epsilon}{\text{Sum of areas} + \epsilon}
  \]
- `epsilon` avoids division by zero.


#### 🔹 Multiclass Dice

- Averages the Dice score over all classes (4).
- Provides a single score reflecting overall segmentation performance.



In [None]:
def dice_coefficient_per_class(y_true, y_pred, class_index, smooth=1e-6):
    
    # Create binary masks for this class
    y_true_c = tf.cast(tf.equal(y_true, class_index), tf.float32)
    y_pred_c = tf.cast(tf.equal(tf.argmax(y_pred, axis=-1), class_index), tf.float32)

    intersection = tf.reduce_sum(y_true_c * y_pred_c)
    union = tf.reduce_sum(y_true_c) + tf.reduce_sum(y_pred_c)

    dice = (2. * intersection + smooth) / (union + smooth)
    return dice

def multiclass_dice_coefficient(y_true, y_pred, num_classes=5, smooth=1e-6):
    dice = 0
    for i in range(num_classes):
        dice += dice_coefficient_per_class(y_true, y_pred, i, smooth)
    return dice / num_classes


-----

### 🧪 Model Compilation & Summary

After building the 3D U-Net, the model is compiled with the following settings:

- **Optimizer**: `Adam` — adaptive learning rate optimization.
- **Loss Function**: `sparse_categorical_crossentropy` — used for multi-class segmentation with integer labels.
- **Metric**: `multiclass_dice_coefficient` — custom metric to evaluate segmentation overlap across all classes.

In [None]:
model = build_3d_unet()

model.compile(
    optimizer='adam',
    loss="sparse_categorical_crossentropy",
    metrics=[multiclass_dice_coefficient]
)

model.summary()


-----

### 📂 Data Preparation with Generators

- Retrieves all subject folder paths from the `train` and `val` directories using `glob`.
- Sorts them to ensure consistent ordering.
- Initializes data generators for both training and validation sets with a specified `batch_size`
- These generators will load and yield data in batches during model training.



In [None]:
train_dir = "/content/BraTS2024/Train"
val_dir = "/content/BraTS2024/Val"
test_dir = "/content/BraTS2024/Test"

In [None]:
# Get list of all subject folders in train and val directories

train_subject_dirs = sorted(glob(os.path.join(train_dir, "*")))
val_subject_dirs = sorted(glob(os.path.join(val_dir, "*")))

batch_size = 2
train_gen = data_generator(train_subject_dirs, batch_size)
val_gen = data_generator(val_subject_dirs, batch_size)


-----

### 🚀 Model Training with Callbacks

The model is trained using `model.fit()` with the following configurations:

- **Training & Validation Generators**: Yield batches from the subject folders.
- **Epochs**: 50 total.
- **Steps**: Calculated based on dataset size and batch size.

####  Callbacks Used:
- **EarlyStopping**: Stops training if `val_loss` doesn't improve for 10 epochs; restores best weights.
- **ModelCheckpoint**: Saves the model with the lowest `val_loss` to `/Outputs/best_model.h5`.
- **CSVLogger**: Logs training history to `/Outputs/training_log.csv`.
- **ReduceLROnPlateau**: Reduces learning rate by half after 3 stagnant validation losses (min LR = 1e-7).

In [None]:
early_stop = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

checkpoint = ModelCheckpoint("/Outputs/best_model.h5", save_best_only=True, monitor='val_loss')

csv_logger = CSVLogger('/Outputs/training_log.csv', append=False)

training_steps = math.ceil(len(train_subject_dirs) / batch_size)
validation_steps = math.ceil(len(val_subject_dirs) / batch_size)

reduce_lr = ReduceLROnPlateau(
    monitor='val_loss',    # metric to monitor
    factor=0.5,            # factor to reduce LR by, new_lr = lr * factor
    patience=3,            # number of epochs with no improvement before reducing LR
    verbose=1,
    min_lr=1e-7            # lower bound on LR
)

history = model.fit(
    train_gen,
    steps_per_epoch=training_steps,
    validation_data=val_gen,
    validation_steps=validation_steps,
    epochs=50,
    callbacks=[early_stop, checkpoint , csv_logger , reduce_lr]
)

plot_training_history()