In [None]:
import tensorflow as tf

# Check if GPU is detected
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
print(tf.test.gpu_device_name())  # Should output something like '/device:GPU:0'

In [None]:
# Enable mixed precision properly
# policy = tf.keras.mixed_precision.Policy('mixed_float16')
# tf.keras.mixed_precision.set_global_policy(policy)
import tensorflow as tf
tf.keras.mixed_precision.set_global_policy('float32')
# Make sure your model's final layer outputs float32 for stability
# Add this to your model's last layer:
# outputs = tf.cast(outputs, tf.float32)

In [None]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        # Enable memory growth (allocates memory as needed)
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

In [None]:
import cv2
import numpy as np
from pathlib import Path
import os
import matplotlib.pyplot as plt

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
from tensorflow.keras.utils import plot_model

In [None]:
from sklearn.metrics import accuracy_score, classification_report

In [None]:
import pandas as pd
from glob import glob

# For neural network
import tensorflow as tf

# For Accuracy metric
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report

import time

from tensorflow.keras.models import Model
from tensorflow.keras import layers

In [None]:
base_dir = '/kaggle/input/hyper-curated-busi/hyper_curated_busi'
normal_dir = os.path.join(base_dir, 'normal')
benign_dir = os.path.join(base_dir, 'benign')
malignant_dir = os.path.join(base_dir, 'malignant')
print(normal_dir)

In [None]:
def load_images_and_masks(directory, class_label, has_mask=True):
    images = []
    masks = []
    labels = []
    # Get all image files (excluding masks)
    image_files = [f for f in os.listdir(directory) if '_mask' not in f and f.endswith('.png')]

    for img_name in image_files:
        # Load image
        img_path = os.path.join(directory, img_name)
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)  # Load as grayscale
        if img is None:
            continue

        # Resize image (e.g., to 256x256)
        img = cv2.resize(img, (256, 256))
        images.append(img)
        labels.append(class_label)

        # Load mask if applicable
        if has_mask:
            mask_name = img_name.replace('.png', '_mask.png')
            mask_path = os.path.join(directory, mask_name)
            if os.path.exists(mask_path):
                mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                mask = cv2.resize(mask, (256, 256))
                # Binarize mask (0 or 255)
                mask = (mask > 0).astype(np.uint8) * 255
                masks.append(mask)
            else:
                masks.append(np.zeros((256, 256), dtype=np.uint8))  # Empty mask if not found
        else:
            masks.append(np.zeros((256, 256), dtype=np.uint8))  # No mask for normal images

    return images, masks, labels

# Load data for each class
normal_images, normal_masks, normal_labels = load_images_and_masks(normal_dir, 'normal', has_mask=False)
benign_images, benign_masks, benign_labels = load_images_and_masks(benign_dir, 'benign', has_mask=True)
malignant_images, malignant_masks, malignant_labels = load_images_and_masks(malignant_dir, 'malignant', has_mask=True)

# Combine all data
all_images = normal_images + benign_images + malignant_images
all_masks = normal_masks + benign_masks + malignant_masks
all_labels = normal_labels + benign_labels + malignant_labels

# Convert to numpy arrays
all_images = np.array(all_images)
all_masks = np.array(all_masks)
all_labels = np.array(all_labels)

In [None]:
# Normalize images and masks
all_images = all_images / 255.0  # Normalize to [0, 1]
all_masks = all_masks / 255.0    # Normalize to [0, 1]

In [None]:
# Split data (80% train, 20% test)
X_train, X_test, y_train, y_test, labels_train, labels_test = train_test_split(
    all_images, all_masks, all_labels,
    test_size=0.2, random_state=40, stratify=all_labels
)

X_val, X_test1, y_val, y_test1, labels_val, labels_test1 = train_test_split(
    X_test, y_test, labels_test,
    test_size=0.95, random_state=40, stratify=labels_test
)


# Reshape for deep learning models (add channel dimension)
X_train = X_train[..., np.newaxis]  # Shape: (n_train, 256, 256, 1)
X_test = X_test[..., np.newaxis]    # Shape: (n_test, 256, 256, 1)
X_val = X_val[..., np.newaxis]    # Shape: (n_test, 256, 256, 1)
y_train = y_train[..., np.newaxis]  # Shape: (n_train, 256, 256, 1)
y_test = y_test[..., np.newaxis]    # Shape: (n_test, 256, 256, 1)
y_val = y_val[..., np.newaxis]    # Shape: (n_test, 256, 256, 1)

