# Import and Google Drive mount

In [None]:
# Mount Google Drive
import os
from google.colab import drive

# Check if the /content/drive directory exists
drive_mounted = os.path.exists('/content/drive')

if drive_mounted:
    print("Google Drive is already mounted.")
else:
    drive.mount('/content/drive')

In [None]:
os.chdir('/content/drive/MyDrive/DL_Explained/computer_vision')

In [None]:
import sys
sys.path.append('../')
try:
  import tensorflow_addons as tfa
  import keras_cv
except:
  !pip install keras_cv
  !pip install tensorflow-addons
  import tensorflow_addons as tfa
import tensorflow as tf
import numpy as np
from tensorflow import keras
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from helper import set_model_config
from helper import plot_loss, visualize_segmentation_predictions
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.callbacks import ModelCheckpoint
from helper import plot_loss
from keras import layers
from keras.models import Model
'''Create a random seed generator for randomized TF ops'''
rng = tf.random.Generator.from_seed(123, alg='philox')

# Load the Dataset and preview basic info

In [None]:
# Load the Oxford pets
(train_ds, val_ds, test_ds), info = tfds.load(
    'oxford_iiit_pet:3.*.*',
    split=['train+test[:50%]', 'test[50%:80%]', 'test[80%:100%]'],
    with_info=True)

# Access and print dataset information
print("Oxford pets dataset information:")
print(f"Number of classes: {info.features['label'].num_classes}")
print(f"Class names: {info.features['label'].names}")
print(f"Number of training examples: {info.splits['train'].num_examples}")
print(f"Dataset splits: {list(info.splits.keys())}")
print(f"Dataset description: {info.description}")

In [None]:
info._features

# Set model config

In [None]:
config = set_model_config("oxford_unet")
config

# Create data pre-processing pipeline

In [None]:
# Apply augmentations to both mask and images for trainset
def augmentations(image, mask):
    image = tf.image.resize(image , (224,224))
    mask = tf.image.resize(mask, (224,224))

    # Random horizontal flip for data augmentation
    if tf.random.uniform(()) > 0.5:
        image = tf.image.flip_left_right(image)
        mask = tf.image.flip_left_right(mask)

    return image, tf.cast(mask, dtype = tf.uint8)

# Resize mask/images pre-process for inference
def resize_inference(image, mask):
    image = tf.image.resize(image , (224,224))
    mask = tf.image.resize(mask, (224,224))

    return image, tf.cast(mask, dtype = tf.uint8)

# Convert to binary
def binary_mask(mask):
    mask = tf.cast(mask, dtype = tf.int32)
    converted_mask = tf.where(tf.equal(mask, 1), 1, mask)
    converted_mask = tf.where(tf.equal(mask, 2), 0, converted_mask)
    converted_mask = tf.where(tf.equal(mask, 3), 0, converted_mask)
    return tf.cast(converted_mask, dtype = tf.uint8)

# Pre-process trainset
def preprocess_train(element):
    image = tf.image.convert_image_dtype(element['image'], tf.float32)
    segmentation_mask = tf.image.convert_image_dtype(element['segmentation_mask'], tf.uint8)
    segmentation_mask = binary_mask(segmentation_mask)

    # Apply augmentations
    image, segmentation_mask = augmentations(image, segmentation_mask)

    return image, segmentation_mask

# Pre-process val/test sets
def preprocess_val_test(element):
    image = tf.image.convert_image_dtype(element['image'], tf.float32)
    segmentation_mask = tf.image.convert_image_dtype(element['segmentation_mask'], tf.uint8)
    segmentation_mask = binary_mask(segmentation_mask)


    # Resize image and mask for inference
    image, segmentation_mask = resize_inference(image, segmentation_mask)

    return image, segmentation_mask

train_dataset = train_ds.map(preprocess_train, num_parallel_calls = tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.batch(config['batch_size'])
train_dataset = train_dataset.shuffle(buffer_size=config['batch_size'] * 6)
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

validation_dataset = val_ds.map(preprocess_val_test, num_parallel_calls = tf.data.experimental.AUTOTUNE)
validation_dataset = validation_dataset.batch(config['batch_size'])
validation_dataset = validation_dataset.shuffle(buffer_size=config['batch_size'] * 6)
validation_dataset = validation_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)


