Implementing a U-Net architecture for image segmentation using TensorFlow and Keras:

from google.colab import drive: This line imports the drive module from the google.colab package. This module contains functionality to mount Google Drive within a Google Colab notebook.
drive.mount('/content/drive'): This line calls the mount() function from the drive module, which mounts the user's Google Drive to the specified directory (/content/drive in this case). After running this command, you'll be prompted to authenticate and grant permissions to access your Google Drive. Once authenticated, your Google Drive will be mounted, and you can access its contents through the specified directory path.

In [None]:
# Importing the necessary module from Google Colab to mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
# Mounting Google Drive to the '/content/drive' directory

basePath = "/content/drive/MyDrive/Kvasir-SEG": This line assigns the base path where the dataset is located in Google Drive to the variable basePath.
def address(path=''): This line defines a function named address that takes an optional path argument. If no path is provided, it defaults to an empty string.
return f"{basePath}/{path}": Inside the function, this line constructs the full address by combining the basePath and the provided path (if any) using f-string formatting. The resulting address is then returned by the function.

In [3]:
# Define the base path where the dataset is located in Google Drive
basePath = "/content/drive/MyDrive/Kvasir-SEG"

# Define a function to construct the full address based on the base path and additional path
def address (path=''):
    return f"{basePath}/{path}"

The code begins by importing necessary libraries and modules.
Libraries such as NumPy (np), Pandas (pd), and tqdm are imported for numerical operations, data processing, and progress tracking, respectively.
TensorFlow and Keras modules are imported for deep learning operations.
skimage, PIL, and matplotlib.pyplot are imported for image processing and visualization.
The base path for the dataset is defined as basePath.
A function named address is defined to construct the full address based on the base path and additional path provided as an argument.

In [4]:
# Importing necessary libraries
import numpy as np  # For numerical operations
import pandas as pd # For data processing and CSV file I/O
from tqdm import tqdm # For progress tracking
import os # For interacting with the operating system

# Deep learning framework imports
import tensorflow as tf
from keras.models import Model
from keras.layers import Input, Add, Dense, Dropout, Activation, ZeroPadding2D, BatchNormalization, Concatenate, Flatten, Conv2D, AveragePooling2D, MaxPool2D, Reshape, Conv2DTranspose
from tensorflow.keras.layers import Input, Add, Dense, Activation, ZeroPadding2D, BatchNormalization, Flatten, Conv2D, AveragePooling2D, MaxPooling2D, Reshape
from tensorflow.keras.initializers import glorot_uniform

# Additional image processing imports
from skimage.color import rgb2gray as rtg # For converting RGB images to grayscale
from skimage.io import imread, imshow # For reading and displaying images
from skimage.transform import resize # For resizing images
import matplotlib.pyplot as plt
from skimage.morphology import label # For labeling connected components in an image
from PIL import Image # For image manipulation

os.walk() is a function used to traverse a directory tree, generating the file names in a directory tree by walking either top-down or bottom-up.
In this loop, dirname represents the current directory being processed, _ represents any subdirectories (which are ignored), and filenames represents the list of files in the current directory.
tqdm(total=len(filenames)) creates a progress bar with a total count equal to the number of files (len(filenames)) in the current directory.
print(f"[INFO] Successfully imported directory: {dirname}") prints information about the directory being processed.
The inner loop iterates over each filename in the current directory, updating the progress bar with t.update(1) for each file processed.

In [None]:
# Iterate over directories, subdirectories, and files in the specified directory and its subdirectories
for dirname, _, filenames in os.walk(address("..")):
    # Initialize a tqdm progress bar with the total number of files in the current directory
    with tqdm(total=len(filenames)) as t:
        # Print information about the successfully imported directory
        print(f"[INFO] Successfully imported directory: {dirname}")
        # Iterate over the filenames in the current directory
        for filename in filenames:
            # Update the progress bar for each file processed
            t.update(1)