print(f"Training set: {X_train.shape}, {y_train.shape},{labels_train.shape}")
print(f"Testing set: {X_test.shape}, {y_test.shape},{labels_test.shape}")
print(f"Validation set: {X_val.shape}, {y_val.shape},{labels_val.shape}")

In [None]:
from collections import Counter

print("Train label distribution:", Counter(labels_train))
print("Test label distribution:", Counter(labels_test))
print("Valdiation label distribution:", Counter(labels_val))

In [None]:
df_labels_train = pd.get_dummies(labels_train).astype(int)
df_labels_test = pd.get_dummies(labels_test).astype(int)
df_labels_val = pd.get_dummies(labels_val).astype(int)

# Optional: reorder columns to follow a consistent order
#df_labels = df_labels[['malignant', 'benign', 'normal']]  # reorder as needed

print(sum(df_labels_train['normal']))
df_labels_val.head()

In [None]:
import albumentations as A

# Define augmentation pipeline
augmentation = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=30, p=0.5),
    A.RandomBrightnessContrast(p=0.3),
    A.RandomCrop(height=224, width=224, p=0.3),
    A.Resize(256, 256)  # Ensure output size
])

# Apply augmentation to training data
augmented_images = []
augmented_masks = []
augmented_labels = []
df_labels_train_np=np.array(df_labels_train)
for img, mask,label in zip(X_train, y_train,df_labels_train_np):
    aug = augmentation(image=img.squeeze(), mask=mask.squeeze())
    augmented_images.append(aug['image'][..., np.newaxis])
    augmented_masks.append(aug['mask'][..., np.newaxis])
    augmented_labels.append(label)
    
# Convert to numpy arrays
augmented_images = np.array(augmented_images)
augmented_masks = np.array(augmented_masks)
augmented_labels=np.array(augmented_labels)
print(df_labels_train.shape)
print(augmented_labels.shape)
# Combine original and augmented data
X_train_aug = np.concatenate([X_train, augmented_images], axis=0)
y_train_aug = np.concatenate([y_train, augmented_masks], axis=0)
df_labels_train_aug=np.concatenate([df_labels_train, augmented_labels], axis=0)

In [None]:
print(f"Training set: {X_train_aug.shape}, {y_train_aug.shape},{df_labels_train_aug.shape}")
print(f"Testing set: {X_test.shape}, {y_test.shape},{labels_test.shape}")
print(f"Validation set: {X_val.shape}, {y_val.shape},{labels_val.shape}")

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model, regularizers
from tensorflow.keras.applications import EfficientNetB4
import tensorflow.keras.backend as K


class SpatialAttention(layers.Layer):
    """Spatial Attention Module for feature refinement"""
    def __init__(self, **kwargs):
        super(SpatialAttention, self).__init__(**kwargs)
        
    def build(self, input_shape):
        self.conv = layers.Conv2D(1, 7, padding='same', activation='sigmoid')
        super(SpatialAttention, self).build(input_shape)
        
    def call(self, inputs):
        avg_pool = tf.reduce_mean(inputs, axis=-1, keepdims=True)
        max_pool = tf.reduce_max(inputs, axis=-1, keepdims=True)
        concat = layers.Concatenate()([avg_pool, max_pool])
        attention = self.conv(concat)
        return inputs * attention

