# Artificial Neural Networks and Deep Learning

---
## Homework 2: Image segmentation of Mars' stones
## Team: The Backpropagators
Arianna Procaccio, Francesco Buccoliero, Kai-Xi Matteo Chen, Luca Capoferri

ariii, frbuccoliero, kaiximatteoc, luke01

246843, 245498, 245523, 259617


## ⚙️ Import Libraries

In [None]:
#from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, f1_score
#from sklearn.utils.class_weight import compute_class_weight
from datetime import datetime
from matplotlib import pyplot as plt
from tensorflow import keras as tfk
from tensorflow.keras import layers as tfkl
import albumentations as A
import numpy as np
import tensorflow as tf

In [None]:
seed = 666
np.random.seed(seed)
tf.random.set_seed(seed)

## ⏳ Load, inspect and prepare the data

In [None]:
DATASET_PATH = "dataset.npz"

In [None]:
# Although we have the test set provided, we don't know their true mask. Reserve a small subset for a known test set to make inference later
train_ratio = 0.85
validation_ratio = 0.10
test_ratio = 0.05

assert train_ratio + validation_ratio + test_ratio == 1


In [None]:
data = np.load(DATASET_PATH)

training_set = data["training_set"]
X_train = training_set[:, 0]
y_train = training_set[:, 1]

hidden_X_test = data["test_set"]

print(f"Training X shape: {X_train.shape}")
print(f"Training y shape: {y_train.shape}")
print(f"Test hidden X shape: {hidden_X_test.shape}")

# Add color channel and rescale pixels between 0 and 1
X_train = X_train[..., np.newaxis] / 255
X_train = X_train.astype(np.float32)
hidden_X_test = hidden_X_test[..., np.newaxis] / 255
hidden_X_test = hidden_X_test.astype(np.float32)

input_shape = X_train.shape[1:]
num_classes = len(np.unique(y_train))

print(f"Input shape: {input_shape}")
print(f"Number of classes: {num_classes}")

In [None]:
# Split train and validation
validation_size = int(X_train.shape[0] * validation_ratio)

indices = np.arange(X_train.shape[0])
np.random.shuffle(indices)

X_train = X_train[indices]
y_train = y_train[indices]

# Define train and validation indices
split_indices = [int(X_train.shape[0] * train_ratio), int(X_train.shape[0] * (train_ratio + validation_ratio))]

X_train, X_val, X_test = np.split(X_train, split_indices)
y_train, y_val, y_test = np.split(y_train, split_indices)

print(f"Training X shape: {X_train.shape}")
print(f"Training y shape: {y_train.shape}")
print(f"Validation X shape: {X_val.shape}")
print(f"Validation y shape: {y_val.shape}")
print(f"Test X shape: {X_test.shape}")
print(f"Test y shape: {y_test.shape}")

In [None]:
# Plot the data. The number of images being displayed are rows X cols
def plot(data, mask=None, num_images=10, rows=4, cols=8):
  # Reshape if needed (e.g., remove channel dimension for grayscale images)
  if data.shape[-1] == 1:  # Grayscale case
    data = data.squeeze(axis=-1)  # Remove channel dimension
  
  if mask is None:
    # Plot settings
    _, axes = plt.subplots(rows, cols, figsize=(12, 6))  # Adjust figure size as needed
  
    # Display images
    for i, ax in enumerate(axes.flat):
      if i < len(data):  # Check if there are enough images
        ax.imshow(data[i], cmap='gray' if len(data[i].shape) == 2 else None)
        ax.axis('off')  # Hide axes
      else:
        ax.axis('off')  # Hide any empty subplot
  
    plt.tight_layout()
    plt.show()
  else:
    num_samples = num_images  # Number of images to display
    if num_samples < 4:
      num_samples = 4

    # Plot settings
    fig, axes = plt.subplots(num_samples, 2, figsize=(8, num_samples * 2))

    for i in range(num_samples):
      # Original image
      axes[i, 0].imshow(data[i], cmap="gray")
      axes[i, 0].set_title(f"Image {i+1}")
      axes[i, 0].axis("off")

      # Corresponding mask
      axes[i, 1].imshow(mask[i], cmap="viridis", alpha=0.8)  # Adjust cmap as needed
      axes[i, 1].set_title(f"Mask {i+1}")
      axes[i, 1].axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
# Outliers share the mask
outlier_mask_template = np.load("outlier_mask.npy") # discovered by hand
train_outliers_indices = [i for i, img in enumerate(y_train) if not np.array_equal(img, outlier_mask_template)]
val_outliers_indices = [i for i, img in enumerate(y_val) if not np.array_equal(img, outlier_mask_template)]
test_outliers_indices = [i for i, img in enumerate(y_test) if not np.array_equal(img, outlier_mask_template)]
print(f'Total outliers in train set: {y_train.shape[0] - len(train_outliers_indices)}')
print(f'Total outliers in validation set: {y_val.shape[0] - len(val_outliers_indices)}')
print(f'Total outliers in test set: {y_test.shape[0] - len(test_outliers_indices)}')

# Remove outlier from train and validation set
X_train = X_train[train_outliers_indices]
y_train = y_train[train_outliers_indices]
X_val = X_val[val_outliers_indices]
y_val = y_val[val_outliers_indices]
X_test = X_test[test_outliers_indices]
y_test = y_test[test_outliers_indices]

print(f'Updated train dataset size: {X_train.shape}')
print(f'Updated validation dataset size: {X_val.shape}')
print(f'Updated test dataset size: {X_test.shape}')


In [None]:
plot(X_train, rows=10, cols=8)

In [None]:
# An additional check: you should not see any outlier
plot(X_train, mask=y_train, num_images=40)

In [None]:
# Define image size for the network (dataset size is 64 X 128) and num of classes
IMG_SIZE = (64, 128)
NUM_CLASSES = num_classes

