In [1]:
import cv2
import numpy as np
from pathlib import Path
import os
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import ModelCheckpoint
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import plot_model
from sklearn.metrics import accuracy_score, classification_report
import pandas as pd
from glob import glob

# For neural network
import tensorflow as tf

# For Accuracy metric
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report

import time

from tensorflow.keras.models import Model
from tensorflow.keras import layers
from tensorflow.keras.models import load_model

2025-07-30 04:33:50.952574: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753850031.133752      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753850031.185874      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
base_dir = '/kaggle/input/hyper-curated-busi/hyper_curated_busi'
normal_dir = os.path.join(base_dir, 'normal')
benign_dir = os.path.join(base_dir, 'benign')
malignant_dir = os.path.join(base_dir, 'malignant')
print(normal_dir)

/kaggle/input/hyper-curated-busi/hyper_curated_busi/normal


In [4]:
def load_images_and_masks(directory, class_label, has_mask=True):
    images = []
    masks = []
    labels = []
    # Get all image files (excluding masks)
    image_files = [f for f in os.listdir(directory) if '_mask' not in f and f.endswith('.png')]

    for img_name in image_files:
        # Load image
        img_path = os.path.join(directory, img_name)
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)  # Load as grayscale
        if img is None:
            continue

        # Resize image (e.g., to 256x256)
        img = cv2.resize(img, (256, 256))
        images.append(img)
        labels.append(class_label)

        # Load mask if applicable
        if has_mask:
            mask_name = img_name.replace('.png', '_mask.png')
            mask_path = os.path.join(directory, mask_name)
            if os.path.exists(mask_path):
                mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                mask = cv2.resize(mask, (256, 256))
                # Binarize mask (0 or 255)
                mask = (mask > 0).astype(np.uint8) * 255
                masks.append(mask)
            else:
                masks.append(np.zeros((256, 256), dtype=np.uint8))  # Empty mask if not found
        else:
            masks.append(np.zeros((256, 256), dtype=np.uint8))  # No mask for normal images

    return images, masks, labels

# Load data for each class
normal_images, normal_masks, normal_labels = load_images_and_masks(normal_dir, 'normal', has_mask=False)
benign_images, benign_masks, benign_labels = load_images_and_masks(benign_dir, 'benign', has_mask=True)
malignant_images, malignant_masks, malignant_labels = load_images_and_masks(malignant_dir, 'malignant', has_mask=True)

# Combine all data
all_images = normal_images + benign_images + malignant_images
all_masks = normal_masks + benign_masks + malignant_masks
all_labels = normal_labels + benign_labels + malignant_labels

# Convert to numpy arrays
all_images = np.array(all_images)
all_masks = np.array(all_masks)
all_labels = np.array(all_labels)
print(len(all_labels))
print(len(normal_masks))
print(len(benign_masks))
print(len(malignant_masks))

399
64
185
150


In [6]:
# Normalize images and masks
all_images = all_images / 255.0  # Normalize to [0, 1]
all_masks = all_masks / 255.0    # Normalize to [0, 1]

In [7]:
# Split data (80% train, 20% test)
X_train, X_test, y_train, y_test, labels_train, labels_test = train_test_split(
    all_images, all_masks, all_labels,
    test_size=0.2, random_state=40, stratify=all_labels
)

X_val, X_test1, y_val, y_test1, labels_val, labels_test1 = train_test_split(
    X_test, y_test, labels_test,
    test_size=0.95, random_state=40, stratify=labels_test
)


# Reshape for deep learning models (add channel dimension)
X_train = X_train[..., np.newaxis]  # Shape: (n_train, 256, 256, 1)
X_test = X_test[..., np.newaxis]    # Shape: (n_test, 256, 256, 1)
X_val = X_val[..., np.newaxis]    # Shape: (n_test, 256, 256, 1)
y_train = y_train[..., np.newaxis]  # Shape: (n_train, 256, 256, 1)
y_test = y_test[..., np.newaxis]    # Shape: (n_test, 256, 256, 1)
y_val = y_val[..., np.newaxis]    # Shape: (n_test, 256, 256, 1)

print(f"Training set: {X_train.shape}, {y_train.shape},{labels_train.shape}")
print(f"Testing set: {X_test.shape}, {y_test.shape},{labels_test.shape}")
print(f"Validation set: {X_val.shape}, {y_val.shape},{labels_val.shape}")

