# DeepLabV3+ for Stone Segmentation

This notebook demonstrates how to implement and train a **DeepLabV3+** model for stone segmentation. The workflow is similar to our SegNet implementation and includes:

1. Importing dependencies
2. Loading and splitting the augmented dataset
3. Creating a custom data generator
4. Defining the DeepLabV3+ model using a MobileNetV2 backbone
5. Compiling and training the model
6. Visualizing training curves and sample predictions
7. Saving the trained model and predicted masks
8. Evaluating model performance using IoU and Dice Coefficient metrics

The notebook is saved as `deeplabv3plus_chambord.ipynb` and assumes the dataset is organized as:
  - Images: `../data/augmented/images`
  - Masks:  `../data/augmented/masks`


In [1]:
# Install TensorFlow if needed
# !pip install tensorflow

import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from sklearn.model_selection import train_test_split

print("TensorFlow version:", tf.__version__)
print("Keras version:", tf.keras.__version__)
print("scikit-learn version:", tf.__version__)

TensorFlow version: 2.16.2
Keras version: 3.9.0
scikit-learn version: 2.16.2


## Hyperparameters and Paths

In [None]:
# Define directories for images and masks
IMAGES_DIR = "../data/augmented/images"
MASKS_DIR = "../data/augmented/masks"

# Image details
IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
NUM_CLASSES = 2  # Binary segmentation: background and stones

# Training parameters
BATCH_SIZE = 4
EPOCHS = 100
VAL_SPLIT = 0.2  # 20% for validation

# For reproducibility
SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)

## 1. Loading and Splitting the Dataset

- The `images` folder contains the input images.
- The `masks` folder contains corresponding masks (with matching filenames).

We load the filenames, build the full paths, and split the data into training and validation sets.

In [None]:
def load_image(image_path, target_size=(256, 256), grayscale=False):
    """
    Loads an image from disk, resizes it, and converts it to an array.
    """
    img = load_img(image_path, target_size=target_size, color_mode='grayscale' if grayscale else 'rgb')
    return img_to_array(img)

image_filenames = sorted(os.listdir(IMAGES_DIR))

# Build full paths
image_paths = [os.path.join(IMAGES_DIR, f) for f in image_filenames]
mask_paths = [os.path.join(MASKS_DIR, f) for f in image_filenames]  # assume same filenames

print("Total images found:", len(image_paths))

### Train/Validation Split

In [None]:
train_images, val_images, train_masks, val_masks = train_test_split(
    image_paths, mask_paths, test_size=VAL_SPLIT, random_state=SEED
)

print("Training set size:", len(train_images))
print("Validation set size:", len(val_images))

## 2. Data Generator

We define a custom generator that yields batches of images and masks. Note that for binary segmentation, we invert the mask (i.e., if the original masks have stones as black and joints as white, we flip that so that stones become white).

In [None]:
def data_generator(image_paths, mask_paths, batch_size, num_classes, input_size=(256,256)):
    """
    Yields batches of (images, masks) for training/validation.
    """
    while True:
        indices = np.arange(len(image_paths))
        np.random.shuffle(indices)

        for start in range(0, len(indices), batch_size):
            end = min(start + batch_size, len(indices))
            batch_indices = indices[start:end]

            images = []
            masks = []

            for i in batch_indices:
                img = load_image(image_paths[i], target_size=input_size, grayscale=False)
                mask = load_image(mask_paths[i], target_size=input_size, grayscale=True)

                # Scale to [0,1]
                img = img / 255.0
                mask = mask / 255.0

                # Invert the mask if needed
                mask = 1.0 - mask

                # For binary segmentation, threshold the mask
                if num_classes == 2:
                    mask = (mask > 0.5).astype(np.float32)  # shape (H, W, 1)
                else:
                    # For multi-class, implement one-hot encoding as needed
                    pass

                images.append(img)
                masks.append(mask)

            images = np.array(images, dtype=np.float32)
            masks = np.array(masks, dtype=np.float32)
            yield images, masks

## 3. Defining the DeepLabV3+ Model