In [None]:
# `concat_and_shuffle_aug_with_no_aug` will double the X_train size
# `remove_bg` will set all the bg pixels to dark
# `augmentation_repetition` will concatenate n times the augmented dataset by applying the same `augmentations` fn. Useful for augmentation pipeline with probability activations
def get_dataset(X, y, batch_size=32, augmentations=None, augmentation_repetition=1, **kwargs):
  def resize_img_and_mask(img, mask):
    input_img = tf.image.resize(img, IMG_SIZE)
    input_img = tf.cast(input_img, tf.float32)

    # Resize needs at least 3 dims, add a dummy one
    target_img = tf.expand_dims(mask, axis=-1)
    # Nearest-neighbor is essential for resizing segmentation masks because it preserves the discrete class labels (e.g., 0, 1, 2) without introducing unintended values due to interpolation
    target_img = tf.image.resize(target_img, IMG_SIZE, method="nearest")
    target_img = tf.cast(target_img, tf.int32) # Consider lower integers

    return input_img, target_img

  def remove_background(image, mask, background_label=0):
    background_mask = (mask == background_label)
    image[background_mask] = 0  # Set to black
    return image, mask

  def apply_augmentation_np():
    X_a = []
    y_a = []
    for i, m in zip(X, y):
      aug_img, aug_mask = augmentations(i, m)
      if kwargs.get('remove_bg', False):
        aug_img, aug_mask = remove_background(aug_img, aug_mask)
      X_a.append(aug_img)  
      y_a.append(aug_mask)  
    return np.array(X_a), np.array(y_a)

  if kwargs.get('remove_bg', False):
    X_a = []
    y_a = []
    for i, m in zip(X, y):
      aug_img, aug_mask = remove_background(i, m)
      X_a.append(aug_img)
      y_a.append(aug_mask)
    X = np.array(X_a)
    y = np.array(y_a)

  # Apply augmentations before converting to dataset (this will be serial I think but we avoid type conversions as A works on np arrays)
  if augmentations is not None:
    X_a, y_a = apply_augmentation_np()
    dataset = tf.data.Dataset.from_tensor_slices((X_a, y_a))
    if augmentation_repetition > 1:
      for i in range(augmentation_repetition-1):
        X_a, y_a = apply_augmentation_np()
        dataset = dataset.concatenate(tf.data.Dataset.from_tensor_slices((X_a, y_a)))
    if kwargs.get('concat_and_shuffle_aug_with_no_aug', False):
      dataset = dataset.concatenate(tf.data.Dataset.from_tensor_slices((X, y)))
      dataset = dataset.shuffle(seed=seed, buffer_size=X.shape[0] * (augmentation_repetition+1))

  else:
    dataset = tf.data.Dataset.from_tensor_slices((X, y))

  dataset = dataset.map(resize_img_and_mask, num_parallel_calls=tf.data.AUTOTUNE)
  dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
  return dataset


## 🎲 Define training configuration

In [None]:
# Define model name. One of:
# U_NET
# U_NET_XCEPTION
# UWNet
# ASPP
model_name = 'ASPP'

In [None]:
# Define training setup
epochs = 1000
batch_size = 16

In [None]:
# Define optimizer setup
lr = 1e-4
fine_tuning_lr = 1e-4
# One of:
# SGD
# Adam
# AdamW
# Lion
# Ranger
opt_name = "AdamW"
fine_tuning_opt_name = "AdamW"

opt_exp_decay_rate: float | None = None
# Decay at how many epochs
opt_decay_epoch_delta = 7

In [None]:
def dice_loss(y_true, y_pred, smooth=1e-6):
  # Convert y_true to one-hot if needed
  # TODO: should we retrieve the argmax and use 1 channels instead of 5?
  if y_true.shape[-1] != y_pred.shape[-1]:
      y_true = tf.one_hot(tf.cast(y_true[..., 0], tf.int32), depth=y_pred.shape[-1])
  
  # Compute Dice Loss per class
  intersection = tf.reduce_sum(y_true * y_pred, axis=(1, 2))
  union = tf.reduce_sum(y_true + y_pred, axis=(1, 2))
  dice = (2. * intersection + smooth) / (union + smooth)
  
  # Average Dice Loss over all classes
  dice_loss = 1 - tf.reduce_mean(dice, axis=-1)
  return dice_loss

In [None]:
# Use sparse_categorical_crossentropy when your labels are integers representing class indices
loss_fn = 'sparse_categorical_crossentropy'

In [None]:
# Visualization callback
category_map = {
    0: 0, # Background,
    1: 1, # Soil,
    2: 2, # Bedrock,
    3: 3, # Sand,
    4: 4, # Big Rock,
}

def apply_category_mapping(label):
    """
    Apply category mapping to labels.
    """
    print("Label dtype before mapping:", label.dtype)
    keys_tensor = tf.constant(list(category_map.keys()), dtype=tf.int32)
    vals_tensor = tf.constant(list(category_map.values()), dtype=tf.int32)
    table = tf.lookup.StaticHashTable(
        tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
        default_value=0
    )
    return table.lookup(label)

def create_segmentation_colormap(num_classes):
    """
    Create a linear colormap using a predefined palette.
    Uses 'viridis' as default because it is perceptually uniform
    and works well for colorblindness.
    """
    return plt.cm.viridis(np.linspace(0, 1, num_classes))

def apply_colormap(label, colormap=None):
    """
    Apply the colormap to a label.
    """
    # Ensure label is 2D
    label = np.squeeze(label)

    if colormap is None:
        num_classes = len(np.unique(label))
        colormap = create_segmentation_colormap(num_classes)

    # Apply the colormap
    colored = colormap[label.astype(int)]

    return colored
    
class VizCallback(tf.keras.callbacks.Callback):
    def __init__(self, image, label, frequency=5):
        super().__init__()
        self.image = image
        self.label = tf.cast(tf.convert_to_tensor(label), tf.int32) 
        self.frequency = frequency

    def on_epoch_end(self, epoch, logs=None):
        if epoch % self.frequency == 0:  # Visualize only every "frequency" epochs
            image, label = self.image, self.label
            label = apply_category_mapping(label)
            image = tf.expand_dims(image, 0)
            pred = self.model.predict(image, verbose=0)
            y_pred = tf.math.argmax(pred, axis=-1)
            y_pred = y_pred.numpy()

            # Create colormap
            num_classes = NUM_CLASSES
            colormap = create_segmentation_colormap(num_classes)

            plt.figure(figsize=(16, 4))

            # Input image
            plt.subplot(1, 3, 1)
            plt.imshow(image[0],cmap='gray')
            plt.title("Input Image")
            plt.axis('off')

            # Ground truth
            plt.subplot(1, 3, 2)
            colored_label = apply_colormap(label.numpy(), colormap)
            plt.imshow(colored_label)
            plt.title("Ground Truth Mask")
            plt.axis('off')

            # Prediction
            plt.subplot(1, 3, 3)
            colored_pred = apply_colormap(y_pred[0], colormap)
            plt.imshow(colored_pred)
            plt.title("Predicted Mask")
            plt.axis('off')

            plt.tight_layout()
            plt.show()
            plt.close()