Training set: (319, 256, 256, 1), (319, 256, 256, 1),(319,)
Testing set: (80, 256, 256, 1), (80, 256, 256, 1),(80,)
Validation set: (4, 256, 256, 1), (4, 256, 256, 1),(4,)


In [8]:
from collections import Counter

print("Train label distribution:", Counter(labels_train))
print("Test label distribution:", Counter(labels_test))
#print("Valdiation label distribution:", Counter(labels_val))

Train label distribution: Counter({'benign': 148, 'malignant': 120, 'normal': 51})
Test label distribution: Counter({'benign': 37, 'malignant': 30, 'normal': 13})


In [9]:
df_labels_train = pd.get_dummies(labels_train).astype(int)
df_labels_test = pd.get_dummies(labels_test).astype(int)
df_labels_val = pd.get_dummies(labels_val).astype(int)

# Optional: reorder columns to follow a consistent order
#df_labels = df_labels[['malignant', 'benign', 'normal']]  # reorder as needed

print(sum(df_labels_train['normal']))
df_labels_train.head()

51


Unnamed: 0,benign,malignant,normal
0,0,1,0
1,1,0,0
2,1,0,0
3,1,0,0
4,0,1,0


In [10]:
import albumentations as A

# Define augmentation pipeline
augmentation = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=10, p=0.5),
    A.RandomBrightnessContrast(p=0.3),
    A.RandomCrop(height=224, width=224, p=0.3),
    A.Resize(256, 256)  # Ensure output size
])

# Apply augmentation to training data
augmented_images = []
augmented_masks = []
augmented_labels = []
df_labels_train_np=np.array(df_labels_train)
for img, mask,label in zip(X_train, y_train,df_labels_train_np):
    aug = augmentation(image=img.squeeze(), mask=mask.squeeze())
    augmented_images.append(aug['image'][..., np.newaxis])
    augmented_masks.append(aug['mask'][..., np.newaxis])
    augmented_labels.append(label)
    
# Convert to numpy arrays
augmented_images = np.array(augmented_images)
augmented_masks = np.array(augmented_masks)
augmented_labels=np.array(augmented_labels)
print(df_labels_train.shape)
print(augmented_labels.shape)
# Combine original and augmented data
X_train_aug = np.concatenate([X_train, augmented_images], axis=0)
y_train_aug = np.concatenate([y_train, augmented_masks], axis=0)
df_labels_train_aug=np.concatenate([df_labels_train, augmented_labels], axis=0)

  check_for_updates()


(319, 3)
(319, 3)


In [11]:
print(f"Training set: {X_train_aug.shape}, {y_train_aug.shape},{df_labels_train_aug.shape}")
print(f"Testing set: {X_test.shape}, {y_test.shape},{labels_test.shape}")
print(f"Validation set: {X_val.shape}, {y_val.shape},{labels_val.shape}")

Training set: (638, 256, 256, 1), (638, 256, 256, 1),(638, 3)
Testing set: (80, 256, 256, 1), (80, 256, 256, 1),(80,)
Validation set: (4, 256, 256, 1), (4, 256, 256, 1),(4,)


In [12]:
def dice_bce_loss(y_true, y_pred, axis=(1, 2, 3), smooth=1e-4):
    y_true = tf.cast(y_true, tf.float32)
    y_pred_sigmoid = tf.keras.activations.sigmoid(y_pred)  # Optional: if logits

    # Binary cross-entropy
    bce = tf.keras.losses.binary_crossentropy(y_true, y_pred_sigmoid)

    # Dice loss
    y_pred_bin = tf.cast(y_pred_sigmoid > 0.5, tf.float32)
    tp = tf.reduce_sum(y_true * y_pred_bin, axis=axis)
    fn = tf.reduce_sum(y_true * (1 - y_pred_bin), axis=axis)
    fp = tf.reduce_sum((1 - y_true) * y_pred_bin, axis=axis)
    dice_score = (2 * tp + smooth) / (2 * tp + fn + fp + smooth)
    dice_loss = 1.0 - tf.reduce_mean(dice_score)

    # Combine
    return dice_loss + tf.reduce_mean(bce)