mainDirPath = address() sets the variable mainDirPath to the base directory path retrieved using the address() function defined earlier.
imagesPath = mainDirPath + "/images" creates the path for the directory containing images within the main directory (mainDirPath).
masksPath = mainDirPath + "/masks" creates the path for the directory containing masks within the main directory (mainDirPath).

In [6]:
# Define the main directory path based on the base path
mainDirPath = address()

# Define the paths for images and masks directories within the main directory
imagesPath = mainDirPath + "/images"
masksPath = mainDirPath + "/masks"

os.listdir(directory) returns a list containing the names of the entries in the directory given by directory.
len(os.listdir(imagesPath)) calculates the total number of files (images) in the imagesPath directory.
len(os.listdir(masksPath)) calculates the total number of files (masks) in the masksPath directory.
Finally, the print statement displays the total number of images and masks found in their respective directories.

In [None]:
# Print the total number of images and masks in their respective directories
print(f"Total images: {len(os.listdir(imagesPath))}\nTotal masks: {len(os.listdir(masksPath))}")

This function plot_training is designed to visualize the training and validation metrics (loss and accuracy) over epochs.
It extracts the training and validation metrics from the history object returned by the fit() method of a Keras model.
It finds the epoch with the lowest validation loss and the epoch with the highest validation accuracy.
It plots the training and validation loss in one plot and the training and validation accuracy in another plot, with markers indicating the epochs with the lowest validation loss and the highest validation accuracy.

In [21]:

def plot_training(hist):
    # Extracting training and validation metrics from the history object
    tr_acc = hist.history['accuracy'] # Training accuracy
    tr_loss = hist.history['loss']  # Training loss
    val_acc = hist.history['val_accuracy']  # Validation accuracy
    val_loss = hist.history['val_loss']  # Validation loss
    
    # Finding the index of the epoch with the lowest validation loss
    index_loss = np.argmin(val_loss)
    # Extracting the lowest validation loss value
    val_lowest = val_loss[index_loss]
    
    # Finding the index of the epoch with the highest validation accuracy
    index_acc = np.argmax(val_acc)
    # Extracting the highest validation accuracy value
    acc_highest = val_acc[index_acc]
    
    # Creating a list of epochs
    Epochs = [i+1 for i in range(len(tr_acc))]
    
    # Creating labels for the lowest validation loss and highest validation accuracy
    loss_label = f'Best epoch= {str(index_loss + 1)}'
    acc_label = f'Best epoch= {str(index_acc + 1)}'

    # Plotting the training and validation loss
    plt.style.use('seaborn-dark')
    plt.figure(figsize=(10, 6))
    plt.plot(Epochs, tr_loss, 'orange', label='Training loss')
    plt.plot(Epochs, val_loss, 'blue', label='Validation loss')
    plt.scatter(index_loss + 1, val_lowest, s=150, c='red', label=loss_label)
    plt.title('Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()
    
    # Plotting the training and validation accuracy
    plt.figure(figsize=(10, 6))
    plt.plot(Epochs, tr_acc, 'orange', label='Training Accuracy')
    plt.plot(Epochs, val_acc, 'blue', label='Validation Accuracy')
    plt.scatter(index_acc + 1 , acc_highest, s=150, c='red', label=acc_label)
    plt.title('Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

os.walk(directory) generates the file names in a directory tree by walking either top-down or bottom-up. It returns a tuple of directory path, directory names, and file names.
next(os.walk(imagesPath))[2] gets the third element of the tuple returned by os.walk(imagesPath), which contains the list of file names in the imagesPath directory.
next(os.walk(masksPath))[2] does the same for the masksPath directory.
So, images_ids and masks_ids contain lists of filenames (ids) of images and masks, respectively.

In [9]:
# Retrieve the list of filenames (ids) of images in the imagesPath directory
images_ids = next(os.walk(imagesPath))[2]
# Retrieve the list of filenames (ids) of masks in the masksPath directory
masks_ids = next(os.walk(masksPath))[2]

np.zeros() is a NumPy function that creates an array filled with zeros.
For X, the shape (len(images_ids), 256, 256, 3) indicates the number of images, image height, image width, and number of channels (RGB).
For Y, the shape (len(masks_ids), 256, 256, 1) indicates the number of masks, mask height, mask width, and a single channel for binary mask.
dtype=np.uint8 specifies the data type as unsigned 8-bit integer, suitable for storing pixel values in the range [0, 255] for images.
dtype=np.bool_ specifies the data type as Boolean, suitable for representing binary mask values (True/False)

In [10]:
# Initialize an empty NumPy array for storing image data
# The shape is determined by the number of images, image dimensions (256x256), and number of channels (3 for RGB)
# The data type is set to uint8 (unsigned 8-bit integer) to represent pixel values in the range [0, 255]
X = np.zeros((len(images_ids), 256, 256, 3), dtype=np.uint8)

# Initialize an empty NumPy array for storing mask data
# The shape is determined by the number of masks, image dimensions (256x256), and a single channel (binary mask)
# The data type is set to bool_ (Boolean) to represent binary mask values (True/False)
Y = np.zeros((len(masks_ids), 256, 256, 1), dtype=np.bool_)

The code iterates over the image IDs (images_ids) using enumerate() to also get the index of each image.
For each image, it constructs the full path to the image file (path).
It reads the image using imread() and selects only the RGB channels.
The image is resized to the desired dimensions (256x256) using resize().
The resized image is stored in the X array at index n.
Similarly, for each image, it initializes an empty mask array, reads the corresponding mask image, converts it to grayscale, resizes it to the desired dimensions (256x256), and stores the resized mask in the Y array at index n

In [None]:
print("Resizing training images and masks")

# Iterate over the image IDs and their corresponding indices
for n, id_ in tqdm(enumerate(images_ids), total=len(images_ids)):
    # Construct the path to the image file
    path = imagesPath + "/" + id_
    
    # Read the image using imread and select only the RGB channels
    img = imread(path)[:, :, :3]
    
    # Resize the image to the desired dimensions (256x256)
    img = resize(img, (256, 256), mode="constant", preserve_range=True)
    
    # Store the resized image in the X array
    X[n] = img
    
    # Initialize an empty mask array
    mask = np.zeros((256, 256, 1), dtype=np.bool_)
    
    # Read the mask image
    mask = imread(masksPath + "/" + id_)
    
    # Convert the mask image to grayscale
    mask = rtg(mask)
    
    # Resize the mask image to the desired dimensions (256x256)
    mask = np.expand_dims(resize(mask, (256, 256), mode="constant", preserve_range=True), axis=-1)
    
    # Store the resized mask in the Y array
    Y[n] = mask

X.shape returns the shape of the array X, which represents the training images. It is a tuple containing the number of images, image height, image width, and number of channels (3 for RGB).
Y.shape returns the shape of the array Y, which represents the corresponding masks for the training images. It is a tuple containing the number of masks, mask height, mask width, and a single channel for binary masks.
The print statement displays these shapes for both X and Y in a formatted string.

In [None]:
# Print the shapes of the training images and masks
print(f"X_train.shape: {X.shape}\nY_train.shape: {Y.shape}")

random.randint(0, len(images_ids)) generates a random integer between 0 and the total number of images.
X[image_x] retrieves the image corresponding to the randomly selected index image_x.
imshow() displays the image.
plt.show() displays the image plot.
np.squeeze(Y[image_x]) removes single-dimensional entries from the shape of the array Y[image_x].
plt.imshow() displays the mask.
plt.show() displays the mask plot.

In [None]:
import random

# Randomly select an image index
image_x = random.randint(0, len(images_ids))

# Display the randomly selected image
imshow(X[image_x])

# Display the corresponding mask for the selected image
plt.show()
plt.imshow(np.squeeze(Y[image_x]))

X[:900] selects the first 900 images from the array X, which is the training set.
Y[:900] selects the corresponding masks for the first 900 images, which form the training set.
X[900:] selects the images from index 900 onwards, which form the validation set.
Y[900:] selects the corresponding masks for the validation set.

In [14]:
# Define training and validation sets for images and masks
x_train = X[:900] # Take the first 900 images for training
y_train = Y[:900] # Take the corresponding masks for training
x_val = X[900:]  # Take the remaining images for validation
y_val = Y[900:] # Take the corresponding masks for validation

In [15]:
# Define the input layer with shape (256, 256, 3)
input = tf.keras.layers.Input((256, 256, 3))

# Normalize input images to the range [0,1]
s = tf.keras.layers.Lambda(lambda x: x / 256)(input)

# Contracting path
# Downward path with convolutional and max pooling layers to extract features
c1 = tf.keras.layers.Conv2D(8, (3, 3), activation="relu", kernel_initializer='he_normal', padding='same')(s)
c1 = tf.keras.layers.Conv2D(8, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
p1 = tf.keras.layers.MaxPooling2D((2, 2))(c1)

c2 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
c2 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
p2 = tf.keras.layers.MaxPooling2D((2, 2))(c2)

c3 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p2)
c3 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c3)
p3 = tf.keras.layers.MaxPooling2D((2, 2))(c3)

c4 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p3)
c4 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c4)
p4 = tf.keras.layers.MaxPooling2D((2, 2))(c4)