In [None]:
# Define fitting callbacks. Comment out from dict the unwanted ones
viz_callback = VizCallback(X_val[0], y_val[0]) # to visualize the first image of the validation every 5 epochs
model_fit_callbacks = {
	'ReduceLROnPlateau': tfk.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.8, patience=25, min_lr=1e-6, verbose=1),
	'EarlyStopping': tfk.callbacks.EarlyStopping(monitor='val_accuracy', mode='max', patience=50, restore_best_weights=True, verbose=1),
    #'Viz_callback' : viz_callback
}

In [None]:
# just to free or not the memory
FREE_MODEL = False

## 🛠️ Define model, augmentation and utils builders

In [None]:
def build_augmentation():
  transform = A.Compose([
          A.RandomRotate90(p=0.7),  # Random 90-degree rotation
          A.HorizontalFlip(p=0.7),  # Horizontal flip for diverse texture representation
          A.VerticalFlip(p=0.7),  # Vertical flip to simulate different orientations
          A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.7),  # Adjust brightness and contrast
          A.GaussianBlur(blur_limit=3, p=0.7),  # Add blur to simulate camera effects
          A.CoarseDropout(max_holes=4, max_height=16, max_width=16, p=0.7),  # Randomly occlude parts of the image
          A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=45, p=0.7),  # Random shifts, scales, and rotations
          A.ElasticTransform(alpha=1, sigma=50, p=0.7),
          A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.7),
          A.OpticalDistortion(distort_limit=0.2, shift_limit=0.2, p=0.7),
          A.Resize(height=IMG_SIZE[0], width=IMG_SIZE[1], p=1),  # Resize for consistent input size
      ])
  
  return transform

In [None]:
def apply_augmentation(img, mask):
  transform = build_augmentation()
  transformed = transform(image=img, mask=mask)
  return transformed["image"], transformed["mask"]

In [None]:
# Try the augmented dataset
N = 2
ds = get_dataset(X_train[:N], y_train[:N], augmentations=apply_augmentation, augmentation_repetition=4, concat_and_shuffle_aug_with_no_aug=True)

for batch in ds.take(1):
  a, b = batch
  plot(a.numpy(), b.numpy(), num_images=N * 5) # use N * (augmentation_repetition+1) as `concat_and_shuffle_aug_with_no_aug` is True 
  break


In [None]:
# taken from https://keras.io/examples/vision/oxford_pets_image_segmentation/
def build_U_NET_XCEPTION(img_size: tuple[int, int, int], num_classes):
  inputs = tfk.Input(shape=img_size) # One channel input

  ### [First half of the network: downsampling inputs] ###

  # Entry block
  x = tfkl.Conv2D(32, 3, strides=2, padding="same")(inputs)
  x = tfkl.BatchNormalization()(x)
  x = tfkl.Activation("relu")(x)

  previous_block_activation = x  # Set aside residual

  # Blocks 1, 2, 3 are identical apart from the feature depth.
  for filters in [64, 128, 256]:
    x = tfkl.Activation("relu")(x)
    x = tfkl.SeparableConv2D(filters, 3, padding="same")(x)
    x = tfkl.BatchNormalization()(x)

    x = tfkl.Activation("relu")(x)
    x = tfkl.SeparableConv2D(filters, 3, padding="same")(x)
    x = tfkl.BatchNormalization()(x)

    x = tfkl.MaxPooling2D(3, strides=2, padding="same")(x)

    # Project residual
    residual = tfkl.Conv2D(filters, 1, strides=2, padding="same")(
    previous_block_activation
    )
    x = tfkl.add([x, residual])  # Add back residual
    previous_block_activation = x  # Set aside next residual

  ### [Second half of the network: upsampling inputs] ###

  for filters in [256, 128, 64, 32]:
    x = tfkl.Activation("relu")(x)
    x = tfkl.Conv2DTranspose(filters, 3, padding="same")(x)
    x = tfkl.BatchNormalization()(x)

    x = tfkl.Activation("relu")(x)
    x = tfkl.Conv2DTranspose(filters, 3, padding="same")(x)
    x = tfkl.BatchNormalization()(x)

    x = tfkl.UpSampling2D(2)(x)

    # Project residual
    residual = tfkl.UpSampling2D(2)(previous_block_activation)
    residual = tfkl.Conv2D(filters, 1, padding="same")(residual)
    x = tfkl.add([x, residual])  # Add back residual
    previous_block_activation = x  # Set aside next residual

  # Add a per-pixel classification layer
  outputs = tfkl.Conv2D(num_classes, 3, activation="softmax", padding="same")(x)

  # Define the model
  model = tfk.Model(inputs, outputs, name='UNetXception')
  return model

In [None]:
def build_U_NET(img_size: tuple[int, int, int], num_classes):
  def unet_block(input_tensor, filters, kernel_size=3, activation='relu', stack=2, name=''):
    # Initialise the input tensor
    x = input_tensor

    # Apply a sequence of Conv2D, Batch Normalisation, and Activation layers for the specified number of stacks
    for i in range(stack):
        x = tfkl.Conv2D(filters, kernel_size=kernel_size, padding='same', name=name + 'conv' + str(i + 1))(x)
        x = tfkl.BatchNormalization(name=name + 'bn' + str(i + 1))(x)
        x = tfkl.Activation(activation, name=name + 'activation' + str(i + 1))(x)

    # Return the transformed tensor
    return x

  input_layer = tfkl.Input(shape=img_size, name='input_layer')

  # Downsampling path
  down_block_1 = unet_block(input_layer, 32, name='down_block1_')
  d1 = tfkl.MaxPooling2D()(down_block_1)

  down_block_2 = unet_block(d1, 64, name='down_block2_')
  d2 = tfkl.MaxPooling2D()(down_block_2)

  # Bottleneck
  bottleneck = unet_block(d2, 128, name='bottleneck')

  # Upsampling path
  u1 = tfkl.UpSampling2D()(bottleneck)
  u1 = tfkl.Concatenate()([u1, down_block_2])
  u1 = unet_block(u1, 64, name='up_block1_')

  u2 = tfkl.UpSampling2D()(u1)
  u2 = tfkl.Concatenate()([u2, down_block_1])
  u2 = unet_block(u2, 32, name='up_block2_')

  # Output Layer
  output_layer = tfkl.Conv2D(num_classes, kernel_size=1, padding='same', activation="softmax", name='output_layer')(u2)

  model = tf.keras.Model(inputs=input_layer, outputs=output_layer, name='UNet')
  return model