class ChannelAttention(layers.Layer):
    """Channel Attention Module (SE block variant)"""
    def __init__(self, ratio=16, **kwargs):
        super(ChannelAttention, self).__init__(**kwargs)
        self.ratio = ratio
        
    def build(self, input_shape):
        self.channels = input_shape[-1]
        self.gap = layers.GlobalAveragePooling2D()
        self.fc1 = layers.Dense(self.channels // self.ratio, activation='relu')
        self.fc2 = layers.Dense(self.channels, activation='sigmoid')
        self.reshape = layers.Reshape((1, 1, self.channels))
        super(ChannelAttention, self).build(input_shape)
        
    def call(self, inputs):
        x = self.gap(inputs)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.reshape(x)
        return inputs * x

class CBAM(layers.Layer):
    """Convolutional Block Attention Module"""
    def __init__(self, **kwargs):
        super(CBAM, self).__init__(**kwargs)
        self.channel_attention = ChannelAttention()
        self.spatial_attention = SpatialAttention()
        
    def call(self, inputs):
        x = self.channel_attention(inputs)
        x = self.spatial_attention(x)
        return x

class ResidualBlock(layers.Layer):
    """Enhanced Residual Block with attention"""
    def __init__(self, filters, strides=1, use_attention=True, dropout_rate=0.1, **kwargs):
        super(ResidualBlock, self).__init__(**kwargs)
        self.filters = filters
        self.strides = strides
        self.use_attention = use_attention
        self.dropout_rate = dropout_rate
        
    def build(self, input_shape):
        # Main path
        self.conv1 = layers.Conv2D(self.filters, 3, strides=self.strides, padding='same',
                                  kernel_regularizer=regularizers.l2(1e-4))
        self.bn1 = layers.BatchNormalization()
        self.conv2 = layers.Conv2D(self.filters, 3, padding='same',
                                  kernel_regularizer=regularizers.l2(1e-4))
        self.bn2 = layers.BatchNormalization()
        self.dropout = layers.Dropout(self.dropout_rate)
        
        # Shortcut path
        if self.strides != 1 or input_shape[-1] != self.filters:
            self.shortcut_conv = layers.Conv2D(self.filters, 1, strides=self.strides, padding='same')
            self.shortcut_bn = layers.BatchNormalization()
        else:
            self.shortcut_conv = None
            
        # Attention
        if self.use_attention:
            self.attention = CBAM()
            
        super(ResidualBlock, self).build(input_shape)
        
    def call(self, inputs, training=None):
        # Main path
        x = self.conv1(inputs)
        x = self.bn1(x, training=training)
        x = layers.Activation('swish')(x)  # Swish activation performs better
        x = self.dropout(x, training=training)
        
        x = self.conv2(x)
        x = self.bn2(x, training=training)
        
        # Shortcut path
        if self.shortcut_conv:
            shortcut = self.shortcut_conv(inputs)
            shortcut = self.shortcut_bn(shortcut, training=training)
        else:
            shortcut = inputs
            
        # Add and activate
        x = layers.Add()([x, shortcut])
        
        # Apply attention
        if self.use_attention:
            x = self.attention(x)
            
        x = layers.Activation('swish')(x)
        return x

class PyramidPooling(layers.Layer):
    """Pyramid Pooling Module for multi-scale context"""
    def __init__(self, filters, **kwargs):
        super(PyramidPooling, self).__init__(**kwargs)
        self.filters = filters
        
    def build(self, input_shape):
        self.pool_sizes = [1, 2, 3, 6]
        self.convs = []
        for _ in self.pool_sizes:
            self.convs.append([
                layers.Conv2D(self.filters // len(self.pool_sizes), 1, padding='same'),
                layers.BatchNormalization(),
                layers.Activation('swish')
            ])
        super(PyramidPooling, self).build(input_shape)
        
    def call(self, inputs, training=None):
        h, w = tf.shape(inputs)[1], tf.shape(inputs)[2]
        pools = []
        
        for i, pool_size in enumerate(self.pool_sizes):
            pool = layers.AveragePooling2D(pool_size, strides=pool_size)(inputs)
            for layer in self.convs[i]:
                if isinstance(layer, layers.BatchNormalization):
                    pool = layer(pool, training=training)
                else:
                    pool = layer(pool)
            pool = tf.image.resize(pool, [h, w])
            pools.append(pool)
            
        return layers.Concatenate()(pools)

# Dice and Focal Loss Functions
def dice_coefficient(y_true, y_pred, smooth=1e-6):
    """Dice coefficient for segmentation evaluation"""
    y_true_f = K.flatten(K.cast(y_true, dtype='float32'))
    y_pred_f = K.flatten(K.cast(y_pred, dtype='float32'))
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)


def dice_loss(y_true, y_pred):
    return 1 - dice_coefficient(y_true, y_pred)

def focal_loss(y_true, y_pred, alpha=0.8, gamma=2.0):
    epsilon = K.epsilon()
    y_pred = K.clip(y_pred, epsilon, 1. - epsilon)
    p_t = tf.where(K.equal(y_true, 1), y_pred, 1 - y_pred)
    alpha_factor = K.ones_like(y_true) * alpha
    alpha_t = tf.where(K.equal(y_true, 1), alpha_factor, 1 - alpha_factor)
    cross_entropy = -K.log(p_t)
    weight = alpha_t * K.pow((1 - p_t), gamma)
    loss = weight * cross_entropy
    return K.mean(K.sum(loss, axis=1))

def combined_loss(y_true_seg, y_pred_seg, y_true_cls, y_pred_cls, seg_weight=0.7):
    seg_loss = dice_loss(y_true_seg, y_pred_seg) + K.binary_crossentropy(y_true_seg, y_pred_seg)
    cls_loss = focal_loss(y_true_cls, y_pred_cls)
    return seg_weight * seg_loss + (1 - seg_weight) * cls_loss

# Main model definition
def model_seg_class_optimized(inp_size=(256, 256, 3), num_classes=3, base_filters=32):
    inputs = layers.Input(inp_size, name='input')
    
    # Data augmentation and normalization inside the model (NOT passed as input_tensor)
    aug = layers.RandomFlip("horizontal")(inputs)
    aug = layers.RandomRotation(0.1)(aug)
    aug = layers.RandomZoom(0.1)(aug)
    aug = layers.Rescaling(1./255)(aug)  # Changed Normalization to Rescaling for simplicity

    # EfficientNet backbone with raw input (augmented tensor is used downstream)
    backbone = EfficientNetB7(input_tensor=inputs, weights='imagenet', include_top=False)
    x = backbone(aug)  # Pass augmented input through the backbone
    
    # Feature extraction
    feature_layers = ['block2a_expand_activation', 'block3a_expand_activation',
                      'block4a_expand_activation', 'block6a_expand_activation']
    skip_connections = [backbone.get_layer(name).output for name in feature_layers]
    
    bottom_features = backbone.output
    ppm = PyramidPooling(base_filters * 16)(bottom_features)
    bottom_features = layers.Concatenate()([bottom_features, ppm])
    
    x = ResidualBlock(base_filters * 16, use_attention=True)(bottom_features)
    x = ResidualBlock(base_filters * 16, use_attention=True)(x)
    
    cls_features = [layers.GlobalAveragePooling2D()(x)]

    decoder_filters = [base_filters * 8, base_filters * 4, base_filters * 2, base_filters]

    for i, (skip, filters) in enumerate(zip(reversed(skip_connections), decoder_filters)):
        x = layers.UpSampling2D(2, interpolation='bilinear')(x)
        skip_processed = layers.Conv2D(filters, 1, padding='same')(skip)
        skip_processed = layers.BatchNormalization()(skip_processed)
        skip_processed = layers.Activation('swish')(skip_processed)
        skip_processed = CBAM()(skip_processed)
        
        x = layers.Concatenate()([skip_processed, x])
        x = ResidualBlock(filters, use_attention=True, dropout_rate=0.1)(x)
        x = ResidualBlock(filters, use_attention=True, dropout_rate=0.1)(x)

        if i < 3:
            cls_features.append(layers.GlobalAveragePooling2D()(x))

    x = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(x)

    seg_out = layers.Conv2D(1, 3, padding='same')(x)
    seg_out = layers.BatchNormalization()(seg_out)
    seg_out = layers.Activation('sigmoid', name='seg_out')(seg_out)

    cls_combined = layers.Concatenate()(cls_features)
    cls_x = layers.Dense(512, activation='swish', kernel_regularizer=regularizers.l2(1e-4))(cls_combined)
    cls_x = layers.Dropout(0.3)(cls_x)
    cls_x = layers.Dense(256, activation='swish', kernel_regularizer=regularizers.l2(1e-4))(cls_x)
    cls_x = layers.Dropout(0.2)(cls_x)
    cls_x = layers.Dense(128, activation='swish', kernel_regularizer=regularizers.l2(1e-4))(cls_x)
    cls_out = layers.Dense(num_classes, activation='softmax', name='cls_out')(cls_x)

    model = Model(inputs=inputs, outputs=[seg_out, cls_out], name='OptimizedBUSIModel')
    return model

class DiceBCELoss(tf.keras.losses.Loss):
    def call(self, y_true, y_pred):
        return dice_loss(y_true, y_pred) + K.binary_crossentropy(y_true, y_pred)


def compile_model(model, learning_rate=1e-4):
    steps_per_epoch = 100
    total_steps = steps_per_epoch * 200
    lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
        initial_learning_rate=learning_rate,
        decay_steps=total_steps,
        alpha=0.1
    )
    optimizer = tf.keras.optimizers.AdamW(
        learning_rate=lr_schedule,
        weight_decay=1e-4,
        beta_1=0.9,
        beta_2=0.999,
        epsilon=1e-7
    )

    model.compile(
        optimizer=optimizer,
        loss={
            'seg_out': DiceBCELoss(),
            'cls_out': focal_loss
        },
        loss_weights={'seg_out': 0.7, 'cls_out': 0.3},
        metrics={
            'seg_out': [dice_coefficient, 'binary_accuracy'],
            'cls_out': ['accuracy', 'categorical_accuracy']
        }
    )
    return model