def dice_loss(y_true, y_pred, axis=(1, 2, 3), smooth=1e-4):
    y_true = tf.cast(y_true, tf.float32)
    y_pred_sigmoid = tf.keras.activations.sigmoid(y_pred)  # Optional: if logits

    # Dice loss
    y_pred_bin = tf.cast(y_pred_sigmoid > 0.5, tf.float32)
    tp = tf.reduce_sum(y_true * y_pred_bin, axis=axis)
    fn = tf.reduce_sum(y_true * (1 - y_pred_bin), axis=axis)
    fp = tf.reduce_sum((1 - y_true) * y_pred_bin, axis=axis)
    dice_score = (2 * tp + smooth) / (2 * tp + fn + fp + smooth)
    dice_loss = 1.0 - tf.reduce_mean(dice_score)
    return dice_loss 

def dice(y_true, y_pred, axis=(0, 1, 2), smooth=0.0001, thr=0.5):
    y_true = tf.cast(y_true, tf.float32) # (B, H, W, C)
    y_pred = tf.cast(y_pred > thr, tf.float32) # (B, H, W, C)
    tp = tf.math.reduce_sum(y_true * y_pred, axis=axis) # calculate True Positive
    fn = tf.math.reduce_sum(y_true * (1 - y_pred), axis=axis) # calculate False Negative
    fp = tf.math.reduce_sum((1 - y_true) * y_pred, axis=axis) # calculate False Positive
    dice = (2*tp + smooth) / (2*tp + fn + fp + smooth) # calculate Dice score
    dice = tf.math.reduce_mean(dice) # average over all classes
    return dice # Dice loss is 1 - Dice score

def iou(y_true, y_pred, axis=(0, 1, 2), smooth=0.0001, thr=0.5):
    y_true = tf.cast(y_true, tf.float32) # (B, H, W, C)
    y_pred = tf.cast(y_pred > thr, tf.float32) # (B, H, W, C)
    tp = tf.math.reduce_sum(y_true * y_pred, axis=axis) # calculate True Positive
    fn = tf.math.reduce_sum(y_true * (1 - y_pred), axis=axis) # calculate False Negative
    fp = tf.math.reduce_sum((1 - y_true) * y_pred, axis=axis) # calculate False Positive
    iou = (tp + smooth) / (tp + fn + fp + smooth) # calculate Dice score
    iou = tf.math.reduce_mean(iou) # average over all classes
    return iou # Dice loss is 1 - Dice score

In [None]:
# #Basic U-Nnet

# def model_seg_class(inp_size=(256,256,1),filter_size=32):

#     inp=layers.Input(inp_size)
#     #stage 1
#     x = layers.Conv2D(filter_size, (3,3), activation=None, padding='same')(inp)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     c4=x
#     x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    
#     #stage 2
#     x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     c3=x
#     x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    
#     #stage 3
#     x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     c2=x
#     x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    
#     #stage 4
#     x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     c1=x
#     p1=x
#     x = layers.MaxPooling2D(pool_size=(2, 2))(x)

#     #stage 5
#     x = layers.Conv2D(filter_size*16, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*16, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*16, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     p2=x
#     x = layers.UpSampling2D(size=(2, 2))(x)

#     #stage 6
#     x=layers.concatenate([c1, x], axis=3)
#     x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     p3=x
#     x = layers.UpSampling2D(size=(2, 2))(x)

#     #stage 7
#     x=layers.concatenate([c2, x], axis=3)
#     x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.UpSampling2D(size=(2, 2))(x)

#     #stage 8
#     x=layers.concatenate([c3, x], axis=3)
#     x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.UpSampling2D(size=(2, 2))(x)

#     #stage 9
#     x=layers.concatenate([c4, x], axis=3)
#     x = layers.Conv2D(filter_size, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(1, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     seg_out=layers.Activation('sigmoid',name='seg_out')(x)
    
#     #classifier
#     p1=layers.Conv2D(1,(3,3),padding='same',activation='relu')(p1)
#     p2=layers.Conv2D(1,(3,3),padding='same',activation='relu')(p2)
#     p3=layers.Conv2D(1,(3,3),padding='same',activation='relu')(p3)
#     g1 = layers.GlobalAveragePooling2D()(p1)
#     g2 = layers.GlobalAveragePooling2D()(p2)
#     g3 = layers.GlobalAveragePooling2D()(p3)
#     x=layers.concatenate([g1,g2,g3])
#     x=layers.Dense(32,activation='relu')(x)
#     cls_out=layers.Dense(3,activation='softmax',name='cls_out')(x)

    
    