We now define a simplified DeepLabV3+ model. In this example, we use a MobileNetV2 backbone pre-trained on ImageNet. We then add an Atrous Spatial Pyramid Pooling (ASPP) module and a decoder. This architecture is widely used for semantic segmentation tasks.

In [None]:
def DeepLabV3Plus(input_shape=(256, 256, 3), num_classes=2):
    """
    A simplified DeepLabV3+ architecture using MobileNetV2 as the backbone.
    """
    base_model = tf.keras.applications.MobileNetV2(
        input_shape=input_shape,
        include_top=False,
        weights='imagenet'
    )

    # Use 'block_13_expand_relu' as the low resolution feature map
    x = base_model.get_layer('block_13_expand_relu').output

    # Atrous Spatial Pyramid Pooling (ASPP) module
    # A simple ASPP with a 1x1 conv and two 3x3 conv layers with different dilation rates
    aspp1 = layers.Conv2D(256, (1,1), padding='same', use_bias=False)(x)
    aspp1 = layers.BatchNormalization()(aspp1)
    aspp1 = layers.Activation('relu')(aspp1)

    aspp2 = layers.Conv2D(256, (3,3), dilation_rate=6, padding='same', use_bias=False)(x)
    aspp2 = layers.BatchNormalization()(aspp2)
    aspp2 = layers.Activation('relu')(aspp2)

    aspp3 = layers.Conv2D(256, (3,3), dilation_rate=12, padding='same', use_bias=False)(x)
    aspp3 = layers.BatchNormalization()(aspp3)
    aspp3 = layers.Activation('relu')(aspp3)

    # Concatenate ASPP features
    x = layers.Concatenate()([aspp1, aspp2, aspp3])

    # Decoder: Upsample and combine with low-level features
    x = layers.UpSampling2D(size=(4,4), interpolation='bilinear')(x)

    # Optionally
    low_level_feat = base_model.get_layer('block_3_expand_relu').output
    low_level_feat = layers.Conv2D(48, (1,1), padding='same', use_bias=False)(low_level_feat)
    low_level_feat = layers.BatchNormalization()(low_level_feat)
    low_level_feat = layers.Activation('relu')(low_level_feat)

    x = layers.Concatenate()([x, low_level_feat])
    x = layers.Conv2D(256, (3,3), padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    # Final upsampling
    x = layers.UpSampling2D(size=(4,4), interpolation='bilinear')(x)

    # Final classification layer
    if num_classes == 2:
        # For binary segmentation
        x = layers.Conv2D(1, (1,1), padding='same', activation='sigmoid')(x)
    else:
        # For multi-class segmentation
        x = layers.Conv2D(num_classes, (1,1), padding='same', activation='softmax')(x)

    model = tf.keras.models.Model(inputs=base_model.input, outputs=x)
    return model

# Instantiate the DeepLabV3+ model
model = DeepLabV3Plus(input_shape=(IMAGE_HEIGHT, IMAGE_WIDTH, 3), num_classes=NUM_CLASSES)
model.summary()

## 4. Compile and Train the Model

We now compile the model using the Adam optimizer and the appropriate loss function (binary crossentropy for 2 classes). We use a flag (`SKIP_TRAINING`) to optionally load the saved model instead of retraining.

In [None]:
# Decide whether to train from scratch or load an existing model
SKIP_TRAINING = True  # Set to False to train from scratch

if not SKIP_TRAINING:
    if NUM_CLASSES == 2:
        loss = 'binary_crossentropy'
    else:
        loss = 'categorical_crossentropy'

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
        loss=loss,
        metrics=["accuracy"]
    )

    # Create generators
    train_gen = data_generator(
        train_images,
        train_masks,
        BATCH_SIZE,
        NUM_CLASSES,
        input_size=(IMAGE_HEIGHT, IMAGE_WIDTH)
    )

    val_gen = data_generator(
        val_images,
        val_masks,
        BATCH_SIZE,
        NUM_CLASSES,
        input_size=(IMAGE_HEIGHT, IMAGE_WIDTH)
    )

    train_steps = len(train_images) // BATCH_SIZE
    val_steps = len(val_images) // BATCH_SIZE

    history = model.fit(
        train_gen,
        epochs=EPOCHS,
        steps_per_epoch=train_steps,
        validation_data=val_gen,
        validation_steps=val_steps,
        verbose=1
    )
