# Artificial Neural Networks and Deep Learning

---
## Homework 2: Image segmentation of Mars' stones
## Team: The Backpropagators
Arianna Procaccio, Francesco Buccoliero, Kai-Xi Matteo Chen, Luca Capoferri

ariii, frbuccoliero, kaiximatteoc, luke01

246843, 245498, 245523, 259617


## ⚙️ Import Libraries

In [None]:
#from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, f1_score
#from sklearn.utils.class_weight import compute_class_weight
from datetime import datetime
from matplotlib import pyplot as plt
from tensorflow import keras as tfk
from tensorflow.keras import layers as tfkl
import albumentations as A
import numpy as np
import tensorflow as tf

In [None]:
seed = 666
np.random.seed(seed)
tf.random.set_seed(seed)

## ⏳ Load, inspect and prepare the data

In [None]:
DATASET_PATH = "dataset.npz"

In [None]:
# Although we have the test set provided, we don't know their true mask. Reserve a small subset for a known test set to make inference later
train_ratio = 0.85
validation_ratio = 0.10
test_ratio = 0.05

assert train_ratio + validation_ratio + test_ratio == 1


In [None]:
data = np.load(DATASET_PATH)

training_set = data["training_set"]
X_train = training_set[:, 0]
y_train = training_set[:, 1]

hidden_X_test = data["test_set"]

print(f"Training X shape: {X_train.shape}")
print(f"Training y shape: {y_train.shape}")
print(f"Test hidden X shape: {hidden_X_test.shape}")

# Add color channel and rescale pixels between 0 and 1
X_train = X_train[..., np.newaxis] / 255
hidden_X_test = hidden_X_test[..., np.newaxis] / 255

input_shape = X_train.shape[1:]
num_classes = len(np.unique(y_train))

print(f"Input shape: {input_shape}")
print(f"Number of classes: {num_classes}")

In [None]:
# Split train and validation
validation_size = int(X_train.shape[0] * validation_ratio)

indices = np.arange(X_train.shape[0])
np.random.shuffle(indices)

X_train = X_train[indices]
y_train = y_train[indices]

# Define train and validation indices
split_indices = [int(X_train.shape[0] * train_ratio), int(X_train.shape[0] * (train_ratio + validation_ratio))]

X_train, X_val, X_test = np.split(X_train, split_indices)
y_train, y_val, y_test = np.split(y_train, split_indices)

print(f"Training X shape: {X_train.shape}")
print(f"Training y shape: {y_train.shape}")
print(f"Validation X shape: {X_val.shape}")
print(f"Validation y shape: {y_val.shape}")
print(f"Test X shape: {X_test.shape}")
print(f"Test y shape: {y_test.shape}")

In [None]:
# Plot the data. The number of images being displayed are rows X cols
def plot(data, mask=None, num_images=10, rows=4, cols=8):
  # Reshape if needed (e.g., remove channel dimension for grayscale images)
  if data.shape[-1] == 1:  # Grayscale case
    data = data.squeeze(axis=-1)  # Remove channel dimension
  
  if mask is None:
    # Plot settings
    _, axes = plt.subplots(rows, cols, figsize=(12, 6))  # Adjust figure size as needed
  
    # Display images
    for i, ax in enumerate(axes.flat):
      if i < len(data):  # Check if there are enough images
        ax.imshow(data[i], cmap='gray' if len(data[i].shape) == 2 else None)
        ax.axis('off')  # Hide axes
      else:
        ax.axis('off')  # Hide any empty subplot
  
    plt.tight_layout()
    plt.show()
  else:
    num_samples = num_images  # Number of images to display
    if num_samples < 4:
      num_samples = 4

    # Plot settings
    fig, axes = plt.subplots(num_samples, 2, figsize=(8, num_samples * 2))

    for i in range(num_samples):
      # Original image
      axes[i, 0].imshow(data[i], cmap="gray")
      axes[i, 0].set_title(f"Image {i+1}")
      axes[i, 0].axis("off")

      # Corresponding mask
      axes[i, 1].imshow(mask[i], cmap="viridis", alpha=0.8)  # Adjust cmap as needed
      axes[i, 1].set_title(f"Mask {i+1}")
      axes[i, 1].axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