In [None]:
# model=model_seg_class()

In [None]:
 # Create optimized model
model = model_seg_class_optimized(inp_size=(256, 256, 3), num_classes=3)
    
    # Compile with optimal settings
model = compile_model(model)
    
    # Print model summary
# print(model.summary())
    
    # Additional callbacks for training
callbacks = [
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss', factor=0.5, patience=10, min_lr=1e-7
        ),
        tf.keras.callbacks.EarlyStopping(
            monitor='val_loss', patience=20, restore_best_weights=True
        ),
        tf.keras.callbacks.ModelCheckpoint(
            'best_busi_model.h5', monitor='val_loss', save_best_only=True
        )
    ]
    
    # print("Model ready for training on BUSI dataset!")
    # print("Key optimizations applied:")
    # print("- EfficientNetB4 backbone with ImageNet pretraining")
    # print("- CBAM attention mechanisms")
    # print("- Residual connections with proper regularization") 
    # print("- Pyramid pooling for multi-scale context")
    # print("- FPN-style decoder")
    # print("- Advanced loss functions (Dice+BCE, Focal)")
    # print("- Built-in data augmentation")
    # print("- AdamW optimizer with cosine decay")
    # print("- Multi-scale classification features")

In [None]:
# model.summary()
# print("\n\n## Model Plot")
# plot_model(model, show_shapes=True)