In [None]:
def build_ATTENTION_UW_NET(img_size: tuple[int, int, int], num_classes):
    def attention_block(x, g, inter_channel):
        # theta_x (bs, h, w, inter_channel)
        theta_x = tfkl.Conv2D(inter_channel, [1, 1], strides=[1, 1])(x)
        
        # phi_g (bs, h, w, inter_channel)
        phi_g = tfkl.Conv2D(inter_channel, [1, 1], strides=[1, 1])(g)
        
        # f (bs, h, w, 1)
        f = tfkl.Activation('relu')(tfkl.Add()([theta_x, phi_g]))
        psi_f = tfkl.Conv2D(1, [1, 1], strides=[1, 1])(f)
        
        # sigmoid_psi_f (bs, h, w, 1)
        sigmoid_psi_f = tfkl.Activation('sigmoid')(psi_f)
        
        # rate (bs, h, w, 1)
        rate = tfkl.multiply([x, sigmoid_psi_f])
        
        return rate
    
    # Input
    inputs = tfkl.Input(shape=img_size)
    
    # Encoder Path
    # Block 1
    conv1 = tfkl.Conv2D(64, 3, padding='same')(inputs)
    conv1 = tfkl.BatchNormalization()(conv1)
    conv1 = tfkl.Activation('relu')(conv1)
    conv1 = tfkl.Conv2D(64, 3, padding='same')(conv1)
    conv1 = tfkl.BatchNormalization()(conv1)
    conv1 = tfkl.Activation('relu')(conv1)
    pool1 = tfkl.MaxPooling2D(pool_size=(2, 2))(conv1)
    
    # Block 2
    conv2 = tfkl.Conv2D(128, 3, padding='same')(pool1)
    conv2 = tfkl.BatchNormalization()(conv2)
    conv2 = tfkl.Activation('relu')(conv2)
    conv2 = tfkl.Conv2D(128, 3, padding='same')(conv2)
    conv2 = tfkl.BatchNormalization()(conv2)
    conv2 = tfkl.Activation('relu')(conv2)
    pool2 = tfkl.MaxPooling2D(pool_size=(2, 2))(conv2)
    
    # Block 3
    conv3 = tfkl.Conv2D(256, 3, padding='same')(pool2)
    conv3 = tfkl.BatchNormalization()(conv3)
    conv3 = tfkl.Activation('relu')(conv3)
    conv3 = tfkl.Conv2D(256, 3, padding='same')(conv3)
    conv3 = tfkl.BatchNormalization()(conv3)
    conv3 = tfkl.Activation('relu')(conv3)
    pool3 = tfkl.MaxPooling2D(pool_size=(2, 2))(conv3)
    
    # Block 4
    conv4 = tfkl.Conv2D(512, 3, padding='same')(pool3)
    conv4 = tfkl.BatchNormalization()(conv4)
    conv4 = tfkl.Activation('relu')(conv4)
    conv4 = tfkl.Conv2D(512, 3, padding='same')(conv4)
    conv4 = tfkl.BatchNormalization()(conv4)
    conv4 = tfkl.Activation('relu')(conv4)
    pool4 = tfkl.MaxPooling2D(pool_size=(2, 2))(conv4)
    
    # Bridge
    conv5 = tfkl.Conv2D(1024, 3, padding='same')(pool4)
    conv5 = tfkl.BatchNormalization()(conv5)
    conv5 = tfkl.Activation('relu')(conv5)
    conv5 = tfkl.Conv2D(1024, 3, padding='same')(conv5)
    conv5 = tfkl.BatchNormalization()(conv5)
    conv5 = tfkl.Activation('relu')(conv5)
    
    # Decoder Path with Attention
    # Block 6
    up6 = tfkl.Conv2D(512, 2, padding='same')(tfkl.UpSampling2D(size=(2, 2))(conv5))
    up6 = tfkl.BatchNormalization()(up6)
    up6 = tfkl.Activation('relu')(up6)
    
    att6 = attention_block(conv4, up6, inter_channel=256)
    merge6 = tfkl.concatenate([att6, up6], axis=3)
    
    conv6 = tfkl.Conv2D(512, 3, padding='same')(merge6)
    conv6 = tfkl.BatchNormalization()(conv6)
    conv6 = tfkl.Activation('relu')(conv6)
    conv6 = tfkl.Conv2D(512, 3, padding='same')(conv6)
    conv6 = tfkl.BatchNormalization()(conv6)
    conv6 = tfkl.Activation('relu')(conv6)
    
    # Block 7
    up7 = tfkl.Conv2D(256, 2, padding='same')(tfkl.UpSampling2D(size=(2, 2))(conv6))
    up7 = tfkl.BatchNormalization()(up7)
    up7 = tfkl.Activation('relu')(up7)
    
    att7 = attention_block(conv3, up7, inter_channel=128)
    merge7 = tfkl.concatenate([att7, up7], axis=3)
    
    conv7 = tfkl.Conv2D(256, 3, padding='same')(merge7)
    conv7 = tfkl.BatchNormalization()(conv7)
    conv7 = tfkl.Activation('relu')(conv7)
    conv7 = tfkl.Conv2D(256, 3, padding='same')(conv7)
    conv7 = tfkl.BatchNormalization()(conv7)
    conv7 = tfkl.Activation('relu')(conv7)
    
    # Block 8
    up8 = tfkl.Conv2D(128, 2, padding='same')(tfkl.UpSampling2D(size=(2, 2))(conv7))
    up8 = tfkl.BatchNormalization()(up8)
    up8 = tfkl.Activation('relu')(up8)
    
    att8 = attention_block(conv2, up8, inter_channel=64)
    merge8 = tfkl.concatenate([att8, up8], axis=3)
    
    conv8 = tfkl.Conv2D(128, 3, padding='same')(merge8)
    conv8 = tfkl.BatchNormalization()(conv8)
    conv8 = tfkl.Activation('relu')(conv8)
    conv8 = tfkl.Conv2D(128, 3, padding='same')(conv8)
    conv8 = tfkl.BatchNormalization()(conv8)
    conv8 = tfkl.Activation('relu')(conv8)
    
    # Block 9
    up9 = tfkl.Conv2D(64, 2, padding='same')(tfkl.UpSampling2D(size=(2, 2))(conv8))
    up9 = tfkl.BatchNormalization()(up9)
    up9 = tfkl.Activation('relu')(up9)
    
    att9 = attention_block(conv1, up9, inter_channel=32)
    merge9 = tfkl.concatenate([att9, up9], axis=3)
    
    conv9 = tfkl.Conv2D(64, 3, padding='same')(merge9)
    conv9 = tfkl.BatchNormalization()(conv9)
    conv9 = tfkl.Activation('relu')(conv9)
    conv9 = tfkl.Conv2D(64, 3, padding='same')(conv9)
    conv9 = tfkl.BatchNormalization()(conv9)
    conv9 = tfkl.Activation('relu')(conv9)
    
    # Output
    outputs = tfkl.Conv2D(num_classes, 1, activation='softmax')(conv9)
    
    model = tfk.Model(inputs=inputs, outputs=outputs, name='AttentionUWNet')
    return model