#     model = Model(inputs=inp, outputs=[seg_out,cls_out])
#     return model

In [47]:
#With Gate units
def gate_unit(inp,size):
    x = layers.Conv2D(size, (3,3), activation=None, padding='same')(inp)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('sigmoid')(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Reshape((1, 1, size))(x)
    g1=layers.Multiply()([x,inp])
    x = layers.Conv2D(size, (3,3), activation=None, padding='same')(inp)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('sigmoid')(x)
    x = layers.GlobalMaxPooling2D()(x)
    x = layers.Reshape((1, 1, size))(x)
    g2=layers.Multiply()([x,inp])
    x=layers.add([g1,g2])
    x = layers.Conv2D(size, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    return x

def model_seg_class(inp_size=(256,256,1),filter_size=32):

    inp=layers.Input(inp_size)
    #stage 1
    x = layers.Conv2D(filter_size, (3,3), activation=None, padding='same')(inp)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    # x = layers.Conv2D(filter_size, (3,3), activation=None, padding='same')(x)
    # x = layers.BatchNormalization()(x)
    # x = layers.Activation('relu')(x)
    c4=x
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    
    #stage 2
    x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    c3=x
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    
    #stage 3
    x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    c2=x
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    
    #stage 4
    x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    c1=x
    p1=x
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)

    #stage 5
    x = layers.Conv2D(filter_size*16, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*16, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*16, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    p2=x
    x = layers.UpSampling2D(size=(2, 2))(x)

    #stage 6
    c1=gate_unit(c1,filter_size*8)
    x=layers.concatenate([c1, x], axis=3)
    x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    # x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
    # x = layers.BatchNormalization()(x)
    # x = layers.Activation('relu')(x)
    p3=x
    x = layers.UpSampling2D(size=(2, 2))(x)

    #stage 7
    c2=gate_unit(c2,filter_size*4)
    x=layers.concatenate([c2, x], axis=3)
    x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    # x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
    # x = layers.BatchNormalization()(x)
    # x = layers.Activation('relu')(x)
    x = layers.UpSampling2D(size=(2, 2))(x)

    #stage 8
    c3=gate_unit(c3,filter_size*2)
    x=layers.concatenate([c3, x], axis=3)
    x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    # x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
    # x = layers.BatchNormalization()(x)
    # x = layers.Activation('relu')(x)
    x = layers.UpSampling2D(size=(2, 2))(x)

    #stage 9
    c4=gate_unit(c4,filter_size)
    x=layers.concatenate([c4, x], axis=3)
    x = layers.Conv2D(filter_size, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(1, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    seg_out=layers.Activation('sigmoid',name='seg_out')(x)
    
    #classifier
    p1=layers.Conv2D(1,(3,3),padding='same',activation='relu')(p1)
    p2=layers.Conv2D(1,(3,3),padding='same',activation='relu')(p2)
    p3=layers.Conv2D(1,(3,3),padding='same',activation='relu')(p3)
    g1 = layers.GlobalAveragePooling2D()(p1)
    g2 = layers.GlobalAveragePooling2D()(p2)
    g3 = layers.GlobalAveragePooling2D()(p3)
    x=layers.concatenate([g1,g2,g3])
    x=layers.Dense(32,activation='relu')(x)
    cls_out=layers.Dense(3,activation='softmax',name='cls_out')(x)

    
    
    model = Model(inputs=inp, outputs=[seg_out,cls_out])
    return model

In [36]:
# #With attention gates
# def gate_unit(inp,size):
#     x = layers.Conv2D(size, (3,3), activation=None, padding='same')(inp)
#     x = layers.Activation('relu')(x)
#     x = layers.GlobalAveragePooling2D()(x)
#     x = layers.Reshape((1, 1, size))(x)
#     g1=layers.Multiply()([x,inp])
#     x = layers.Conv2D(size, (3,3), activation=None, padding='same')(inp)
#     x = layers.Activation('relu')(x)
#     x = layers.GlobalMaxPooling2D()(x)
#     x = layers.Reshape((1, 1, size))(x)
#     g2=layers.Multiply()([x,inp])
#     x=layers.add([g1,g2])
#     x = layers.Conv2D(size, (3,3), activation=None, padding='same')(x)
#     x = layers.Activation('relu')(x)
#     return x
    
    
# def model_seg_class(inp_size=(256,256,1),filter_size=32):

#     inp=layers.Input(inp_size)
#     #stage 1
#     x = layers.Conv2D(filter_size, (3,3), activation=None, padding='same')(inp)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     c4=x
#     x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    
#     #stage 2
#     x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     c3=x
#     x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    
#     #stage 3
#     x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     c2=x
#     x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    
#     #stage 4
#     x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     c1=x
#     p1=x
#     x = layers.MaxPooling2D(pool_size=(2, 2))(x)

#     #stage 5
#     x = layers.Conv2D(filter_size*16, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*16, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*16, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     p2=x
#     x = layers.UpSampling2D(size=(2, 2))(x)

#     #stage 6
#     g1 = gate_unit(c1,filter_size*8)
#     x = layers.concatenate([g1, x], axis=3)   
#     # c1 = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(c1)
#     # c1 = layers.BatchNormalization()(c1)
#     # c1 = layers.Activation('relu')(c1)
#     # x = layers.concatenate([c1, x], axis=3) 
#     x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     p3=x
#     x = layers.UpSampling2D(size=(2, 2))(x)

#     #stage 7
#     g2 = gate_unit(c2,filter_size*4)
#     x = layers.concatenate([g2, x], axis=3)
#     # c2 = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(c2)
#     # c2 = layers.BatchNormalization()(c2)
#     # c2 = layers.Activation('relu')(c2)
#     # x=layers.concatenate([c2, x], axis=3)
#     x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.UpSampling2D(size=(2, 2))(x)

#     #stage 8
#     g3 = gate_unit(c3,filter_size*2)
#     x = layers.concatenate([g3, x], axis=3)
#     # c3 = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(c3)
#     # c3 = layers.BatchNormalization()(c3)
#     # c3 = layers.Activation('relu')(c3)
#     # x=layers.concatenate([c3, x], axis=3)
#     x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.UpSampling2D(size=(2, 2))(x)

#     #stage 9
#     g4 = gate_unit(c4,filter_size)
#     x = layers.concatenate([g4, x], axis=3)
#     # c4 = layers.Conv2D(filter_size, (3,3), activation=None, padding='same')(c4)
#     # c4 = layers.BatchNormalization()(c4)
#     # c4 = layers.Activation('relu')(c4)
#     # x=layers.concatenate([c4, x], axis=3)
#     x = layers.Conv2D(filter_size, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.Conv2D(1, (3,3), activation=None, padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     seg_out=layers.Activation('sigmoid',name='seg_out')(x)
    
#     #classifier
#     p1=layers.Conv2D(1,(3,3),padding='same',activation='relu')(p1)
#     p2=layers.Conv2D(1,(3,3),padding='same',activation='relu')(p2)
#     p3=layers.Conv2D(1,(3,3),padding='same',activation='relu')(p3)
#     g1 = layers.GlobalAveragePooling2D()(p1)
#     g2 = layers.GlobalAveragePooling2D()(p2)
#     g3 = layers.GlobalAveragePooling2D()(p3)
#     x=layers.concatenate([g1,g2,g3])
#     x=layers.Dense(32,activation='relu')(x)
#     cls_out=layers.Dense(3,activation='softmax',name='cls_out')(x)

    
    
#     model = Model(inputs=inp, outputs=[seg_out,cls_out])
#     return model

In [28]:
#UNET ++

def model_seg_class(inp_size=(256,256,1),filter_size=16):

    inp=layers.Input(inp_size)
    #stage 0,0
    x = layers.Conv2D(filter_size, (3,3), activation=None, padding='same')(inp)
    x = layers.BatchNormalization()(x)    
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    # x = layers.Conv2D(filter_size, (3,3), activation=None, padding='same')(x)
    # x = layers.BatchNormalization()(x)
    # x = layers.Activation('relu')(x)
    s00=x
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    
    #stage 1,0
    x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    # x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
    # x = layers.BatchNormalization()(x)
    # x = layers.Activation('relu')(x)
    s10=x
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    
    #stage 2,0
    x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    # x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
    # x = layers.BatchNormalization()(x)
    # x = layers.Activation('relu')(x)
    s20=x
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    
    #stage 3,0
    x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    s30=x
    p1=x
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    s30_mp=x

    #stage 0,1
    x = layers.concatenate([s00, layers.UpSampling2D(size=(2, 2))(s10)], axis=3) 
    x = layers.Conv2D(filter_size, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    s01=x

    #stage 1,1
    x = layers.concatenate([s10, layers.UpSampling2D(size=(2, 2))(s20)], axis=3) 
    x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    s11=x

    #stage 0,2
    x = layers.concatenate([s01,s00, layers.UpSampling2D(size=(2, 2))(s11)], axis=3) 
    x = layers.Conv2D(filter_size, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    s02=x

    #stage 2,1
    x = layers.concatenate([s20,layers.UpSampling2D(size=(2, 2))(s30)], axis=3) 
    x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    s21=x

    #stage 1,2
    x = layers.concatenate([s11,s10,layers.UpSampling2D(size=(2, 2))(s21)], axis=3) 
    x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    s12=x

    #stage 0,3
    x = layers.concatenate([s00,s01,s02,layers.UpSampling2D(size=(2, 2))(s12)], axis=3) 
    x = layers.Conv2D(filter_size, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    s03=x
    

    #stage 4,0
    x = layers.Conv2D(filter_size*16, (3,3), activation=None, padding='same')(s30_mp)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*16, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*16, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    p2=x
    s40=x
    x = layers.UpSampling2D(size=(2, 2))(x)

    #stage 3,1
    x = layers.concatenate([x,s30], axis=3) 
    x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*8, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    p3=x
    x = layers.UpSampling2D(size=(2, 2))(x)

    #stage 2,2
    x = layers.concatenate([x,s20,s21], axis=3) 
    x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*4, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.UpSampling2D(size=(2, 2))(x)

    #stage 1,3
    x = layers.concatenate([x,s10,s11,s12], axis=3) 
    x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filter_size*2, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.UpSampling2D(size=(2, 2))(x)

    #stage 0,4
    x = layers.concatenate([x,s00,s01,s02,s03], axis=3) 
    x = layers.Conv2D(filter_size, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(1, (3,3), activation=None, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    seg_out=layers.Activation('sigmoid',name='seg_out')(x)
    
    #classifier
    p1=layers.Conv2D(1,(3,3),padding='same',activation='relu')(p1)
    p2=layers.Conv2D(1,(3,3),padding='same',activation='relu')(p2)
    p3=layers.Conv2D(1,(3,3),padding='same',activation='relu')(p3)
    g1 = layers.GlobalAveragePooling2D()(p1)
    g2 = layers.GlobalAveragePooling2D()(p2)
    g3 = layers.GlobalAveragePooling2D()(p3)
    x=layers.concatenate([g1,g2,g3])
    x=layers.Dense(32,activation='relu')(x)
    cls_out=layers.Dense(3,activation='softmax',name='cls_out')(x)

    
    
    model = Model(inputs=inp, outputs=[seg_out,cls_out])
    return model

In [48]:
model=model_seg_class()

In [90]:
model.compile(
    optimizer='adam',
    loss={
        'seg_out': dice_bce_loss,
        'cls_out': 'CategoricalCrossentropy'
    },
    metrics={
        'seg_out': [dice], # You might use Dice Coefficient or IoU here
        'cls_out': ['accuracy']
    }
)

checkpoint_cb = ModelCheckpoint(
    filepath="/kaggle/working/best_model.h5",
    monitor="val_seg_out_dice",        # or "val_dice" or "val_iou"
    mode="max",                # "min" for losses, "max" for accuracy metrics
    save_best_only=True,
    verbose=1
)


start=time.time()
history = model.fit(
    x=X_train_aug,
    y={'seg_out': y_train_aug, 'cls_out': df_labels_train_aug},
    batch_size=16,
    epochs=200,
    validation_data=(X_test, {'seg_out': y_test, 'cls_out': df_labels_test}),
    callbacks=[checkpoint_cb]
)

end=time.time()
print(f"\n\nTraining time: {(end-start):.2f} seconds")
model.save('/kaggle/working/unet.h5')

Epoch 1/200
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 830ms/step - cls_out_accuracy: 0.9916 - cls_out_loss: 0.0267 - loss: 1.7665 - seg_out_dice: 0.6854 - seg_out_loss: 1.7398
Epoch 1: val_seg_out_dice improved from -inf to 0.55875, saving model to /kaggle/working/best_model.h5
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m91s[0m 981ms/step - cls_out_accuracy: 0.9915 - cls_out_loss: 0.0272 - loss: 1.7674 - seg_out_dice: 0.6846 - seg_out_loss: 1.7402 - val_cls_out_accuracy: 0.8375 - val_cls_out_loss: 1.0872 - val_loss: 2.8921 - val_seg_out_dice: 0.5588 - val_seg_out_loss: 1.8050
Epoch 2/200
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 501ms/step - cls_out_accuracy: 0.9720 - cls_out_loss: 0.0703 - loss: 1.8226 - seg_out_dice: 0.6273 - seg_out_loss: 1.7523
Epoch 2: val_seg_out_dice did not improve from 0.55875
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 519ms/step - cls_out_accuracy: 0.9721 - cls_out_loss: 0.0703 - 

KeyboardInterrupt: 

In [91]:
model = load_model('/kaggle/working/best_model.h5')

In [17]:
model=load_model('/kaggle/input/unet/keras/default/3/unet6582.h5')

In [14]:
def predict(model,X_test):
    seg_preds, cls_preds = model.predict(X_test)
    idx=0
    for cls in cls_preds:
        if(cls.argmax()==2):
            seg_preds[idx][:]=0
        idx=idx+1
    return seg_preds, cls_preds

In [18]:
seg_preds, cls_preds = predict(model,X_test)

[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 2s/step  


In [19]:
true_labels = df_labels_test.values.argmax(axis=1)

# Predicted class indices
predicted_labels = cls_preds.argmax(axis=1)

# Accuracy
acc = accuracy_score(true_labels, predicted_labels)
print(f"\n## Accuracy: {acc:.4f}")

# Classification report
report = classification_report(true_labels, predicted_labels, target_names=['benign', 'malignant', 'normal'])
print("\n## Classification Report:\n", report)

#segmentation
print("\n## Dice Score:\n",dice(y_test,seg_preds).numpy())
print("\n## IOU:\n",iou(y_test,seg_preds).numpy(),"\n")


## Accuracy: 0.8125

## Classification Report:
               precision    recall  f1-score   support

      benign       0.81      0.81      0.81        37
   malignant       0.77      0.90      0.83        30
      normal       1.00      0.62      0.76        13

    accuracy                           0.81        80
   macro avg       0.86      0.78      0.80        80
weighted avg       0.83      0.81      0.81        80


## Dice Score:
 0.66862774

## IOU:
 0.5022095 



In [20]:
def plot_prediction(image, predicted_mask, ground_truth_mask=None):
    plt.figure(figsize=(12, 4))
    predicted_mask =  (predicted_mask > 0.5).astype(np.float32)
    # Fix image shape
    image_2d = np.squeeze(image)
    if image_2d.ndim == 3:  # if shape is (H, W, C)
        image_2d = image_2d[:, :, 0]  # Take first channel

    # Plot input image
    plt.subplot(1, 3 if ground_truth_mask is not None else 2, 1)
    plt.imshow(image_2d, cmap='gray')
    plt.title("Ultrasound Image")
    plt.axis('off')

    # Ground truth
    if ground_truth_mask is not None:
        gt_mask = np.squeeze(ground_truth_mask)
        plt.subplot(1, 3, 2)
        plt.imshow(gt_mask, cmap='gray')
        plt.title("Ground Truth Mask")
        plt.axis('off')

    # Predicted mask
    pred_mask = np.squeeze(predicted_mask)
    plt.subplot(1, 3 if ground_truth_mask is not None else 2, 3 if ground_truth_mask is not None else 2)
    plt.imshow(pred_mask, cmap='gray')
    plt.title("Predicted Mask")
    plt.axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
n=78
plot_prediction(X_test[n],seg_preds[n],y_test[n])

In [None]:
for i in range(15,25):
    plot_prediction(X_test[i],seg_preds[i],y_test[i])

In [49]:
 model.summary()
# print("\n\n## Model Plot")
# plot_model(model, show_shapes=True)