In [None]:
# def dice_bce_loss(y_true, y_pred, axis=(1, 2, 3), smooth=1e-4):
#     y_true = tf.cast(y_true, tf.float32)
#     y_pred_sigmoid = tf.keras.activations.sigmoid(y_pred)  # Optional: if logits

#     # Binary cross-entropy
#     bce = tf.keras.losses.binary_crossentropy(y_true, y_pred_sigmoid)

#     # Dice loss
#     y_pred_bin = tf.cast(y_pred_sigmoid > 0.5, tf.float32)
#     tp = tf.reduce_sum(y_true * y_pred_bin, axis=axis)
#     fn = tf.reduce_sum(y_true * (1 - y_pred_bin), axis=axis)
#     fp = tf.reduce_sum((1 - y_true) * y_pred_bin, axis=axis)
#     dice_score = (2 * tp + smooth) / (2 * tp + fn + fp + smooth)
#     dice_loss = 1.0 - tf.reduce_mean(dice_score)

#     # Combine
#     return dice_loss + tf.reduce_mean(bce)

def dice(y_true, y_pred, axis=(0, 1, 2), smooth=0.0001, thr=0.5):
    y_true = tf.cast(y_true, tf.float32) # (B, H, W, C)
    y_pred = tf.cast(y_pred > thr, tf.float32) # (B, H, W, C)
    tp = tf.math.reduce_sum(y_true * y_pred, axis=axis) # calculate True Positive
    fn = tf.math.reduce_sum(y_true * (1 - y_pred), axis=axis) # calculate False Negative
    fp = tf.math.reduce_sum((1 - y_true) * y_pred, axis=axis) # calculate False Positive
    dice = (2*tp + smooth) / (2*tp + fn + fp + smooth) # calculate Dice score
    dice = tf.math.reduce_mean(dice) # average over all classes
    return dice # Dice loss is 1 - Dice score

