In [None]:
# import necessary libraries

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger, ReduceLROnPlateau
import math
import os
from glob import glob
from  Utils import plot_training_history
from Data_processing_utils import data_generator

### ⚙️ **Architecture Overview**

The **Hybrid 3D U-Net** extends the classic encoder–decoder design by incorporating **MedNeXt-style residual blocks**, **adaptive normalization**, and **lightweight attention mechanisms**, resulting in a more powerful and efficient segmentation model for volumetric data.


#### 🔹 Encoder Path (Contracting)
- Uses **MedNeXt blocks** instead of standard convolutions.  
- Each block includes:
  - **Depthwise/Grouped Convolution** for efficient feature extraction.  
  - **Pointwise Convolution** for channel mixing.  
  - **Adaptive Group Normalization** (auto-selects number of groups, fallback to LayerNorm).  
  - **Squeeze-and-Excitation (SE) attention** to recalibrate channel importance.  
  - **Residual connections** with optional projection for stable training.  
- Downsampling performed with **3D Max Pooling**.


#### 🔹 Bottleneck
- A deep **MedNeXt block** at the lowest resolution.  
- Captures global context and high-level features.  


#### 🔹 Decoder Path (Expanding)
- **Upsampling** with 3D transposed convolutions (Conv3DTranspose).  
- Skip connections concatenate encoder and decoder features to retain spatial detail.  
- Each upsampling stage applies a **MedNeXt block** for refined feature learning.  


#### 🔹 Output Layer
- Final **1×1×1 convolution** maps decoder features to `num_classes`.  
- **Softmax activation** generates voxel-wise class probabilities for segmentation.  


In [None]:
# ---------- Adaptive GroupNorm ----------
def adaptive_gn(x, filters):
    groups = min(8, filters, x.shape[-1])
    if x.shape[-1] % groups != 0:
        groups = 1  # fallback to LayerNorm
    return layers.GroupNormalization(groups=groups)(x)