c5 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p4)
c5 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5)

# Expansive path
# Upward path with transpose convolutional layers and skip connections
u6 = tf.keras.layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c5)
u6 = tf.keras.layers.concatenate([u6, c4])
c6 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u6)
c6 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c6)

u7 = tf.keras.layers.Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(c6)
u7 = tf.keras.layers.concatenate([u7, c3])
c7 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u7)
c7 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c7)

u8 = tf.keras.layers.Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(c7)
u8 = tf.keras.layers.concatenate([u8, c2])
c8 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u8)
c8 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c8)

u9 = tf.keras.layers.Conv2DTranspose(8, (2, 2), strides=(2, 2), padding='same')(c8)
u9 = tf.keras.layers.concatenate([u9, c1])
c9 = tf.keras.layers.Conv2D(8, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u9)
c9 = tf.keras.layers.Conv2D(8, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c9)

# Output layer with sigmoid activation function for binary classification
outputs = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid')(c9)

# Define the U-Net model
modelUNet = tf.keras.Model(inputs=input, outputs=outputs, name='U-NET')

# Compile the model with Adam optimizer and binary cross-entropy loss
modelUNet.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

In [None]:
modelUNet.summary()

Callbacks are functions that can be applied at different stages of the training process. They can be used to perform tasks like logging training statistics, saving model checkpoints, and stopping training early if certain conditions are met.
In this code, we define a callback using tf.keras.callbacks.TensorBoard to log training statistics. The log_dir parameter specifies the directory where the logs will be written.
The fit method is then called on the modelUNet object to train the model on the provided data (x_train and y_train).
During training, the model will use 1% of the training data for validation (validation_split=0.01).
The batch_size parameter determines the number of samples that will be processed before updating the model's weights.
The epochs parameter specifies the number of times the entire dataset will be passed through the model during training.
The callbacks parameter is used to specify the list of callbacks to be applied during training. In this case, we pass the list containing the TensorBoard callback we defined earlier

In [None]:
# Callbacks are objects that can perform actions at various stages of training (e.g., at the start or end of each epoch)
# Here, we define a callback to log training statistics using TensorBoard
callbacks=[tf.keras.callbacks.TensorBoard(log_dir="logs")]

