In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import load_img, img_to_array

# Defining directories
image_dirs = ["/kaggle/input/drishtigs-retina-dataset-for-onh-segmentation/Test-20211018T060000Z-001/Test/Images/glaucoma", "/kaggle/input/drishtigs-retina-dataset-for-onh-segmentation/Test-20211018T060000Z-001/Test/Images/normal"]  # List of image directories
mask_dir = "/kaggle/input/drishtigs-retina-dataset-for-onh-segmentation/Test-20211018T060000Z-001/Test/Test_GT"  # Mask directory

# Define image dimensions
img_height, img_width, img_channels = 256, 256, 1
batch_size = 8

def load_images_and_masks(image_dirs, mask_dir):
    images, masks = [], []

    # Iterate through both image directories
    for image_dir in image_dirs:
        for filename in os.listdir(image_dir):
            # Load image
            img_path = os.path.join(image_dir, filename)
            img = load_img(img_path, color_mode="grayscale", target_size=(img_height, img_width)) #Load image as  grayscale to match model input
            img_array = img_to_array(img) / 255.0  # Normalize to [0, 1]
            images.append(img_array)

            # Load corresponding mask
            mask_filename = filename.split('.')[0] + "_ODsegSoftmap.png"
            mask_path = os.path.join(mask_dir, filename.split('.')[0], "SoftMap", mask_filename)
            if not os.path.exists(mask_path):
                print(f"Mask not found for {filename} at {mask_path}. Skipping...")
                continue
            mask = load_img(mask_path, color_mode="grayscale", target_size=(img_height, img_width))
            mask_array = img_to_array(mask) / 255.0  # Normalize to [0, 1]

            # Convert mask to binary inverted mask
            mask_binary = np.where(mask_array > 0.5, 0, 1).astype(np.float32)  # Invert: black becomes white, and vice versa
            masks.append(mask_binary)

    return np.array(images), np.array(masks)

# Load images and masks
images, masks = load_images_and_masks(image_dirs, mask_dir)

# Convert to TensorFlow tensors
images_tensor = tf.convert_to_tensor(images, dtype=tf.float32)
masks_tensor = tf.convert_to_tensor(masks, dtype=tf.float32)

# Create TensorFlow dataset
dataset = tf.data.Dataset.from_tensor_slices((images_tensor, masks_tensor))
dataset = dataset.batch(batch_size).shuffle(buffer_size=len(images))

# Print dataset shapes for verification
for batch_images, batch_masks in dataset.take(1):
    print("Batch Images Shape:", batch_images.shape) 
    print("Batch Masks Shape:", batch_masks.shape)   

In [None]:
#Define metrics and loss functions as custom model features
@tf.keras.utils.register_keras_serializable()
def weighted_binary_crossentropy(y_true, y_pred):
    weights = tf.where(tf.less(tf.range(tf.shape(y_true)[2]), tf.shape(y_true)[2] // 2), 2.0, 1.0)
    bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    weighted_bce = bce * weights
    return tf.reduce_mean(weighted_bce)

@tf.keras.utils.register_keras_serializable()
def dice_coefficient(y_true, y_pred):
    # Cast both y_true and y_pred to float32 to ensure compatibility
    y_true_f = tf.keras.backend.flatten(tf.cast(y_true, tf.float32))
    y_pred_f = tf.keras.backend.flatten(tf.cast(y_pred, tf.float32))
    
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    dice = (2. * intersection + 1e-6) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) + 1e-6)
    
    return dice

@tf.keras.utils.register_keras_serializable()
def iou(y_true, y_pred):
    y_true_f = tf.keras.backend.flatten(tf.cast(y_true, tf.float32))
    y_pred_f = tf.keras.backend.flatten(tf.cast(y_pred, tf.float32))

    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    union = tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) - intersection
    iou_metric = (intersection ) / (union )
    return iou_metric

In [None]:
#Load the saved XNet model
from tensorflow.keras.models import load_model

model = load_model('/kaggle/input/best_models/keras/default/1/Best Models/xnet.keras')

In [None]:
#Evaluate the model 
model.evaluate(dataset)

In [None]:
#Plotting predictions against the image and ground truth
import matplotlib.pyplot as plt
import tensorflow as tf



# Take one batch of data
batch = next(iter(dataset.take(1)))  # Get one batch
images, masks = batch  # Unpack images and ground truth masks

# Make predictions on the batch
predictions = model.predict(images)


batch_size = images.shape[0]  # Number of samples in the batch

for i in range(batch_size):
    plt.figure(figsize=(12, 4))

    # Plot original image
    plt.subplot(1, 3, 1)
    plt.imshow(tf.squeeze(images[i]), cmap="gray")
    plt.title("Original Image")
    plt.axis("off")

    # Plot ground truth mask
    plt.subplot(1, 3, 2)
    plt.imshow(tf.squeeze(masks[i]), cmap="gray")
    plt.title("Ground Truth Mask")
    plt.axis("off")

    # Plot predicted mask
    plt.subplot(1, 3, 3)
    plt.imshow(tf.squeeze(predictions[i] > 0.5), cmap="gray")  # Apply threshold if binary segmentation
    plt.title("Predicted Mask")
    plt.axis("off")

    plt.show()