test_dataset = test_ds.map(preprocess_val_test)
test_dataset = test_dataset.batch(config['batch_size'])
test_dataset = test_dataset.shuffle(buffer_size=8 * 6)
test_dataset = test_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)


# Visualize a few samples from the Dataset

In [None]:
def visualize_samples(dataset, num_samples=3):
    # Take the first 'num_samples' samples from the dataset
    sample_dataset = dataset.take(num_samples)

    # Create a subplot for displaying images in the grid
    fig, axes = plt.subplots(num_samples, 2, figsize=(10, 5 * num_samples))

    # Iterate through the samples dataset
    for i, batch in enumerate(sample_dataset):
        image = batch['image'].numpy().astype(int)
        label = batch['label'].numpy().astype(int)
        segmentation_mask = batch['segmentation_mask'].numpy().astype(int)

        # Plot original image
        axes[i, 0].imshow(image)
        axes[i, 0].set_title(f'Original Image - Sample {i + 1}')
        axes[i, 0].axis('off')

        # Plot segmentation mask
        axes[i, 1].imshow(segmentation_mask[:, :, 0], cmap='gray')
        axes[i, 1].set_title(f'Segmentation Mask - Sample {i + 1}')
        axes[i, 1].axis('off')

with plt.style.context('dark_background'):
  visualize_samples(train_ds)

# U-net Architecture

In [None]:
from tensorflow.keras import layers
from tensorflow.keras.applications import MobileNetV2

# Unet with pre-trained encoder
def unet_model(img_size, num_classes):
    base_model = tf.keras.applications.MobileNetV2(input_shape=img_size, include_top=False)

    names = ['block_1_expand_relu', 'block_3_expand_relu', 'block_6_expand_relu',
         'block_13_expand_relu', 'block_16_expand_relu']
    encoder_layers = [base_model.get_layer(name).output for name in names]

    down_sample = tf.keras.Model(base_model.input, encoder_layers)
    down_sample.trainable = False

    # Downsampling through the model
    inputs = keras.Input(shape = img_size)
    x = inputs
    skips = down_sample(x)
    x = skips[-1]
    skips = reversed(skips[:-1])

    up_stack = [layers.Conv2DTranspose(512, (3, 3), strides=(2, 2), padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same', activation='relu', kernel_initializer='he_normal'),
                layers.Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', activation='relu', kernel_initializer='he_normal'),
              ]

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])

    conv10 = layers.Conv2DTranspose(1, 3, strides=2 , padding='same')(x)

    # Create the UNet model
    model = keras.models.Model(inputs=inputs, outputs=[conv10])

    # Compile the model
    model.compile(optimizer=keras.optimizers.Adam(learning_rate=config['learning_rate']),
                  loss=keras.losses.BinaryCrossentropy(from_logits = True),
                  metrics=[tf.keras.metrics.BinaryIoU(target_class_ids=[0, 1],
                                                      name = 'IoU'),
                           'accuracy'])

    return model

model = unet_model((224, 224, 3), config['n_classes'])
model.summary()

# Set training callbacks and train the model

In [None]:
# Set a learning rate scheduler to progressively reduce the learning rate as learning plateaus
lr_scheduler = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
                                 patience=5, min_lr=0.0001)

# Always save the best model
saving_cb = ModelCheckpoint(
    filepath='./trained_models/oxford_segmentation/best_weights.h5',
    save_weights_only=True,
    monitor='val_loss',
    mode='min',
    save_best_only=True,
    verbose=1
)

# Train the model with class weights and EarlyStopping
history = model.fit(train_dataset, epochs=2, validation_data=validation_dataset, callbacks = [lr_scheduler, saving_cb])

# Plot losses and save the final model

In [None]:
# Plot with dark backgorund
with plt.style.context('dark_background'):
    plot_loss(history, model_type = 'segmentation')

In [None]:
# Save the final model
model.save("./trained_models/oxford_segmentation/oxford_unet")

# Make some predictions on the test set and visualize them

In [None]:
with plt.style.context('dark_background'):
  visualize_segmentation_predictions(test_dataset, model, num_samples=3, threshold=0.5)