# Outliers share the mask
outlier_mask_template = np.load("outlier_mask.npy") # discovered by hand
train_outliers_indices = [i for i, img in enumerate(y_train) if not np.array_equal(img, outlier_mask_template)]
val_outliers_indices = [i for i, img in enumerate(y_val) if not np.array_equal(img, outlier_mask_template)]
test_outliers_indices = [i for i, img in enumerate(y_test) if not np.array_equal(img, outlier_mask_template)]
print(f'Total outliers in train set: {y_train.shape[0] - len(train_outliers_indices)}')
print(f'Total outliers in validation set: {y_val.shape[0] - len(val_outliers_indices)}')
print(f'Total outliers in test set: {y_test.shape[0] - len(test_outliers_indices)}')

# Remove outlier from train and validation set
X_train = X_train[train_outliers_indices]
y_train = y_train[train_outliers_indices]
X_val = X_val[val_outliers_indices]
y_val = y_val[val_outliers_indices]
X_test = X_test[test_outliers_indices]
y_test = y_test[test_outliers_indices]

print(f'Updated train dataset size: {X_train.shape}')
print(f'Updated validation dataset size: {X_val.shape}')
print(f'Updated test dataset size: {X_test.shape}')


In [None]:
plot(X_train, rows=10, cols=8)

In [None]:
# An additional check: you should not see any outlier
plot(X_train, mask=y_train, num_images=40)

In [None]:
# Define image size for the network (dataset size is 64 X 128) and num of classes
IMG_SIZE = (64, 128)
NUM_CLASSES = num_classes

In [None]:
# `concat_and_shuffle_aug_with_no_aug` will double the X_train size
def get_dataset(X, y, batch_size=32, augmentations=None, **kwargs):
  def resize_img_and_mask(img, mask):
    input_img = tf.image.resize(img, IMG_SIZE)
    input_img = tf.cast(input_img, tf.float32)

    # Resize needs at least 3 dims, add a dummy one
    target_img = tf.expand_dims(mask, axis=-1)
    # Nearest-neighbor is essential for resizing segmentation masks because it preserves the discrete class labels (e.g., 0, 1, 2) without introducing unintended values due to interpolation
    target_img = tf.image.resize(target_img, IMG_SIZE, method="nearest")
    target_img = tf.cast(target_img, tf.int32) # Consider lower integers

    return input_img, target_img

  # Apply augmentations before converting to dataset (this will be serial I think but we avoid type conversions as A works on np arrays)
  if augmentations is not None:
    X_a = []
    y_a = []
    for a, b in zip(X, y):
      aug_a, aug_b = augmentations(a, b)
      X_a.append(aug_a)  
      y_a.append(aug_b)  
    dataset = tf.data.Dataset.from_tensor_slices((np.array(X_a), np.array(y_a)))
    if kwargs.get('concat_and_shuffle_aug_with_no_aug', False):
      dataset = dataset.concatenate(tf.data.Dataset.from_tensor_slices((X, y)))
      dataset = dataset.shuffle(seed=seed, buffer_size=X.shape[0] * 2)

  else:
    dataset = tf.data.Dataset.from_tensor_slices((X, y))

  dataset = dataset.map(resize_img_and_mask, num_parallel_calls=tf.data.AUTOTUNE)
  dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
  return dataset

In [None]:
# Define model name
model_name = 'U_NET'

In [None]:
# Define training setup
epochs = 1000
batch_size = 16

In [None]:
# Define optimizer setup
lr = 1e-4
fine_tuning_lr = 1e-4
# One of:
# SGD
# Adam
# AdamW
# Lion
# Ranger
opt_name = "AdamW"
fine_tuning_opt_name = "AdamW"

opt_exp_decay_rate: float | None = None
# Decay at how many epochs
opt_decay_epoch_delta = 7