# The `fit` method trains the U-Net model on the training data
# It takes the input data (`x_train`) and corresponding target data (`y_train`)
# `validation_split=0.01` specifies that 1% of the training data will be used for validation
# `batch_size=8` specifies the number of samples per gradient update
# `epochs=30` specifies the number of epochs (iterations over the entire dataset) for training
# `callbacks=callbacks` specifies the list of callbacks to apply during training (in this case, the TensorBoard callback)
results = modelUNet.fit(x_train, y_train, validation_split=0.01, batch_size=8, epochs=30, callbacks=callbacks)

This function call will generate plots showing the training and validation loss over epochs, as well as the training and validation accuracy over epochs, if implemented within the plot_training() function.

In [None]:
# Plot the training history using the `plot_training` function
plot_training(results)

The evaluate method is used to evaluate the performance of a trained model on a test dataset.
In this code, modelUNet is the trained U-Net model that we want to evaluate.
X is the input test data (images), and Y is the corresponding ground truth data (masks).
The method computes the loss value and any other metrics specified during model compilation (e.g., accuracy).
The evaluation results are returned as a list, where the first element is the loss value and subsequent elements are the values of the specified metrics.
These results can be used to assess how well the model generalizes to unseen data and to compare the model's performance with different configurations or architectures

In [None]:
# The `evaluate` method evaluates the trained U-Net model on the given test data (X and Y).
# It computes the loss value and metrics (specified during model compilation) for the test dataset.
# The evaluation results are returned as a list containing the loss value and metrics values.
modelUNet.evaluate(X, Y)

The code calculates the confusion matrix to evaluate the performance of the U-Net model on the validation data.
The predict method is used to generate predictions (y_pred) for the validation images (x_val).
The predicted masks are binarized using a threshold of 0.5 to convert them into binary format (y_pred_binary).
Similarly, the ground truth masks (y_val) are converted into binary format (y_val_binary).
The confusion matrix is computed using the confusion_matrix function from the scikit-learn library.
The computed confusion matrix is visualized as a heatmap using the heatmap function from the seaborn library.
The heatmap provides a graphical representation of the confusion matrix, where each cell represents the count of true positives, false positives, true negatives, and false negatives. The intensity of the color indicates the count, and annotations provide the exact count values

In [None]:
# Import necessary libraries
from sklearn.metrics import confusion_matrix
import seaborn as sns

# Predict masks for validation data using the trained U-Net model
y_pred = modelUNet.predict(x_val)

# Convert predicted masks to binary format using a threshold of 0.5
y_pred_binary = (y_pred > 0.5).astype(int)

# Convert ground truth masks to binary format
y_val_binary = y_val.astype(int)

# Compute the confusion matrix using sklearn
cm = confusion_matrix(y_val_binary.flatten(), y_pred_binary.flatten())

# Plot the confusion matrix as a heatmap using seaborn
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, cmap='Blues', fmt='g', cbar=False)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()


The code calculates the Receiver Operating Characteristic (ROC) curve and the Area Under the Curve (AUC) to evaluate the performance of the U-Net model for binary classification.
The roc_curve function from scikit-learn computes the ROC curve by taking the true labels (y_val.flatten()) and predicted probabilities (y_pred.flatten()).
The ROC curve plots the true positive rate (sensitivity) against the false positive rate (1 - specificity) at various threshold settings.
The auc function computes the AUC score, which quantifies the overall performance of the model. AUC ranges from 0 to 1, where a higher value indicates better performance.
Finally, the ROC curve is plotted using plt.plot. The label of the curve includes the computed AUC value for reference. Axis labels, title, and legend are added to the plot for better interpretation.

In [None]:
# Import necessary libraries
from sklearn.metrics import roc_curve, auc

# Calculate ROC curve
# ROC curve is a graphical representation of the true positive rate (sensitivity) versus the false positive rate (1 - specificity)
# It shows the trade-off between sensitivity and specificity across different threshold values
fpr, tpr, _ = roc_curve(y_val.flatten(), y_pred.flatten())