In [None]:
def build_ASPP_model(img_size: tuple[int, int, int], num_classes: int):
    
    initializer = tf.keras.initializers.HeNormal()
    regularizer = tf.keras.regularizers.l2(1e-4)

    inputs = tfkl.Input(shape=img_size)

    def conv_block(x, filters, kernel_size=(3, 3), activation="relu", batch_norm=True, dropout_rate=0.2):
        x = tfkl.Conv2D(
            filters,
            kernel_size,
            padding="same",
            kernel_initializer=initializer,
            kernel_regularizer=regularizer,
        )(x)
        if batch_norm:
            x = tfkl.BatchNormalization()(x)
        x = tfkl.Activation(activation)(x)
        if dropout_rate > 0:
            x = tfkl.SpatialDropout2D(dropout_rate)(x)
        return x

    def encoder_block(x, filters, dropout_rate=0.2):
        x = conv_block(x, filters, dropout_rate=dropout_rate)
        x = conv_block(x, filters, dropout_rate=dropout_rate)
        p = tfkl.MaxPooling2D(2)(x)
        return x, p

    def atrous_spatial_pyramid_pooling(x, dropout_rate=0.3):
        dims = x.shape[1:3]
        pool = tfkl.GlobalAveragePooling2D()(x)
        pool = tfkl.Reshape((1, 1, x.shape[-1]))(pool)
        pool = tfkl.Conv2D(
            256, 
            1, 
            padding="same", 
            kernel_initializer=initializer, 
            kernel_regularizer=regularizer,
        )(pool)
        pool = tfkl.UpSampling2D(size=dims, interpolation="bilinear")(pool)
        pool = tfkl.SpatialDropout2D(dropout_rate)(pool)

        conv_1x1 = tfkl.Conv2D(
            256,
            1,
            padding="same",
            kernel_initializer=initializer,
            kernel_regularizer=regularizer,
        )(x)
        atrous_6 = tfkl.Conv2D(
            256,
            3,
            dilation_rate=6,
            padding="same",
            kernel_initializer=initializer,
            kernel_regularizer=regularizer,
        )(x)
        atrous_12 = tfkl.Conv2D(
            256,
            3,
            dilation_rate=12,
            padding="same",
            kernel_initializer=initializer,
            kernel_regularizer=regularizer,
        )(x)
        atrous_18 = tfkl.Conv2D(
            256,
            3,
            dilation_rate=18,
            padding="same",
            kernel_initializer=initializer,
            kernel_regularizer=regularizer,
        )(x)

        x = tfkl.Concatenate()([pool, conv_1x1, atrous_6, atrous_12, atrous_18])
        x = tfkl.Conv2D(
            256,
            1,
            padding="same",
            kernel_initializer=initializer,
            kernel_regularizer=regularizer,
        )(x)
        x = tfkl.SpatialDropout2D(dropout_rate)(x)
        return x

    def decoder_block(x, skip, filters, dropout_rate=0.2):
        x = tfkl.Conv2DTranspose(
            filters,
            2,
            strides=2,
            padding="same",
            kernel_initializer=initializer,
            kernel_regularizer=regularizer,
        )(x)
        x = tfkl.Concatenate()([x, skip])
        x = conv_block(x, filters, dropout_rate=dropout_rate)
        return x

    # Encoder
    filters = [64, 128, 256, 512]
    skips = []
    x = inputs
    for f in filters:
        skip, x = encoder_block(x, f, dropout_rate=0.2)
        skips.append(skip)

    # Bottleneck with ASPP
    x = conv_block(x, 1024, dropout_rate=0.3)
    x = atrous_spatial_pyramid_pooling(x, dropout_rate=0.3)

    # Decoder
    skips = skips[::-1]
    decoder_filters = [512, 256, 128, 64]
    for skip, f in zip(skips, decoder_filters):
        x = decoder_block(x, skip, f, dropout_rate=0.2)

    # Final convolutional layer
    outputs = tfkl.Conv2D(
        num_classes, 
        1, 
        activation="softmax", 
        kernel_initializer=initializer, 
        kernel_regularizer=regularizer,
    )(x)

    model = tf.keras.Model(inputs, outputs)
    return model