In [None]:
# Use sparse_categorical_crossentropy when your labels are integers representing class indices
loss_fn = 'sparse_categorical_crossentropy'

In [None]:
# Define fitting callbacks. Comment out from dict the unwanted ones
model_fit_callbacks = {
	#'ReduceLROnPlateau': tfk.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=1e-6, verbose=1),
	'EarlyStopping': tfk.callbacks.EarlyStopping(monitor='val_accuracy', mode='max', patience=15, restore_best_weights=True, verbose=1) #https://colab.research.google.com/drive/15h-47mevDv3hFXq5LUBh3XxyTTf9ZDx8#scrollTo=H6J65MMp4pA8&line=4&uniqifier=1
}

In [None]:
# just to free or not the memory
FREE_MODEL = False

## 🛠️ Define model, augmentation and utils builders

In [None]:
def build_augmentation():
  transform = A.Compose([
      #A.GridElasticDeform(num_grid_xy=(4, 4), magnitude=10, p=0.5),
      A.GridElasticDeform(num_grid_xy=(8, 8), magnitude=10),
      A.XYMasking(),
      A.ShiftScaleRotate()
  ])
  return transform

In [None]:
def apply_augmentation(img, mask):
  transform = build_augmentation()
  transformed = transform(image=img, mask=mask)
  return transformed["image"], transformed["mask"]

In [None]:
# Try the augmented dataset
N = 5
ds = get_dataset(X_train[:N], y_train[:N], augmentations=apply_augmentation, concat_and_shuffle_aug_with_no_aug=True)

for batch in ds.take(1):
  a, b = batch
  plot(a.numpy(), b.numpy(), num_images=N * 2) # use N * 2 as `concat_and_shuffle_aug_with_no_aug` is True 
  break


In [None]:
# taken from https://keras.io/examples/vision/oxford_pets_image_segmentation/
def build_U_NET_XCEPTION(img_size: tuple[int, int, int], num_classes):
  inputs = tfk.Input(shape=img_size) # One channel input

  ### [First half of the network: downsampling inputs] ###

  # Entry block
  x = tfkl.Conv2D(32, 3, strides=2, padding="same")(inputs)
  x = tfkl.BatchNormalization()(x)
  x = tfkl.Activation("relu")(x)

  previous_block_activation = x  # Set aside residual

  # Blocks 1, 2, 3 are identical apart from the feature depth.
  for filters in [64, 128, 256]:
    x = tfkl.Activation("relu")(x)
    x = tfkl.SeparableConv2D(filters, 3, padding="same")(x)
    x = tfkl.BatchNormalization()(x)

    x = tfkl.Activation("relu")(x)
    x = tfkl.SeparableConv2D(filters, 3, padding="same")(x)
    x = tfkl.BatchNormalization()(x)

    x = tfkl.MaxPooling2D(3, strides=2, padding="same")(x)

    # Project residual
    residual = tfkl.Conv2D(filters, 1, strides=2, padding="same")(
    previous_block_activation
    )
    x = tfkl.add([x, residual])  # Add back residual
    previous_block_activation = x  # Set aside next residual

  ### [Second half of the network: upsampling inputs] ###

  for filters in [256, 128, 64, 32]:
    x = tfkl.Activation("relu")(x)
    x = tfkl.Conv2DTranspose(filters, 3, padding="same")(x)
    x = tfkl.BatchNormalization()(x)

    x = tfkl.Activation("relu")(x)
    x = tfkl.Conv2DTranspose(filters, 3, padding="same")(x)
    x = tfkl.BatchNormalization()(x)

    x = tfkl.UpSampling2D(2)(x)

    # Project residual
    residual = tfkl.UpSampling2D(2)(previous_block_activation)
    residual = tfkl.Conv2D(filters, 1, padding="same")(residual)
    x = tfkl.add([x, residual])  # Add back residual
    previous_block_activation = x  # Set aside next residual

  # Add a per-pixel classification layer
  outputs = tfkl.Conv2D(num_classes, 3, activation="softmax", padding="same")(x)

  # Define the model
  model = tfk.Model(inputs, outputs, name='UNetXception')
  return model

