<a href="https://colab.research.google.com/github/deasadiqbal/computer-vision-project-with-keras-and-tensorflow/blob/main/Image_Segmentation_with_U_Net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install keras_cv

In [None]:
!pip install keras-core


In [None]:
import keras
import keras_cv
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt


## Configuration

In [None]:
#image config
IMG_HEIGHT = 160
IMG_WIDTH = 160
NUM_CLASSES = 3

#Augmentation Config
ROTATION_FACTOR = (-0.2, 0.2)

#Training config
BATCH_SIZE = 64 # You can set different BATCH_SIZE
EPOCHS = 10 # You can set different EPOCHS
LEARNING_RATE = 1e-4
AUTOTUNE = tf.data.AUTOTUNE

## Download The data


In [None]:
tfds.disable_progress_bar()
train_ds, test_ds = tfds.load(
    name="oxford_iiit_pet",
    split=["train+test[:80%]", "test[80%:]"],
)

## Preprocessing

In [None]:
preprocessed = lambda inputs: {
    "images": tf.cast(inputs["image"], dtype=tf.float32) / 255.0,
    "segmentation_masks": inputs["segmentation_mask"] - 1,
}

train_ds = train_ds.map(
    preprocessed, num_parallel_calls=AUTOTUNE
)
val_ds = test_ds.map(preprocessed, num_parallel_calls=AUTOTUNE)

## Utility Function

In [None]:
def unpack_inputs(inputs):
    # Ensure segmentation_masks have the same data type across branches
    inputs['segmentation_masks'] = tf.cast(inputs['segmentation_masks'], tf.float32)
    return inputs['images'], inputs['segmentation_masks']

In [None]:
plot_train_ds = train_ds.map(unpack_inputs).ragged_batch(4)
images, seg_masks = next(iter(plot_train_ds.take(1)))

keras_cv.visualization.plot_segmentation_mask_gallery(
    images,
    value_range = (0, 1),
    num_classes = 3,
    y_true = seg_masks,
    y_pred = None,
    scale = 4,
    rows = 2,
    cols = 2,
)

## data Augmentation

In [None]:
resize_fn = keras_cv.layers.Resizing(
    IMG_HEIGHT,
    IMG_WIDTH,
)

augment_fn = keras.Sequential(
    [
        resize_fn,
        keras_cv.layers.RandomFlip(),
        keras_cv.layers.RandomRotation(
            factor=ROTATION_FACTOR,
            segmentation_classes=NUM_CLASSES,
        ),

    ]
)

In [None]:
augment_train_ds = (
    train_ds.shuffle(BATCH_SIZE * 2)
    .map(augment_fn, num_parallel_calls=AUTOTUNE)
    .map(unpack_inputs)
    .batch(BATCH_SIZE)
    .prefetch(buffer_size=tf.data.AUTOTUNE)
)

resized_test_ds = (
    val_ds.map(resize_fn, num_parallel_calls=AUTOTUNE)
    .map(unpack_inputs)
    .batch(BATCH_SIZE)
    .prefetch(buffer_size=tf.data.AUTOTUNE)
)

## visualization

In [None]:
images, seg_masks = next(iter(augment_train_ds.take(1)))
keras_cv.visualization.plot_segmentation_mask_gallery(
    images,
    value_range = (0, 1),
    num_classes = 3,
    y_true = seg_masks,
    y_pred = None,
    scale = 4,
    rows = 2,
    cols =2,
)

## Model Architecture

In [None]:
def get_model(img_size, num_classes):
  inputs = keras.Input(shape = img_size + (3,))

  # Downsampling the inputs

  x = keras.layers.Conv2D(32, 3, strides=2, padding='same')(inputs)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.Activation('relu')(x)

  previous_block_activation = x

  for filters in [64, 128, 256]:
    x = keras.layers.Activation('relu')(x)
    x = keras.layers.SeparableConv2D(filters, 3, padding = 'same')(x)
    x = keras.layers.BatchNormalization()(x)

    x = keras.layers.Activation('relu')(x)
    x = keras.layers.SeparableConv2D(filters, 3, padding = 'same')(x)
    x = keras.layers.BatchNormalization()(x)

    x = keras.layers.MaxPooling2D(3, strides=2, padding='same')(x)

    #project residual
    residual = keras.layers.Conv2D(filters, 1, strides=2, padding='same')(previous_block_activation)
    x = keras.layers.add([x, residual]) #add back residual
    previous_block_activation = x #set aside next residual

  # Upsmaple the inputs

  for  filters in [256, 128, 64, 32]:
    x = keras.layers.Activation('relu')(x)
    x = keras.layers.Conv2DTranspose(filters, 3, padding = 'same')(x)
    x = keras.layers.BatchNormalization()(x)

    x = keras.layers.Activation('relu')(x)
    x = keras.layers.Conv2DTranspose(filters, 3, padding = 'same')(x)
    x = keras.layers.BatchNormalization()(x)

    x = keras.layers.UpSampling2D(2)(x)

    # Project residual
    residual = keras.layers.UpSampling2D(2)(previous_block_activation)
    residual = keras.layers.Conv2D(filters, 1, padding='same')(residual)
    x = keras.layers.add([x, residual])
    previous_block_activation = x

  outputs = keras.layers.Conv2D(num_classes, 3, activation = 'softmax', padding='same')(x)

  # Define Model
  model = keras.Model(inputs, outputs)
  return model


In [None]:
# measure the model progress
test_images, test_masks = next(iter(resized_test_ds))

In [None]:
import random
class DisplayCallback(keras.callbacks.Callback):
    def __init__(self, epoch_interval=None):
        self.epoch_interval = epoch_interval

    def on_epoch_end(self, epoch, logs=None):
        if self.epoch_interval and epoch % self.epoch_interval == 0:
            pred_masks = self.model.predict(test_images)
            pred_masks = tf.math.argmax(pred_masks, axis=-1)
            pred_masks = pred_masks[..., tf.newaxis]

            # Randomly select an image from the test batch
            random_index = random.randint(0, BATCH_SIZE - 1)
            random_image = test_images[random_index]
            random_pred_mask = pred_masks[random_index]
            random_true_mask = test_masks[random_index]

            fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 5))
            ax[0].imshow(random_image)
            ax[0].set_title(f"Image: {epoch:03d}")

            ax[1].imshow(random_true_mask)
            ax[1].set_title(f"Ground Truth Mask: {epoch:03d}")

            ax[2].imshow(random_pred_mask)
            ax[2].set_title(
                f"Predicted Mask: {epoch:03d}",
            )

            plt.show()
            plt.close()


callbacks = [DisplayCallback(5)]

## Model Training

In [None]:
model = get_model(img_size=(IMG_HEIGHT, IMG_WIDTH), num_classes= NUM_CLASSES)
model.compile(
    optimizer = keras.optimizers.Adam(LEARNING_RATE),
    loss = 'sparse_categorical_crossentropy',
    metrics = ['accuracy']
)
history = model.fit(
    augment_train_ds,
    epochs = EPOCHS,
    validation_data = resized_test_ds,
    callbacks = callbacks,
)

## Infrences

In [None]:
pred_masks = model.predict(test_images)
pred_masks = tf.math.argmax(pred_masks, axis=-1)[..., None]

keras_cv.visualization.plot_segmentation_mask_gallery(
    test_images,
    value_range = (0, 1),
    num_classes = 3,
    y_true = test_masks,
    y_pred = pred_masks,
    scale = 4,
    rows = 2,
    cols =2,
)