else:
    from tensorflow.keras.models import load_model
    model = load_model("../models/deeplabv3plus_chambord.keras", compile=False)
    print("Model loaded from disk. Training skipped.")

## 5. Visualize Training Curves

If training was performed, we can visualize the loss and accuracy curves. If training was skipped, a message will be printed.

In [None]:
plt.figure(figsize=(12,4))
if 'history' in globals():
    plt.subplot(1,2,1)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Val Loss')
    plt.title('Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1,2,2)
    plt.plot(history.history['accuracy'], label='Train Accuracy')
    plt.plot(history.history['val_accuracy'], label='Val Accuracy')
    plt.title('Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.show()
else:
    print("Training was skipped, so no training history available for plotting.")

## 6. Sample Prediction

Let's visualize some sample predictions on the validation set.

In [None]:
def visualize_predictions(model, image_paths, mask_paths, input_size=(IMAGE_HEIGHT, IMAGE_WIDTH), num_samples=3):
    num_samples = min(num_samples, len(image_paths))
    fig, axs = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))
    if num_samples == 1:
        axs = np.expand_dims(axs, axis=0)
    for i in range(num_samples):
        img = load_image(image_paths[i], target_size=input_size, grayscale=False)
        mask = load_image(mask_paths[i], target_size=input_size, grayscale=True)
        img_scaled = img / 255.0
        mask_scaled = mask / 255.0
        pred = model.predict(np.expand_dims(img_scaled, axis=0))
        if NUM_CLASSES == 2:
            pred_mask = (pred[0, :, :, 0] > 0.5).astype(np.uint8)
        else:
            pred_mask = np.argmax(pred[0], axis=-1)
        axs[i, 0].imshow(img.astype(np.uint8))
        axs[i, 0].set_title('Image')
        axs[i, 0].axis('off')
        axs[i, 1].imshow(mask_scaled[:, :, 0], cmap='gray')
        axs[i, 1].set_title('Ground Truth Mask')
        axs[i, 1].axis('off')
        axs[i, 2].imshow(pred_mask, cmap='gray')
        axs[i, 2].set_title('Predicted Mask')
        axs[i, 2].axis('off')
    plt.tight_layout()
    plt.show()

# Visualize 5 samples from the validation set
visualize_predictions(model, val_images, val_masks, input_size=(IMAGE_HEIGHT, IMAGE_WIDTH), num_samples=5)

## 7. Saving the Model

If training was performed, the trained model is saved to disk for later reuse.

In [None]:
if not SKIP_TRAINING:
    SAVE_PATH = "../models/deeplabv3plus_chambord.keras"
    model.save(SAVE_PATH)
    print(f"Model saved to {SAVE_PATH}")
else:
    print("Model not saved because training was skipped.")

## 8. Post-Training Evaluation: Compute IoU and Dice Metrics

We now compute the Intersection over Union (IoU) and Dice Coefficient for the validation set.

In [None]:
def compute_iou(y_true, y_pred):
    """
    Compute the Intersection over Union (IoU) for binary masks.
    """
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    union = tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) - intersection
    iou = (intersection + 1e-7) / (union + 1e-7)
    return iou.numpy()

def compute_dice(y_true, y_pred):
    """
    Compute the Dice Coefficient for binary masks.
    """
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    dice = (2.0 * intersection + 1e-7) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + 1e-7)
    return dice.numpy()

iou_scores = []
dice_scores = []

for img_path, mask_path in zip(val_images, val_masks):
    img = load_image(img_path, target_size=(IMAGE_HEIGHT, IMAGE_WIDTH), grayscale=False) / 255.0
    mask = load_image(mask_path, target_size=(IMAGE_HEIGHT, IMAGE_WIDTH), grayscale=True) / 255.0
    pred = model.predict(np.expand_dims(img, axis=0))
    pred_bin = (pred[0, :, :, 0] > 0.5).astype(np.float32)
    pred_bin = np.expand_dims(pred_bin, axis=-1)
    mask = np.expand_dims(mask[:, :, 0], axis=-1)
    current_iou = compute_iou(mask[np.newaxis, ...], pred_bin[np.newaxis, ...])
    current_dice = compute_dice(mask[np.newaxis, ...], pred_bin[np.newaxis, ...])
    iou_scores.append(current_iou)
    dice_scores.append(current_dice)