In [None]:
def build_U_NET(img_size: tuple[int, int, int], num_classes):
  def unet_block(input_tensor, filters, kernel_size=3, activation='relu', stack=2, name=''):
    # Initialise the input tensor
    x = input_tensor

    # Apply a sequence of Conv2D, Batch Normalisation, and Activation layers for the specified number of stacks
    for i in range(stack):
        x = tfkl.Conv2D(filters, kernel_size=kernel_size, padding='same', name=name + 'conv' + str(i + 1))(x)
        x = tfkl.BatchNormalization(name=name + 'bn' + str(i + 1))(x)
        x = tfkl.Activation(activation, name=name + 'activation' + str(i + 1))(x)

    # Return the transformed tensor
    return x

  input_layer = tfkl.Input(shape=img_size, name='input_layer')

  # Downsampling path
  down_block_1 = unet_block(input_layer, 32, name='down_block1_')
  d1 = tfkl.MaxPooling2D()(down_block_1)

  down_block_2 = unet_block(d1, 64, name='down_block2_')
  d2 = tfkl.MaxPooling2D()(down_block_2)

  # Bottleneck
  bottleneck = unet_block(d2, 128, name='bottleneck')

  # Upsampling path
  u1 = tfkl.UpSampling2D()(bottleneck)
  u1 = tfkl.Concatenate()([u1, down_block_2])
  u1 = unet_block(u1, 64, name='up_block1_')

  u2 = tfkl.UpSampling2D()(u1)
  u2 = tfkl.Concatenate()([u2, down_block_1])
  u2 = unet_block(u2, 32, name='up_block2_')

  # Output Layer
  output_layer = tfkl.Conv2D(num_classes, kernel_size=1, padding='same', activation="softmax", name='output_layer')(u2)

  model = tf.keras.Model(inputs=input_layer, outputs=output_layer, name='UNet')
  return model

In [None]:
model_dict = {
	'U_NET': build_U_NET,
	'U_NET_XCEPTIO': build_U_NET_XCEPTION
}

In [None]:
def get_callbacks():
	return [i for i in model_fit_callbacks.values()]

In [None]:
def fit_model(model, data_loader=None, validation_data_loader=None):
  assert(data_loader is not None)
  assert(validation_data_loader is not None)
  fit_history = model.fit(
        data_loader,
	      epochs=epochs,
        validation_data=validation_data_loader,
	      callbacks=get_callbacks()
	    ).history
  return fit_history

In [None]:
# Taken from https://github.com/SeanSdahl/RangerOptimizerTensorflow/blob/master/module.py
def build_ranger(lr=1e-3, weight_decay=0.0):
  try:
    import tensorflow_addons as tfa
  except:
    raise Exception("You have to install tensorflow_addons package for Ranger. Please note that this package is available up to tensorflow==2.14")
  def ranger(sync_period=6,
           slow_step_size=0.5,
           learning_rate=lr,
           beta_1=0.9,
           beta_2=0.999,
           epsilon=1e-7,
           weight_decay=weight_decay,
           amsgrad=False,
           sma_threshold=5.0,
           total_steps=0,
           warmup_proportion=0.1,
           min_lr=0.,
           name="Ranger"):
    inner = tfa.optimizers.RectifiedAdam(learning_rate, beta_1, beta_2, epsilon, weight_decay, amsgrad, sma_threshold, total_steps, warmup_proportion, min_lr, name)
    optim = tfa.optimizers.Lookahead(inner, sync_period, slow_step_size, name)
    return optim
  return ranger()