def iou(y_true, y_pred, axis=(0, 1, 2), smooth=0.0001, thr=0.5):
    y_true = tf.cast(y_true, tf.float32) # (B, H, W, C)
    y_pred = tf.cast(y_pred > thr, tf.float32) # (B, H, W, C)
    tp = tf.math.reduce_sum(y_true * y_pred, axis=axis) # calculate True Positive
    fn = tf.math.reduce_sum(y_true * (1 - y_pred), axis=axis) # calculate False Negative
    fp = tf.math.reduce_sum((1 - y_true) * y_pred, axis=axis) # calculate False Positive
    iou = (tp + smooth) / (tp + fn + fp + smooth) # calculate Dice score
    iou = tf.math.reduce_mean(iou) # average over all classes
    return iou # Dice loss is 1 - Dice score

In [None]:
# model.compile(
#     optimizer='adam',
#     loss={
#         'seg_out': 'BinaryCrossentropy',
#         'cls_out': 'CategoricalCrossentropy'
#     },
#     metrics={
#         'seg_out': [dice, iou], # You might use Dice Coefficient or IoU here
#         'cls_out': ['accuracy']
#     }
# )

In [None]:
print(model.output_names)
print(X_train_aug.shape)  # (N, 256, 256, 3)
print(y_train_aug.shape)          # e.g. (1000, 256, 256, 1)
print(df_labels_train_aug.shape)  # e.g. (1000, 3)


In [None]:
# sample_X = X_train_aug[:2]
# sample_y = {
#     'seg_out': y_train_aug[:2],
#     'cls_out': df_labels_train_aug[:2]
# }
# # print(model.predict(sample_X))
# # print(model.train_on_batch(sample_X, sample_y))
# results = model.train_on_batch(sample_X, sample_y)
# print("Train on batch results:", results)


In [None]:
import tensorflow as tf

# Check if GPU is detected
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
print(tf.test.gpu_device_name())  # Should output something like '/device:GPU:0'

In [None]:
import os
os.environ['TF_GPU_TIMER_LOG_LEVEL'] = '3'

In [None]:


start=time.time()

history = model.fit(
    x=X_train_aug,
    y={'seg_out': y_train_aug, 'cls_out': df_labels_train_aug},
    batch_size=8,
    epochs=5,
    validation_data=(X_val, {'seg_out': y_val, 'cls_out': df_labels_val}),
)

end=time.time()
print(f"\n\nTraining time: {(end-start):.2f} seconds")
model.save('/kaggle/working/pc.h5')

In [None]:
seg_preds, cls_preds = model.predict(X_test)
 

In [None]:
true_labels = df_labels_test.values.argmax(axis=1)

# Predicted class indices
predicted_labels = cls_preds.argmax(axis=1)

# Accuracy
acc = accuracy_score(true_labels, predicted_labels)
print(f"\n## Accuracy: {acc:.4f}")

# Classification report
report = classification_report(true_labels, predicted_labels, target_names=['benign', 'malignant', 'normal'])
print("\n## Classification Report:\n", report)

#segmentation
print("\n## Dice Score:\n",dice(y_test,seg_preds).numpy())
print("\n## IOU:\n",iou(y_test,seg_preds).numpy(),"\n")

In [None]:
def plot_prediction(image, predicted_mask, ground_truth_mask=None):
    plt.figure(figsize=(12, 4))
    predicted_mask =  (predicted_mask > 0.5).astype(np.float32)
    # Fix image shape
    image_2d = np.squeeze(image)
    if image_2d.ndim == 3:  # if shape is (H, W, C)
        image_2d = image_2d[:, :, 0]  # Take first channel

    # Plot input image
    plt.subplot(1, 3 if ground_truth_mask is not None else 2, 1)
    plt.imshow(image_2d, cmap='gray')
    plt.title("Ultrasound Image")
    plt.axis('off')

    # Ground truth
    if ground_truth_mask is not None:
        gt_mask = np.squeeze(ground_truth_mask)
        plt.subplot(1, 3, 2)
        plt.imshow(gt_mask, cmap='gray')
        plt.title("Ground Truth Mask")
        plt.axis('off')

    # Predicted mask
    pred_mask = np.squeeze(predicted_mask)
    plt.subplot(1, 3 if ground_truth_mask is not None else 2, 3 if ground_truth_mask is not None else 2)
    plt.imshow(pred_mask, cmap='gray')
    plt.title("Predicted Mask")
    plt.axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
n=18
plot_prediction(X_test[n],seg_preds[n],y_test[n])

In [None]:
for i in range(10,20):
    plot_prediction(X_test[i],seg_preds[i],y_test[i])