In [None]:
def build_RockSeg(img_size: tuple[int, int, int], num_classes: int):
    """
    Builds the RockSeg model, adjusted for input image size 64x128.

    Parameters:
        img_size (tuple): Input image dimensions (height, width, channels).
        num_classes (int): Number of output classes.

    Returns:
        tf.keras.Model: RockSeg model instance.
    """

    initializer = tf.keras.initializers.HeNormal()
    regularizer = tf.keras.regularizers.l2(1e-4)  # L2 regularization with strength 1e-4

    def resnet_block(input_tensor, filters, kernel_size=3, activation='relu', stack=2, name=''):
        x = input_tensor
        for i in range(stack):
            x = tfkl.Conv2D(filters, kernel_size=kernel_size, padding='same', kernel_initializer=initializer,kernel_regularizer=regularizer, name=name + f'conv{i + 1}')(x)
            x = tfkl.BatchNormalization(name=name + f'bn{i + 1}')(x)
            x = tfkl.Activation(activation, name=name + f'activation{i + 1}')(x)
        return x

    def transformer_block(input_tensor, embed_dim, num_heads, name=''):
        x = tfkl.Conv2D(filters=256, kernel_size=(1, 1), padding='same', kernel_initializer=initializer,kernel_regularizer=regularizer, name=name + 'transformer_11')(input_tensor)
        x = tfkl.AveragePooling2D(pool_size=(2, 2), strides=(2, 2), padding="same", name=name + 'avg_pool')(x)
        x = tf.keras.layers.LayerNormalization(name=name + 'ln')(x)
        attention_output = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, name=name + 'mha')(x, x)
        x = tf.keras.layers.Add(name=name + 'skip1')([x, attention_output])
        feed_forward = tfkl.Dense(embed_dim, activation='relu', kernel_initializer=initializer,kernel_regularizer=regularizer, name=name + 'dense')(x)
        x = tf.keras.layers.Add(name=name + 'skip2')([x, feed_forward])
        return x

    def multiscale_feature_fusion(feature_maps, name=''):
        base_channels = feature_maps[len(feature_maps) // 2].shape[-1]
        consistent_features = [
            tfkl.Conv2D(base_channels, kernel_size=(1, 1), padding='same', kernel_initializer=initializer,kernel_regularizer=regularizer, name=name + f'conv_{i}')(fm) if fm.shape[-1] != base_channels else fm
            for i, fm in enumerate(feature_maps)
        ]
        base_height, base_width = consistent_features[len(consistent_features) // 2].shape[1:3]
        resized_features = [
            tfkl.UpSampling2D(size=(base_height // fm.shape[1], base_width // fm.shape[2]), interpolation='bilinear', name=name + f'up_{i}')(fm)
            if fm.shape[1] < base_height else
            tfkl.MaxPooling2D(pool_size=(fm.shape[1] // base_height, fm.shape[2] // base_width), name=name + f'pool_{i}')(fm)
            for i, fm in enumerate(consistent_features)
        ]
        fused = tfkl.Concatenate(name=name + 'concat')(resized_features)
        return tfkl.Conv2D(base_channels, kernel_size=1, padding='same', activation='relu', kernel_initializer=initializer,kernel_regularizer=regularizer, name=name + 'conv_fused')(fused)

    input_layer = tfkl.Input(shape=img_size, name='input_layer')

    conv1 = tfkl.Conv2D(filters=64, kernel_size=(4, 4), strides=(2, 2), padding='same', kernel_initializer=initializer,kernel_regularizer=regularizer, name='conv1')(input_layer)
    conv1_upsampled = tfkl.UpSampling2D(size=(2, 2), interpolation='bilinear', name='conv1_upsampled')(conv1)
    maxpool1 = tfkl.MaxPooling2D(pool_size=(2, 2), name='pool1')(conv1)

    resnet1 = resnet_block(maxpool1, filters=64, stack=2, name='resnet1_')
    resnet2 = resnet_block(resnet1, filters=128, stack=2, name='resnet2_')

    transformer = transformer_block(resnet2, embed_dim=256, num_heads=4, name='transformer_')

    bottleneck = tfkl.Conv2D(filters=512, kernel_size=3, padding='same', activation='relu', kernel_initializer=initializer, kernel_regularizer=regularizer, name='bottleneck')(transformer)

    msf1 = multiscale_feature_fusion([resnet1, resnet2, bottleneck], name='msf1_')
    msf1_upsampled = tfkl.UpSampling2D(size=(2, 2), interpolation='bilinear', name='msf1_upsample')(msf1)

    resnet1_upsampled = tfkl.UpSampling2D(size=(2, 2), interpolation='bilinear', name='resnet1_upsample')(resnet1)

    concat1 = tfkl.Concatenate(name='concat1')([msf1_upsampled, resnet1_upsampled])
    decoder1 = tfkl.Conv2D(filters=128, kernel_size=3, padding='same', activation='relu', kernel_initializer=initializer, kernel_regularizer=regularizer, name='decoder1')(concat1)

    upsample2 = tfkl.UpSampling2D(size=(2, 2), interpolation='bilinear', name='upsample2')(decoder1)

    concat2 = tfkl.Concatenate(name='concat2')([upsample2, conv1_upsampled])
    decoder2 = tfkl.Conv2D(filters=64, kernel_size=3, padding='same', activation='relu', kernel_initializer=initializer,kernel_regularizer=regularizer, name='decoder2')(concat2)

    upsample3 = tfkl.UpSampling2D(size=(1, 1), interpolation='bilinear', name='upsample3')(decoder2)

    output_layer = tfkl.Conv2D(num_classes, kernel_size=1, padding='same', activation="softmax", kernel_initializer=initializer, name='output_layer')(upsample3)

    model = tf.keras.Model(inputs=input_layer, outputs=output_layer, name='RockSeg')

    return model

In [None]:
model_dict = {
	'U_NET': build_U_NET,
	'U_NET_XCEPTION': build_U_NET_XCEPTION,
  'UWNet': build_ATTENTION_UW_NET,
  'ASPP' : build_ASPP_model,
  'ROCKSEG' : build_RockSeg
}

In [None]:
def get_callbacks():
	return [i for i in model_fit_callbacks.values()]

In [None]:
def fit_model(model, data_loader=None, validation_data_loader=None):
  assert(data_loader is not None)
  assert(validation_data_loader is not None)
  fit_history = model.fit(
        data_loader,
	      epochs=epochs,
        validation_data=validation_data_loader,
	      callbacks=get_callbacks()
	    ).history
  return fit_history

In [None]:
# Taken from https://github.com/SeanSdahl/RangerOptimizerTensorflow/blob/master/module.py
def build_ranger(lr=1e-3, weight_decay=0.0):
  try:
    import tensorflow_addons as tfa
  except:
    raise Exception("You have to install tensorflow_addons package for Ranger. Please note that this package is available up to tensorflow==2.14")
  def ranger(sync_period=6,
           slow_step_size=0.5,
           learning_rate=lr,
           beta_1=0.9,
           beta_2=0.999,
           epsilon=1e-7,
           weight_decay=weight_decay,
           amsgrad=False,
           sma_threshold=5.0,
           total_steps=0,
           warmup_proportion=0.1,
           min_lr=0.,
           name="Ranger"):
    inner = tfa.optimizers.RectifiedAdam(learning_rate, beta_1, beta_2, epsilon, weight_decay, amsgrad, sma_threshold, total_steps, warmup_proportion, min_lr, name)
    optim = tfa.optimizers.Lookahead(inner, sync_period, slow_step_size, name)
    return optim
  return ranger()

In [None]:
def get_optimizer(is_fine_tuning = False, use_decay_fine_tuning = False, **kwargs):
	decay = opt_exp_decay_rate
	if is_fine_tuning and not use_decay_fine_tuning:
		decay = None

	opt = opt_name if not is_fine_tuning else fine_tuning_opt_name

	if opt == "SGD":
		optimizer = tf.keras.optimizers.SGD(learning_rate=lr, momentum=0.9 if 'momentum' not in kwargs else kwargs['momentum'])
		if decay is not None:
			lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
					initial_learning_rate=fine_tuning_lr if is_fine_tuning else lr,
					decay_steps=opt_decay_epoch_delta * (X_train.shape[0] // batch_size),  # Decay every 7 epochs
					decay_rate=opt_exp_decay_rate,
					staircase=True
			)
			optimizer.learning_rate = lr_schedule
			print(f'\n\n{"Finetuning: " if is_fine_tuning else "NotFinetuning: "}using {opt} optimizer with exp decay {decay} (momentum = {optimizer.momentum})\n\n')
			return optimizer
		else:
			optimizer.learning_rate = fine_tuning_lr if is_fine_tuning else lr
			print(f'\n\n{"Finetuning: " if is_fine_tuning else "NotFinetuning: "}using {opt} optimizer (momentum = {optimizer.momentum})\n\n')
			return optimizer

	elif opt == "Adam":
		if 'weight_decay' in kwargs:
			optimizer = tf.keras.optimizers.Adam(weight_decay=kwargs['weight_decay'])
		else:
			optimizer = tf.keras.optimizers.Adam()
		if decay is not None:
			lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
					initial_learning_rate=fine_tuning_lr if is_fine_tuning else lr,
					decay_steps=opt_decay_epoch_delta * (X_train.shape[0] // batch_size),  # Decay every 7 epochs
					decay_rate=opt_exp_decay_rate,
					staircase=True
			)
			optimizer.learning_rate = lr_schedule
			print(f'\n\n{"Finetuning: " if is_fine_tuning else "NotFinetuning: "}using {opt} optimizer with exp decay of {decay} weight decay = {optimizer.weight_decay}\n\n')
			return optimizer
		else:
			optimizer.learning_rate = fine_tuning_lr if is_fine_tuning else lr
			print(f'\n\n{"Finetuning: " if is_fine_tuning else "NotFinetuning: "}using {opt} optimizer (weight decay = {optimizer.weight_decay})\n\n')
			return optimizer

	elif opt == "AdamW":
		if 'weight_decay' in kwargs:
			optimizer = tf.keras.optimizers.AdamW(weight_decay=kwargs['weight_decay'])
		else:
			optimizer = tf.keras.optimizers.AdamW()
		if decay is not None:
			lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
					initial_learning_rate=fine_tuning_lr if is_fine_tuning else lr,
					decay_steps=opt_decay_epoch_delta * (X_train.shape[0] // batch_size),  # Decay every 7 epochs
					decay_rate=opt_exp_decay_rate,
					staircase=True
			)
			optimizer.learning_rate = lr_schedule
			print(f'\n\n{"Finetuning: " if is_fine_tuning else "NotFinetuning: "}using {opt} optimizer with exp decay of {decay} weight decay = {optimizer.weight_decay}\n\n')
			return optimizer
		else:
			optimizer.learning_rate = fine_tuning_lr if is_fine_tuning else lr
			print(f'\n\n{"Finetuning: " if is_fine_tuning else "NotFinetuning: "}using {opt} optimizer (weight decay = {optimizer.weight_decay})\n\n')
			return optimizer

	elif opt == "Lion":
		if 'weight_decay' in kwargs:
			optimizer = tf.keras.optimizers.Lion(weight_decay=kwargs['weight_decay'])
		else:
			optimizer = tf.keras.optimizers.Lion()
		if decay is not None:
			lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
					initial_learning_rate=fine_tuning_lr if is_fine_tuning else lr,
					decay_steps=opt_decay_epoch_delta * (X_train.shape[0] // batch_size),  # Decay every 7 epochs
					decay_rate=opt_exp_decay_rate,
					staircase=True
			)
			optimizer.learning_rate = lr_schedule
			print(f'\n\n{"Finetuning: " if is_fine_tuning else "NotFinetuning: "}using {opt} optimizer with exp decay of {decay} weight decay = {optimizer.weight_decay}\n\n')
			return optimizer
		else:
			optimizer.learning_rate = fine_tuning_lr if is_fine_tuning else lr
			print(f'\n\n{"Finetuning: " if is_fine_tuning else "NotFinetuning: "}using {opt} optimizer (weight decay = {optimizer.weight_decay})\n\n')
			return optimizer
	elif opt == "Ranger":
		optimizer = build_ranger(lr=lr if not is_fine_tuning else fine_tuning_lr, weight_decay=0.0 if 'weight_decay' not in kwargs else kwargs['weight_decay'])
		if decay is not None:
			raise RuntimeError("Not supported")
		else:
			optimizer.learning_rate = fine_tuning_lr if is_fine_tuning else lr
			print(f'\n\n{"Finetuning: " if is_fine_tuning else "NotFinetuning: "}using {opt} optimizer\n\n')
			return optimizer

In [None]:
def display_model(model):
	# Display a summary of the model architecture
	model.summary(expand_nested=True)

	# Display model architecture with layer shapes and trainable parameters
	tfk.utils.plot_model(model, expand_nested=True, show_trainable=True, show_shapes=True, dpi=70)

## 🍣 Define and display model

In [None]:
# Define custom Mean Intersection Over Union metric: the competition excludes the background class
class MeanIntersectionOverUnion(tf.keras.metrics.MeanIoU):
  def __init__(self, num_classes, labels_to_exclude=None, name="mean_iou", dtype=None):
    super(MeanIntersectionOverUnion, self).__init__(num_classes=num_classes, name=name, dtype=dtype)
    if labels_to_exclude is None:
      labels_to_exclude = [0]  # Default to excluding label 0
    self.labels_to_exclude = labels_to_exclude

  def update_state(self, y_true, y_pred, sample_weight=None):
    # Convert predictions to class labels
    y_pred = tf.math.argmax(y_pred, axis=-1)

    # Flatten the tensors
    y_true = tf.reshape(y_true, [-1])
    y_pred = tf.reshape(y_pred, [-1])

    # Apply mask to exclude specified labels
    for label in self.labels_to_exclude:
      mask = tf.not_equal(y_true, label)
      y_true = tf.boolean_mask(y_true, mask)
      y_pred = tf.boolean_mask(y_pred, mask)

    # Update the state
    return super().update_state(y_true, y_pred, sample_weight)

In [None]:
model = model_dict[model_name](IMG_SIZE + (1,), NUM_CLASSES)

model.compile(loss=loss_fn, optimizer=get_optimizer(is_fine_tuning=False), metrics=['accuracy', MeanIntersectionOverUnion(num_classes=NUM_CLASSES, labels_to_exclude=[0])])
display_model(model)

## 🧗🏻‍♂️ Train and save

In [None]:
# Fit the initial model
print('\n\nFitting model\n\n')
fit_history = fit_model(model, data_loader=get_dataset(X_train, y_train, batch_size=batch_size, augmentations=apply_augmentation, concat_and_shuffle_aug_with_no_aug=True), validation_data_loader=get_dataset(X_val, y_val, batch_size=batch_size))

# Calculate and print the final validation accuracy
final_val_meanIoU = round(max(fit_history['val_mean_iou'])* 100, 2)
print(f'Final validation Mean Intersection Over Union: {final_val_meanIoU}%')

# Save intermediate model
model_filename = f'{model_name}-{str(final_val_meanIoU)}-{datetime.now().strftime("%y%m%d_%H%M")}.keras'
model.save(model_filename)

# Free memory by deleting the model instance
if FREE_MODEL:
  del model

In [None]:
def plot_trainig(fit):
  # Plot and display training and validation loss
  plt.figure(figsize=(18, 3))
  plt.plot(fit['loss'], label='Training', alpha=0.8, color='#ff7f0e', linewidth=2)
  plt.plot(fit['val_loss'], label='Validation', alpha=0.9, color='#5a9aa5', linewidth=2)
  plt.title('Cross Entropy')
  plt.legend()
  plt.grid(alpha=0.3)
  plt.show()

  # Plot and display training and validation accuracy
  plt.figure(figsize=(18, 3))
  plt.plot(fit['accuracy'], label='Training', alpha=0.8, color='#ff7f0e', linewidth=2)
  plt.plot(fit['val_accuracy'], label='Validation', alpha=0.9, color='#5a9aa5', linewidth=2)
  plt.title('Accuracy')
  plt.legend()
  plt.grid(alpha=0.3)
  plt.show()

  # Plot and display training and validation mean IoU
  plt.figure(figsize=(18, 3))
  plt.plot(fit['mean_iou'], label='Training', alpha=0.8, color='#ff7f0e', linewidth=2)
  plt.plot(fit['val_mean_iou'], label='Validation', alpha=0.9, color='#5a9aa5', linewidth=2)
  plt.title('Mean Intersection over Union')
  plt.legend()
  plt.grid(alpha=0.3)
  plt.show()

In [None]:
plot_trainig(fit_history)

## ✍🏿 Make prediction on test set

In [None]:
test_dataset = get_dataset(X_test, y_test, batch_size=8)

In [None]:
# Load UNet model without compiling
model = tfk.models.load_model('UNet_59.26.keras', compile=False)

# Compile the model with specified loss, optimizer, and metrics
model.compile(
    loss=loss_fn,
    optimizer=get_optimizer(),
    metrics=["accuracy", MeanIntersectionOverUnion(num_classes=NUM_CLASSES, labels_to_exclude=[0])]
)

display_model(model)

# Evaluate the model on the test set and print the results
test_loss, test_accuracy, test_mean_iou = model.evaluate(test_dataset, verbose=1)
print(f'Test Accuracy: {round(test_accuracy, 4)}')
print(f'Test Mean Intersection over Union: {round(test_mean_iou, 4)}')

In [None]:
def create_segmentation_colormap(num_classes):
  """
  Create a linear colormap using a predefined palette.
  Uses 'viridis' as default because it is perceptually uniform
  and works well for colorblindness.
  """
  return plt.cm.viridis(np.linspace(0, 1, num_classes))

def apply_colormap(label, colormap=None):
  """
  Apply the colormap to a label.
  """
  # Ensure label is 2D
  label = np.squeeze(label)

  if colormap is None:
    num_classes = len(np.unique(label))
    colormap = create_segmentation_colormap(num_classes)

  # Apply the colormap
  colored = colormap[label.astype(int)]

  return colored

def plot_triptychs(dataset, model, num_samples=1):
  """
  Plot triptychs (original image, true mask, predicted mask) for samples from a tf.data.Dataset

  Parameters:
  dataset: tf.data.Dataset - The dataset containing image-label pairs
  model: tf.keras.Model - The trained model to generate predictions
  num_samples: int - Number of samples to plot
  """
  # Take samples from the dataset
  samples = dataset.take(num_samples)

  for images, labels in samples:
    # If we have a batch, take the first example
    if len(images.shape) == 4:  # Batch of images
      images = images[0:1]
      labels = labels[0:1]

    # Generate predictions
    pred = model.predict(images, verbose=0)
    pred = tf.math.argmax(pred, axis=-1)

    # Create colormap based on number of classes in labels
    labels_np = labels.numpy()
    num_classes = len(np.unique(labels_np))
    colormap = create_segmentation_colormap(num_classes)

    # Create figure with subplots
    fig, axes = plt.subplots(1, 3, figsize=(20, 4))

    # Plot original image
    axes[0].set_title("Original Image")
    axes[0].imshow(images[0])
    axes[0].axis('off')

    # Plot original mask
    axes[1].set_title("Original Mask")
    colored_label = apply_colormap(labels[0], colormap)
    axes[1].imshow(colored_label)
    axes[1].axis('off')

    # Plot predicted mask
    axes[2].set_title("Predicted Mask")
    colored_pred = apply_colormap(pred[0], colormap)
    axes[2].imshow(colored_pred)
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()
    plt.close()

plot_triptychs(test_dataset, model, num_samples=2)


## 🎰 Make prediction on competition test set and create csv

## We are not loading the model but using the python env model as there is a current error on the `MeanIntersectionOverUnion` class which is not serializable making the model not loadable

In [None]:
import pandas as pd

print(hidden_X_test.shape)
preds = model.predict(hidden_X_test)
preds = np.argmax(preds, axis=-1)
print(f"Predictions shape: {preds.shape}")

def y_to_df(y) -> pd.DataFrame:
  """Converts segmentation predictions into a DataFrame format for Kaggle."""
  n_samples = len(y)
  y_flat = y.reshape(n_samples, -1)
  df = pd.DataFrame(y_flat)
  df["id"] = np.arange(n_samples)
  cols = ["id"] + [col for col in df.columns if col != "id"]
  return df[cols]

submission_filename = f'submission_{datetime.now().strftime("%y%m%d_%H%M")}.csv'
submission_df = y_to_df(preds)
submission_df.to_csv(submission_filename, index=False)