In [None]:
def get_optimizer(is_fine_tuning = False, use_decay_fine_tuning = False, **kwargs):
	decay = opt_exp_decay_rate
	if is_fine_tuning and not use_decay_fine_tuning:
		decay = None

	opt = opt_name if not is_fine_tuning else fine_tuning_opt_name

	if opt == "SGD":
		optimizer = tf.keras.optimizers.SGD(learning_rate=lr, momentum=0.9 if 'momentum' not in kwargs else kwargs['momentum'])
		if decay is not None:
			lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
					initial_learning_rate=fine_tuning_lr if is_fine_tuning else lr,
					decay_steps=opt_decay_epoch_delta * (X_train.shape[0] // batch_size),  # Decay every 7 epochs
					decay_rate=opt_exp_decay_rate,
					staircase=True
			)
			optimizer.learning_rate = lr_schedule
			print(f'\n\n{"Finetuning: " if is_fine_tuning else "NotFinetuning: "}using {opt} optimizer with exp decay {decay} (momentum = {optimizer.momentum})\n\n')
			return optimizer
		else:
			optimizer.learning_rate = fine_tuning_lr if is_fine_tuning else lr
			print(f'\n\n{"Finetuning: " if is_fine_tuning else "NotFinetuning: "}using {opt} optimizer (momentum = {optimizer.momentum})\n\n')
			return optimizer

	elif opt == "Adam":
		if 'weight_decay' in kwargs:
			optimizer = tf.keras.optimizers.Adam(weight_decay=kwargs['weight_decay'])
		else:
			optimizer = tf.keras.optimizers.Adam()
		if decay is not None:
			lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
					initial_learning_rate=fine_tuning_lr if is_fine_tuning else lr,
					decay_steps=opt_decay_epoch_delta * (X_train.shape[0] // batch_size),  # Decay every 7 epochs
					decay_rate=opt_exp_decay_rate,
					staircase=True
			)
			optimizer.learning_rate = lr_schedule
			print(f'\n\n{"Finetuning: " if is_fine_tuning else "NotFinetuning: "}using {opt} optimizer with exp decay of {decay} weight decay = {optimizer.weight_decay}\n\n')
			return optimizer
		else:
			optimizer.learning_rate = fine_tuning_lr if is_fine_tuning else lr
			print(f'\n\n{"Finetuning: " if is_fine_tuning else "NotFinetuning: "}using {opt} optimizer (weight decay = {optimizer.weight_decay})\n\n')
			return optimizer

	elif opt == "AdamW":
		if 'weight_decay' in kwargs:
			optimizer = tf.keras.optimizers.AdamW(weight_decay=kwargs['weight_decay'])
		else:
			optimizer = tf.keras.optimizers.AdamW()
		if decay is not None:
			lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
					initial_learning_rate=fine_tuning_lr if is_fine_tuning else lr,
					decay_steps=opt_decay_epoch_delta * (X_train.shape[0] // batch_size),  # Decay every 7 epochs
					decay_rate=opt_exp_decay_rate,
					staircase=True
			)
			optimizer.learning_rate = lr_schedule
			print(f'\n\n{"Finetuning: " if is_fine_tuning else "NotFinetuning: "}using {opt} optimizer with exp decay of {decay} weight decay = {optimizer.weight_decay}\n\n')
			return optimizer
		else:
			optimizer.learning_rate = fine_tuning_lr if is_fine_tuning else lr
			print(f'\n\n{"Finetuning: " if is_fine_tuning else "NotFinetuning: "}using {opt} optimizer (weight decay = {optimizer.weight_decay})\n\n')
			return optimizer

	elif opt == "Lion":
		if 'weight_decay' in kwargs:
			optimizer = tf.keras.optimizers.Lion(weight_decay=kwargs['weight_decay'])
		else:
			optimizer = tf.keras.optimizers.Lion()
		if decay is not None:
			lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
					initial_learning_rate=fine_tuning_lr if is_fine_tuning else lr,
					decay_steps=opt_decay_epoch_delta * (X_train.shape[0] // batch_size),  # Decay every 7 epochs
					decay_rate=opt_exp_decay_rate,
					staircase=True
			)
			optimizer.learning_rate = lr_schedule
			print(f'\n\n{"Finetuning: " if is_fine_tuning else "NotFinetuning: "}using {opt} optimizer with exp decay of {decay} weight decay = {optimizer.weight_decay}\n\n')
			return optimizer
		else:
			optimizer.learning_rate = fine_tuning_lr if is_fine_tuning else lr
			print(f'\n\n{"Finetuning: " if is_fine_tuning else "NotFinetuning: "}using {opt} optimizer (weight decay = {optimizer.weight_decay})\n\n')
			return optimizer
	elif opt == "Ranger":
		optimizer = build_ranger(lr=lr if not is_fine_tuning else fine_tuning_lr, weight_decay=0.0 if 'weight_decay' not in kwargs else kwargs['weight_decay'])
		if decay is not None:
			raise RuntimeError("Not supported")
		else:
			optimizer.learning_rate = fine_tuning_lr if is_fine_tuning else lr
			print(f'\n\n{"Finetuning: " if is_fine_tuning else "NotFinetuning: "}using {opt} optimizer\n\n')
			return optimizer