mean_iou = np.mean(iou_scores)
mean_dice = np.mean(dice_scores)

print("Validation Mean IoU: {:.4f}".format(mean_iou))
print("Validation Mean Dice: {:.4f}".format(mean_dice))

## 9. Evaluate on External Test Set

We now evaluate the model on an external test set. The test images are located in `../test/images` and the expert-provided ground truth masks are in `../test/gt`. Note that the ground truth masks use the inverse convention (stones as black, joints as white), so we invert them before evaluation.

In [None]:
TEST_IMAGES_DIR = "../test/images"
TEST_GT_DIR = "../test/gt"

test_image_filenames = sorted(os.listdir(TEST_IMAGES_DIR))
test_mask_filenames = sorted(os.listdir(TEST_GT_DIR))

test_iou_scores = []
test_dice_scores = []

assert len(test_image_filenames) == len(test_mask_filenames), "Mismatch between test images and GT masks!"

for img_fname, gt_fname in zip(test_image_filenames, test_mask_filenames):
    img_path = os.path.join(TEST_IMAGES_DIR, img_fname)
    gt_path = os.path.join(TEST_GT_DIR, gt_fname)
    img = load_image(img_path, target_size=(IMAGE_HEIGHT, IMAGE_WIDTH), grayscale=False) / 255.0
    gt_mask = load_image(gt_path, target_size=(IMAGE_HEIGHT, IMAGE_WIDTH), grayscale=True) / 255.0
    gt_mask_inverted = 1.0 - gt_mask
    pred = model.predict(np.expand_dims(img, axis=0))
    pred_bin = (pred[0, :, :, 0] > 0.5).astype(np.float32)
    pred_bin = np.expand_dims(pred_bin, axis=-1)
    current_iou = compute_iou(gt_mask_inverted[np.newaxis, ...], pred_bin[np.newaxis, ...])
    current_dice = compute_dice(gt_mask_inverted[np.newaxis, ...], pred_bin[np.newaxis, ...])
    test_iou_scores.append(current_iou)
    test_dice_scores.append(current_dice)

mean_test_iou = np.mean(test_iou_scores)
mean_test_dice = np.mean(test_dice_scores)

print(f"Test Mean IoU: {mean_test_iou:.4f}")
print(f"Test Mean Dice: {mean_test_dice:.4f}")

## 10. Saving Predicted Masks

Before moving on, we save the predicted masks for a subset of test images. These masks are saved in a format similar to the ground truth (stones as white, joints as black) and can be included in the final report.

In [None]:
from PIL import Image

SAVE_MASKS_DIR = "../predicted_images/deeplabv3plus_masks"
os.makedirs(SAVE_MASKS_DIR, exist_ok=True)

num_samples_to_save = 10

for i, (img_fname, gt_fname) in enumerate(zip(test_image_filenames, test_mask_filenames)):
    if i >= num_samples_to_save:
        break
    img_path = os.path.join(TEST_IMAGES_DIR, img_fname)
    img = load_image(img_path, target_size=(IMAGE_HEIGHT, IMAGE_WIDTH), grayscale=False) / 255.0
    pred = model.predict(np.expand_dims(img, axis=0))
    pred_mask = (pred[0, :, :, 0] > 0.5).astype(np.uint8)
    pred_mask_255 = (pred_mask * 255).astype(np.uint8)
    # Create a 3-channel RGB array to avoid transparency issues
    rgb_array = np.stack([pred_mask_255]*3, axis=-1)
    mask_img = Image.fromarray(rgb_array, 'RGB')
    base_name = os.path.splitext(img_fname)[0]
    save_name = f"pred_{base_name}.jpg"
    save_path = os.path.join(SAVE_MASKS_DIR, save_name)
    mask_img.save(save_path, format="JPEG", quality=100)
    print(f"Saved predicted mask for {img_fname} as {save_path}")