# Calculate AUC (Area Under the Curve)
# AUC quantifies the overall performance of a binary classification model
# It represents the probability that a randomly chosen positive sample will be ranked higher than a randomly chosen negative sample
roc_auc = auc(fpr, tpr)

# Plot ROC curve
# The ROC curve is plotted with false positive rate (x-axis) against true positive rate (y-axis)
plt.plot(fpr, tpr, label=f'AUC = {roc_auc:.2f}')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend()
plt.show()

The calculate_iou function computes the Intersection over Union (IoU) score between true and predicted masks.
Intersection is calculated by finding the pixels where both true and predicted masks are non-zero (logical AND operation).
Union is calculated by finding the pixels where either true or predicted masks are non-zero (logical OR operation).
IoU score is computed as the ratio of the intersection area to the union area.
The code then calculates IoU scores for different thresholds (ranging from 0.1 to 1.0 with a step size of 0.1) using list comprehension.
For each threshold value, the predicted mask is binarized based on whether the predicted probability is above the threshold.
Finally, the IoU scores are plotted against the threshold values to visualize the IoU curve, which helps in determining the optimal threshold for segmentation.

In [None]:
def calculate_iou(y_true, y_pred):
    # Calculate intersection and union between true and predicted masks
    intersection = np.logical_and(y_true, y_pred)
    union = np.logical_or(y_true, y_pred)
    
    # Calculate Intersection over Union (IoU) score
    # IoU measures the overlap between the true and predicted masks
    iou_score = np.sum(intersection) / np.sum(union)
    return iou_score


# Calculate IoU for different thresholds
# IoU is calculated for various threshold values to evaluate the model's segmentation performance at different confidence levels
thresholds = np.arange(0.1, 1.0, 0.1)
iou_scores = [calculate_iou(y_val_binary, (y_pred > threshold).astype(int)) for threshold in thresholds]

# Plot IoU curve
# IoU scores are plotted against different threshold values
plt.plot(thresholds, iou_scores, marker='.')
plt.xlabel('Threshold')
plt.ylabel('IoU')
plt.title('IoU Curve')
plt.show()

The code iterates over each image in the validation dataset using a for loop.
For each image, it retrieves the original image (img) and generates a predicted mask (predMask) using the trained U-Net model (modelUNet.predict()).
It then creates a new figure (fig) to visualize the original image, predicted mask, and original mask.
Three subplots are added to the figure:
   1) The first subplot (ax1) displays the predicted mask using the 'cividis' colormap.
   2) The second subplot (ax2) displays the original image in grayscale.
   3) The third subplot (ax3) displays the original mask using the 'bone' colormap.
Axis labels are turned off for all subplots.
After iterating over all images, the plot with all images is shown using plt.show().

In [None]:
import matplotlib.pyplot as plt

# Iterate over each image in the validation dataset
for i in range( len(x_val)):
    ind = i
    img = x_val[ind]   # Get the original image
    predMask = modelUNet.predict(np.expand_dims(img, axis=0), verbose=0)  # Generate predicted mask for the image

    # Create a new figure for visualization
    fig = plt.figure(figsize=(15, 5))  

    # Add subplots for original image, predicted mask, and original mask
    ax1 = fig.add_subplot(1, 3, 1)
    ax1.set_title("pred mask")
    ax1.imshow(np.squeeze(predMask), cmap='cividis', interpolation='bicubic')  # Display predicted mask with 'cividis' colormap
    ax1.axis('off')  # Turn off axis


    ax2 = fig.add_subplot(1, 3, 2)
    ax2.set_title("original image")
    ax2.imshow(x_val[ind], cmap='gray') # Display original image in grayscale
    ax2.axis('off')   # Turn off axis


    ax3 = fig.add_subplot(1, 3, 3)
    ax3.set_title("original mask")
    ax3.imshow(np.squeeze(y_val[ind]), cmap='bone')  # Display original mask with 'bone' colormap
    ax3.axis('off')  # Turn off axis

# Show the plot with all images
plt.show()