In [None]:
def display_model(model):
	# Display a summary of the model architecture
	model.summary(expand_nested=True)

	# Display model architecture with layer shapes and trainable parameters
	tfk.utils.plot_model(model, expand_nested=True, show_trainable=True, show_shapes=True, dpi=70)

## 🍣 Define and display model

In [None]:
# Define custom Mean Intersection Over Union metric: the competition excludes the background class
class MeanIntersectionOverUnion(tf.keras.metrics.MeanIoU):
  def __init__(self, num_classes, labels_to_exclude=None, name="mean_iou", dtype=None):
    super(MeanIntersectionOverUnion, self).__init__(num_classes=num_classes, name=name, dtype=dtype)
    if labels_to_exclude is None:
      labels_to_exclude = [0]  # Default to excluding label 0
    self.labels_to_exclude = labels_to_exclude

  def update_state(self, y_true, y_pred, sample_weight=None):
    # Convert predictions to class labels
    y_pred = tf.math.argmax(y_pred, axis=-1)

    # Flatten the tensors
    y_true = tf.reshape(y_true, [-1])
    y_pred = tf.reshape(y_pred, [-1])

    # Apply mask to exclude specified labels
    for label in self.labels_to_exclude:
      mask = tf.not_equal(y_true, label)
      y_true = tf.boolean_mask(y_true, mask)
      y_pred = tf.boolean_mask(y_pred, mask)

    # Update the state
    return super().update_state(y_true, y_pred, sample_weight)

In [None]:
model = model_dict[model_name](IMG_SIZE + (1,), NUM_CLASSES)

model.compile(loss=loss_fn, optimizer=get_optimizer(is_fine_tuning=False), metrics=['accuracy', MeanIntersectionOverUnion(num_classes=NUM_CLASSES, labels_to_exclude=[0])])
display_model(model)

## 🧗🏻‍♂️ Train and save

In [None]:
# Fit the initial model
print('\n\nFitting model\n\n')
fit_history = fit_model(model, data_loader=get_dataset(X_train, y_train, batch_size=batch_size, augmentations=apply_augmentation, concat_and_shuffle_aug_with_no_aug=True), validation_data_loader=get_dataset(X_val, y_val, batch_size=batch_size))

# Calculate and print the final validation accuracy
final_val_meanIoU = round(max(fit_history['val_mean_iou'])* 100, 2)
print(f'Final validation Mean Intersection Over Union: {final_val_meanIoU}%')

# Save intermediate model
model_filename = f'{model_name}-{str(final_val_meanIoU)}-{datetime.now().strftime("%y%m%d_%H%M")}.keras'
model.save(model_filename)

# Free memory by deleting the model instance
if FREE_MODEL:
  del model