# ---------- MedNeXt-style Block ----------
def mednext_block(x, filters, dropout_rate=0.2):
    input_channels = x.shape[-1]

    # Depthwise / grouped conv
    groups = min(8, filters, input_channels)
    if input_channels % groups != 0:
        groups = 1
    x_dw = layers.Conv3D(filters, 3, padding="same", groups=groups, use_bias=False)(x)
    x_dw = adaptive_gn(x_dw, filters)
    x_dw = layers.ReLU()(x_dw)

    # Pointwise conv
    x_pw = layers.Conv3D(filters, 1, padding="same", use_bias=False)(x_dw)
    x_pw = adaptive_gn(x_pw, filters)

    # Lightweight attention (Squeeze-and-Excitation)
    se = layers.GlobalAveragePooling3D()(x_pw)
    se = layers.Dense(max(filters // 4, 1), activation="relu")(se)
    se = layers.Dense(filters, activation="sigmoid")(se)
    se = layers.Reshape((1, 1, 1, filters))(se)
    x_att = layers.Multiply()([x_pw, se])

    # Dropout
    x_att = layers.Dropout(dropout_rate)(x_att)

    # Residual connection
    if input_channels == filters:
        x_out = layers.Add()([x, x_att])
    else:
        x_res = layers.Conv3D(filters, 1, padding="same")(x)
        x_out = layers.Add()([x_res, x_att])

    return layers.ReLU()(x_out)

# ---------- Encoder & Decoder ----------
def encoder_block(x, filters):
    f = mednext_block(x, filters)
    p = layers.MaxPooling3D((2,2,2))(f)
    return f, p

def decoder_block(x, skip, filters):
    us = layers.Conv3DTranspose(filters, 2, strides=2, padding="same")(x)
    concat = layers.Concatenate()([us, skip])
    return mednext_block(concat, filters)

# ---------- Full Hybrid UNet ----------
def build_hybrid_unet(input_shape=(128,128,128,4), num_classes=5):
    inputs = tf.keras.Input(shape=input_shape, dtype=tf.float32)

    # Encoder
    s1, p1 = encoder_block(inputs, 32)
    s2, p2 = encoder_block(p1, 64)
    s3, p3 = encoder_block(p2, 128)
    s4, p4 = encoder_block(p3, 256)

    # Bottleneck
    b = mednext_block(p4, 512)

    # Decoder
    d1 = decoder_block(b, s4, 256)
    d2 = decoder_block(d1, s3, 128)
    d3 = decoder_block(d2, s2, 64)
    d4 = decoder_block(d3, s1, 32)

    outputs = layers.Conv3D(num_classes, 1, activation="softmax")(d4)

    return tf.keras.Model(inputs, outputs)


### 🎯 Dice Coefficient & Loss (Multiclass)

The **Dice Coefficient** measures the overlap between predicted and ground truth segmentations. It’s particularly useful in medical imaging tasks where class imbalance is common.



#### 🔹 Per-Class Dice Coefficient
- Extracts binary masks for a **specific class**:
  - `y_true_c`: ground truth mask for the target class.  
  - `y_pred_c`: predicted mask obtained from `argmax` of the logits.  
- Uses `epsilon` (`smooth`) to avoid division by zero.



#### 🔹 Multiclass Dice Coefficient
- Iterates over all `num_classes` (default = 5).  
- Averages the per-class Dice scores.  
- Produces a **single global metric** reflecting segmentation performance across all classes.  



#### 🔹 Dice Loss
- Converts ground truth to **one-hot encoding**.  
- Flattens predictions and labels to compute per-class overlap.  
- Returns the complement of the mean Dice score across classes.  



#### 🔹 Hybrid Loss
- Combines **Sparse Categorical Cross-Entropy (CE)** with **Dice Loss**.  
- Balances **class-wise prediction accuracy** with **segmentation overlap quality**.  


In [None]:
def dice_coefficient_per_class(y_true, y_pred, class_index, smooth=1e-6):
    # Ground truth mask for this class
    y_true_c = tf.cast(tf.equal(y_true, class_index), tf.float32)
    # Prediction mask for this class
    y_pred_c = tf.cast(tf.equal(tf.argmax(y_pred, axis=-1), class_index), tf.float32)

    intersection = tf.reduce_sum(y_true_c * y_pred_c)
    union = tf.reduce_sum(y_true_c) + tf.reduce_sum(y_pred_c)

    return (2. * intersection + smooth) / (union + smooth)

def multiclass_dice_coefficient(y_true, y_pred, num_classes=5, smooth=1e-6):
    dice = 0
    for i in range(num_classes):
        dice += dice_coefficient_per_class(y_true, y_pred, i, smooth)
    return dice / num_classes


# ---------- Dice Loss ----------

def dice_loss(y_true, y_pred, num_classes=5, smooth=1e-6):
    # One-hot encode ground truth
    y_true = tf.one_hot(tf.cast(y_true, tf.int32), depth=num_classes)

    # Flatten (batch, H, W, D, C) → (batch, -1, C)
    y_true = tf.reshape(y_true, [-1, num_classes])
    y_pred = tf.reshape(y_pred, [-1, num_classes])

    intersection = tf.reduce_sum(y_true * y_pred, axis=0)
    union = tf.reduce_sum(y_true + y_pred, axis=0)

    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - tf.reduce_mean(dice)

def hybrid_loss(y_true, y_pred):
    ce = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
    return 0.5 * ce + 0.5 * dice_loss(y_true, y_pred)


----

### 🧪 Model Compilation & Summary

After building the 3D U-Net, the model is compiled with the following settings:

- **Optimizers**:  
  - `Adam` (learning rate = 1e-4) — adaptive learning rate optimization.  
  - `SGD` (learning rate = 1e-4, momentum = 0.9) — stochastic gradient descent with momentum.  

- **Loss Function**: `hybrid_loss` — combines Sparse Categorical Cross-Entropy and Dice Loss for balanced optimization.  

- **Metric**: `multiclass_dice_coefficient` — custom metric to evaluate segmentation overlap across all classes.  

In [None]:
model = build_3d_unet()

# Compile the model

# With Adam
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss = hybrid_loss,   # Soft Dice + CE
    metrics=[multiclass_dice_coefficient])

# Then with SGD
model.compile(
     optimizer=tf.keras.optimizers.SGD(learning_rate=0.0001, momentum=0.9),
     loss = hybrid_loss,
     metrics=[multiclass_dice_coefficient])

model.summary()


-----

### 📂 Data Preparation with Generators

- Retrieves all subject folder paths from the `train` and `val` directories using `glob`.
- Sorts them to ensure consistent ordering.
- Initializes data generators for both training and validation sets with a specified `batch_size`
- These generators will load and yield data in batches during model training.



In [None]:
train_dir = "/content/BraTS2024/train"
val_dir = "/content/BraTS2024/val"
test_dir = "/content/BraTS2024/test"

In [None]:
# Get list of all subject folders in train and val directories

train_subject_dirs = sorted(glob(os.path.join(train_dir, "*")))
val_subject_dirs = sorted(glob(os.path.join(val_dir, "*")))

batch_size = 2
train_gen = data_generator(train_subject_dirs, batch_size)
val_gen = data_generator(val_subject_dirs, batch_size)


-----

### 🚀 Model Training with Callbacks

The model is trained using `model.fit()` with the following configurations:

- **Training & Validation Generators**: Yield batches from the subject folders.
- **Epochs**: 50 total.
- **Steps**: Calculated based on dataset size and batch size.

####  Callbacks Used:
- **EarlyStopping**: Stops training if `val_loss` doesn't improve for 10 epochs; restores best weights.
- **ModelCheckpoint**: Saves the model with the lowest `val_loss` to `/Outputs/best_model.h5`.
- **CSVLogger**: Logs training history to `/Outputs/training_log.csv`.
- **ReduceLROnPlateau**: Reduces learning rate by half after 3 stagnant validation losses (min LR = 1e-7).

In [None]:
early_stop = EarlyStopping(monitor='val_multiclass_dice_coefficient', patience=10, mode='max' , restore_best_weights=True)

checkpoint = ModelCheckpoint("/Outputs/best_model.h5", save_best_only=True, mode='max', monitor='val_multiclass_dice_coefficient')

csv_logger = CSVLogger('/Outputs/training_log.csv', append=True)

training_steps = math.ceil(len(train_subject_dirs) / batch_size)
validation_steps = math.ceil(len(val_subject_dirs) / batch_size)

reduce_lr = ReduceLROnPlateau(
    monitor='val_multiclass_dice_coefficient',    # metric to monitor
    factor=0.5,            # factor to reduce LR by, new_lr = lr * factor
    patience=3,            # number of epochs with no improvement before reducing LR
    verbose=1,
    min_lr=1e-7            # lower bound on LR
)

history = model.fit(
    train_gen,
    steps_per_epoch=training_steps,
    validation_data=val_gen,
    validation_steps=validation_steps,
    epochs=50,
    callbacks=[early_stop, checkpoint , csv_logger , reduce_lr]
)

plot_training_history()