In [None]:
def plot_trainig(fit):
  # Plot and display training and validation loss
  plt.figure(figsize=(18, 3))
  plt.plot(fit['loss'], label='Training', alpha=0.8, color='#ff7f0e', linewidth=2)
  plt.plot(fit['val_loss'], label='Validation', alpha=0.9, color='#5a9aa5', linewidth=2)
  plt.title('Cross Entropy')
  plt.legend()
  plt.grid(alpha=0.3)
  plt.show()

  # Plot and display training and validation accuracy
  plt.figure(figsize=(18, 3))
  plt.plot(fit['accuracy'], label='Training', alpha=0.8, color='#ff7f0e', linewidth=2)
  plt.plot(fit['val_accuracy'], label='Validation', alpha=0.9, color='#5a9aa5', linewidth=2)
  plt.title('Accuracy')
  plt.legend()
  plt.grid(alpha=0.3)
  plt.show()

  # Plot and display training and validation mean IoU
  plt.figure(figsize=(18, 3))
  plt.plot(fit['mean_iou'], label='Training', alpha=0.8, color='#ff7f0e', linewidth=2)
  plt.plot(fit['val_mean_iou'], label='Validation', alpha=0.9, color='#5a9aa5', linewidth=2)
  plt.title('Mean Intersection over Union')
  plt.legend()
  plt.grid(alpha=0.3)
  plt.show()

In [None]:
plot_trainig(fit_history)

## ✍🏿 Make inference

In [None]:
test_dataset = get_dataset(X_test, y_test, batch_size=8)

In [None]:
# Load UNet model without compiling
model = tfk.models.load_model('UNet_59.26.keras', compile=False)

# Compile the model with specified loss, optimizer, and metrics
model.compile(
    loss=loss_fn,
    optimizer=get_optimizer(),
    metrics=["accuracy", MeanIntersectionOverUnion(num_classes=NUM_CLASSES, labels_to_exclude=[0])]
)

display_model(model)

# Evaluate the model on the test set and print the results
test_loss, test_accuracy, test_mean_iou = model.evaluate(test_dataset, verbose=1)
print(f'Test Accuracy: {round(test_accuracy, 4)}')
print(f'Test Mean Intersection over Union: {round(test_mean_iou, 4)}')

In [None]:
def create_segmentation_colormap(num_classes):
  """
  Create a linear colormap using a predefined palette.
  Uses 'viridis' as default because it is perceptually uniform
  and works well for colorblindness.
  """
  return plt.cm.viridis(np.linspace(0, 1, num_classes))

def apply_colormap(label, colormap=None):
  """
  Apply the colormap to a label.
  """
  # Ensure label is 2D
  label = np.squeeze(label)

  if colormap is None:
    num_classes = len(np.unique(label))
    colormap = create_segmentation_colormap(num_classes)

  # Apply the colormap
  colored = colormap[label.astype(int)]

  return colored

def plot_triptychs(dataset, model, num_samples=1):
  """
  Plot triptychs (original image, true mask, predicted mask) for samples from a tf.data.Dataset

  Parameters:
  dataset: tf.data.Dataset - The dataset containing image-label pairs
  model: tf.keras.Model - The trained model to generate predictions
  num_samples: int - Number of samples to plot
  """
  # Take samples from the dataset
  samples = dataset.take(num_samples)

  for images, labels in samples:
    # If we have a batch, take the first example
    if len(images.shape) == 4:  # Batch of images
      images = images[0:1]
      labels = labels[0:1]

    # Generate predictions
    pred = model.predict(images, verbose=0)
    pred = tf.math.argmax(pred, axis=-1)

    # Create colormap based on number of classes in labels
    labels_np = labels.numpy()
    num_classes = len(np.unique(labels_np))
    colormap = create_segmentation_colormap(num_classes)

    # Create figure with subplots
    fig, axes = plt.subplots(1, 3, figsize=(20, 4))

    # Plot original image
    axes[0].set_title("Original Image")
    axes[0].imshow(images[0])
    axes[0].axis('off')

    # Plot original mask
    axes[1].set_title("Original Mask")
    colored_label = apply_colormap(labels[0], colormap)
    axes[1].imshow(colored_label)
    axes[1].axis('off')

    # Plot predicted mask
    axes[2].set_title("Predicted Mask")
    colored_pred = apply_colormap(pred[0], colormap)
    axes[2].imshow(colored_pred)
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()
    plt.close()

plot_triptychs(test_dataset, model, num_samples=2)