In [None]:
import os
import numpy as np
import cv2
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, concatenate, BatchNormalization, Activation, Conv2DTranspose, Add
from tensorflow.keras.metrics import Recall, Precision
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
from glob import glob
import shutil
from tensorflow.keras.preprocessing.image import img_to_array, load_img
import keras
from tqdm import tqdm
import seaborn as sns
from sklearn.preprocessing import normalize
import tensorflow_addons as tfa
import math
import random
import joblib
from tensorflow.keras import backend as K
from scipy.ndimage import binary_fill_holes
from sklearn.linear_model import LogisticRegression

original_width, original_height = 788, 510  # dimensions of original images

In [None]:
# Clear Keras/TensorFlow GPU session
tf.keras.backend.clear_session()

# Check for available GPUs
print("Available GPUs:", tf.config.list_physical_devices('GPU'))

# Force TensorFlow to use the GPU if available
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
        tf.config.set_visible_devices(gpus[0], 'GPU')
        print("Using GPU:", gpus[0])
    except RuntimeError as e:
        print("GPU Error:", e)


In [None]:
def augment_image(image, mask):
    image = tf.cast(image, tf.float64)
    mask = tf.cast(mask, tf.float64)

    def hflip_fn():
        return tf.image.flip_left_right(image), tf.image.flip_left_right(mask)

    def vflip_fn():
        return tf.image.flip_up_down(image), tf.image.flip_up_down(mask)

    def shift_fn():
        height = tf.shape(image)[0]
        width = tf.shape(image)[1]
        shift_h = tf.cast(0.1 * tf.cast(height, tf.float64), tf.int32)
        shift_w = tf.cast(0.1 * tf.cast(width, tf.float64), tf.int32)
        dx = tf.random.uniform([], -shift_w, shift_w + 1, dtype=tf.int32)
        dy = tf.random.uniform([], -shift_h, shift_h + 1, dtype=tf.int32)
        image_shifted = tfa.image.translate(image, [dx, dy])
        mask_shifted = tfa.image.translate(mask, [dx, dy])
        return image_shifted, mask_shifted

    def rotate_fn():
        angle = tf.random.uniform([], -20.0, 20.0) * (np.pi / 180.0)
        image_rot = tf.cast(image, tf.float32)
        mask_rot = tf.cast(mask, tf.float32)
        image_rot = tfa.image.rotate(image_rot, angle, interpolation='BILINEAR')
        mask_rot = tfa.image.rotate(mask_rot, angle, interpolation='NEAREST')
        return tf.cast(image_rot, tf.float64), tf.cast(mask_rot, tf.float64)

    def zoom_fn():
        zoom_factor = tf.random.uniform([], 0.9, 1.1)
        size = tf.cast(tf.shape(image)[:2], tf.float32)
        new_size = tf.cast(zoom_factor * size, tf.int32)
        image_zoom = tf.image.resize(image, new_size, method='bilinear')
        mask_zoom = tf.image.resize(mask, new_size, method='nearest')
        image_zoom = tf.image.resize_with_crop_or_pad(image_zoom, 256, 256)
        mask_zoom = tf.image.resize_with_crop_or_pad(mask_zoom, 256, 256)
        return tf.cast(image_zoom, tf.float64), tf.cast(mask_zoom, tf.float64)

    def contrast_fn():
        image_adj = tf.image.random_contrast(tf.cast(image, tf.float32), 0.8, 1.2)
        return tf.cast(image_adj, tf.float64)

    # Apply augmentations with tf.cond
    image, mask = tf.cond(tf.random.uniform(()) < 0.3, hflip_fn, lambda: (image, mask))
    image, mask = tf.cond(tf.random.uniform(()) < 0.3, vflip_fn, lambda: (image, mask))
    image, mask = tf.cond(tf.random.uniform(()) < 0.3, shift_fn, lambda: (image, mask))
    image, mask = tf.cond(tf.random.uniform(()) < 0.3, rotate_fn, lambda: (image, mask))
    image, mask = tf.cond(tf.random.uniform(()) < 0.3, zoom_fn, lambda: (image, mask))
    image = tf.cond(tf.random.uniform(()) < 0.3, contrast_fn, lambda: image)

    return image, mask


In [None]:
def load_data(path):
    original_images = sorted(glob(os.path.join(path, "Images", "*.bmp")))
    mask_images = sorted(glob(os.path.join(path, "Masks", "*.bmp")))
    
    assert len(original_images) == len(mask_images), "Mismatch between images and masks!"

    return np.array(original_images), np.array(mask_images)

def read_image(path):
    path = path.decode()
    x = cv2.imread(path, cv2.IMREAD_COLOR)
    x = cv2.resize(x, (256,256))
    x = x/255.0    
    return x

def read_mask(path):
    path = path.decode()
    x = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    x = cv2.resize(x, (256,256))
    x = x/255.0
    x = np.expand_dims(x, axis = -1)    
    return x

def tf_parse(x, y, augment=False):
    def _parse(x, y):
        x = read_image(x)
        y = read_mask(y)
        return x, y

    x, y = tf.numpy_function(_parse, [x, y], [tf.float64, tf.float64])
    x.set_shape([256, 256, 3])
    y.set_shape([256, 256, 1])

    if augment:
        x, y = augment_image(x, y)  # See below for function

    return x, y


def tf_dataset(x, y, batch=8, augment=False):
    dataset = tf.data.Dataset.from_tensor_slices((x, y))
    dataset = dataset.map(lambda a, b: tf_parse(a, b, augment=augment),
                          num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset


In [None]:
dataset_path = r"/path/to/dataset/"

# All data
images, masks = load_data(dataset_path)
print(len(images))

In [None]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (Input, Conv2D, MaxPooling2D, Conv2DTranspose, Add, BatchNormalization,
                                     Activation, concatenate, GlobalAveragePooling2D, Dense, Reshape, Multiply)

def squeeze_excite_block(inputs, ratio=16):
    filters = inputs.shape[-1]
    se = GlobalAveragePooling2D()(inputs)
    se = Dense(filters // ratio, activation='relu')(se)
    se = Dense(filters, activation='sigmoid')(se)
    se = Reshape((1, 1, filters))(se)
    return Multiply()([inputs, se])

def attention_gate(x, gating):
    filters = x.shape[-1]
    theta_x = Conv2D(filters, (1, 1), padding="same")(x)
    phi_g = Conv2D(filters, (1, 1), padding="same")(gating)
    act = Activation('relu')(Add()([theta_x, phi_g]))
    psi = Conv2D(1, (1, 1), padding="same", activation="sigmoid")(act)
    return Multiply()([x, psi])

def conv_block(inputs, n_filters, kernel_size=3, strides=1):
    x = Conv2D(n_filters, kernel_size, strides=strides, padding='same')(inputs)
    x = BatchNormalization()(x)
    return Activation('relu')(x)

def residual_se_block(inputs, n_filters):
    x = conv_block(inputs, n_filters)
    x = conv_block(x, n_filters)
    x = squeeze_excite_block(x)
    shortcut = Conv2D(n_filters, (1,1), padding='same')(inputs)
    return Add()([x, shortcut])

def aspp_block(x, n_filters):
    # ASPP with dilation rates 1, 6, 12, 18 (DeepLab-style)
    conv1 = Conv2D(n_filters, 3, padding="same", dilation_rate=1)(x)
    conv6 = Conv2D(n_filters, 3, padding="same", dilation_rate=6)(x)
    conv12 = Conv2D(n_filters, 3, padding="same", dilation_rate=12)(x)
    conv18 = Conv2D(n_filters, 3, padding="same", dilation_rate=18)(x)
    x = Add()([conv1, conv6, conv12, conv18])
    x = BatchNormalization()(x)
    return Activation('relu')(x)

def upconv_block(inputs, n_filters):
    x = Conv2DTranspose(n_filters, 3, strides=2, padding='same')(inputs)
    x = BatchNormalization()(x)
    return Activation('relu')(x)

def XBoundNetPP(input_shape=(256, 256, 3), n_classes=1, dropout_rate=0.2):
    inputs = Input(input_shape)
    filters = [32, 64, 128, 256, 512]
    skip_connections = []

    # --- Stem ---
    x = conv_block(inputs, filters[0])
    x = conv_block(x, filters[0])
    #x = Dropout(dropout_rate)(x)

    # --- Encoder: Residual + SE blocks ---
    for i in range(1, len(filters)-1):
        x = residual_se_block(x, filters[i])
        if i >= 2:  # Only apply dropout to deeper levels
            x = Dropout(dropout_rate)(x)
        skip_connections.append(x)
        x = MaxPooling2D(pool_size=(2, 2))(x)

    # --- Bottleneck with ASPP ---
    x = aspp_block(x, filters[-1])
    x = Dropout(dropout_rate)(x)

    # --- Decoder with Attention + SE ---
    for i in reversed(range(len(skip_connections))):
        x = upconv_block(x, filters[i+1])
        g = attention_gate(skip_connections[i], x)
        x = concatenate([x, g])
        x = residual_se_block(x, filters[i+1])
        if i > 0:  # Only apply dropout to deeper levels
            x = Dropout(dropout_rate)(x)

    # --- Final Output ---
    outputs = Conv2D(n_classes, 1, activation='sigmoid')(x)

    return Model(inputs, outputs)

In [None]:
model = XBoundNetPP()
model.summary()

In [None]:
def read_image_predict(path):
  x = cv2.imread(path, cv2.IMREAD_COLOR) #788x510
  y = cv2.cvtColor(x, cv2.COLOR_RGB2BGR)
  y = cv2.resize(y, (256,256))
  y = y/255.0
  x = cv2.resize(x, (256,256))
  x = x/255.0    #(256, 256, 3)
  return x, y

def read_mask_predict(path):
  x = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
  x = cv2.resize(x, (256,256))
  x = np.expand_dims(x, axis = -1)    #(256, 256, 1)
  return x

def mask_parse(mask):
    mask = np.squeeze(mask)
    mask = [mask, mask, mask]
    mask = np.transpose(mask, (1, 2, 0))
    return mask

In [None]:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

In [None]:
def set_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

In [None]:
def iou(y_true, y_pred):
    def f(y_true, y_pred):
        intersection = (y_true * y_pred).sum()
        union = y_true.sum() + y_pred.sum() - intersection
        x = (intersection + 1e-15) / (union + 1e-15)
        return np.float32(x)
    return tf.numpy_function(f, [y_true, y_pred], tf.float32)

In [None]:
def log_dice_loss(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)

    intersection = K.sum(y_true_f * y_pred_f)
    denominator = K.sum(y_true_f) + K.sum(y_pred_f)

    dice_score = (2. * intersection + smooth) / (denominator + smooth)
    log_dice = -K.log(dice_score)

    return log_dice

In [None]:
def log_dice_bce_loss(y_true, y_pred, alpha=0.7):
    """
    Combined Log Dice + Binary Cross Entropy Loss.
    """
    ce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    dice = log_dice_loss(y_true, y_pred)
    return alpha * dice + (1 - alpha) * ce

In [None]:
class PredictionMonitor(tf.keras.callbacks.Callback):
    def __init__(self, sample_data, save_dir):
        super().__init__()
        self.sample_data = sample_data  # Tuple (input_image, true_mask)
        self.save_dir = save_dir
        os.makedirs(self.save_dir, exist_ok=True)

    def on_epoch_end(self, epoch, logs=None):
        input_image, true_mask = self.sample_data
        prediction = self.model.predict(tf.expand_dims(input_image, axis=0))[0]

        # Plotting
        plt.figure(figsize=(12, 4))

        plt.subplot(1, 3, 1)
        plt.imshow(input_image, cmap='gray')
        plt.title('Input Image')
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(true_mask, cmap='gray')
        plt.title('True Mask')
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(prediction, cmap='gray')
        plt.title(f'Prediction at Epoch {epoch + 1}')
        plt.axis('off')

        plt.tight_layout()
        plt.savefig(os.path.join(self.save_dir, f'epoch_{epoch + 1}.png'))
        plt.close()


In [None]:
#kf = KFold(n_splits=5, shuffle=True, random_state=42)
kf = KFold(n_splits=5, shuffle=False)

fold_accuracies, fold_losses, fold_ious = [], [], []

model_save_file = "XBoundNetCombinedLDL07BCE03Drop02onlowresLre4.h5"
#model_save_file = "XBoundNetLDLDrop02Lre3.h5"
#model_save_file = "XBoundNetLDLDrop02.h5"
#model_save_file = "XBoundNetLDLNoDrop.h5"

In [None]:
batch = 4
epochs = 400
lr = 1e-4
#lr = 1e-3


for fold, (trainval_idx, test_idx) in enumerate(kf.split(images)):
    print(f"\n🔹 Training Fold {fold + 1}/5")
    
    trainval_x, test_x = images[trainval_idx], images[test_idx]
    trainval_y, test_y = masks[trainval_idx], masks[test_idx]

    # From trainval (80% of total), carve out 16% as validation
    train_x, val_x, train_y, val_y = train_test_split(
        trainval_x, trainval_y, test_size=0.16, random_state=42
    )
    
    # Confirm split sizes
    print(f"  Train: {len(train_x)} | Val: {len(val_x)} | Test: {len(test_x)}")
    
    # Prepare TF datasets
    # Augmentation is applied dynamically each epoch
    train_dataset = tf_dataset(train_x, train_y, batch=batch, augment=True)
    valid_dataset = tf_dataset(val_x, val_y, batch=batch)

    #model = ResUnetPlusPlus()
    model = XBoundNetPP()


    opt = tf.keras.optimizers.Adam(lr)
    metrics = ["acc", Recall(), Precision(), iou]
    #model.compile(loss=log_dice_loss, optimizer=opt, metrics=metrics)
    model.compile(loss=log_dice_bce_loss, optimizer=opt, metrics=metrics)
    
    # Define original network path (TensorFlow has issues with $)
    network_log_path = r"/path/to/dataset/XBoundNet/tensorflow_logs"
    
    network_log_dir = os.path.join(network_log_path, model_save_file)
    os.makedirs(network_log_dir, exist_ok=True)
    
    # Define a TEMPORARY local log directory
    local_log_dir = os.path.expanduser("~/tensorflow_logs")  # Works for both Windows & Linux
    
    # Ensure the local directory exists
    os.makedirs(local_log_dir, exist_ok=True)
    
    # Debugging
    print(f"✅ Local log directory being used: {local_log_dir}")

    # Load and preprocess the sample image and mask
    sample_image_valid, t = read_image_predict(val_x[0])  # This should return a preprocessed image tensor
    sample_mask_valid = read_mask_predict(val_y[0])    # This should return a preprocessed mask tensor
    
    # Ensure the data types match the model's expectations
    sample_image_valid = tf.convert_to_tensor(sample_image_valid, dtype=tf.float32)
    sample_mask_valid = tf.convert_to_tensor(sample_mask_valid, dtype=tf.float32)
    
    # Create the sample data tuple
    sample_data_valid = (sample_image_valid, sample_mask_valid)
    
    save_dir_valid = 'prediction_monitoring_validation'
    prediction_monitor_validation = PredictionMonitor(sample_data_valid, save_dir_valid)
    
    # Load and preprocess the sample image and mask
    sample_image_training, t = read_image_predict(train_x[0])  # This should return a preprocessed image tensor
    sample_mask_training = read_mask_predict(train_y[0])    # This should return a preprocessed mask tensor
    
    # Ensure the data types match the model's expectations
    sample_image_training = tf.convert_to_tensor(sample_image_training, dtype=tf.float32)
    sample_mask_training = tf.convert_to_tensor(sample_mask_training, dtype=tf.float32)
    
    # Create the sample data tuple
    sample_data_training = (sample_image_training, sample_mask_training)
    
    save_dir_train = 'prediction_monitoring_training'
    prediction_monitor_training = PredictionMonitor(sample_data_training, save_dir_train)

    callbacks = [
        tf.keras.callbacks.ModelCheckpoint(f"XBoundNet_{fold + 1}Fold_LDL07_BCE03_Drop02onlowres_Lre4.h5", save_best_only=True),
        tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.1, patience=15),
        tf.keras.callbacks.TensorBoard(log_dir=local_log_dir, histogram_freq=1),
        tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=50, restore_best_weights=True),
        prediction_monitor_training, prediction_monitor_validation
    ]

    train_steps = len(train_x) // batch
    valid_steps = len(val_x) // batch
    train_steps += 1 if len(train_x) % batch != 0 else 0
    valid_steps += 1 if len(val_x) % batch != 0 else 0


    history = model.fit(
        train_dataset,
        validation_data=valid_dataset,
        epochs=epochs,
        #steps_per_epoch=train_steps,
        validation_steps=valid_steps,
        callbacks=callbacks,
        shuffle=False#, verbose=2
    )

    # ✅ Evaluate Model
    loss, acc, rec, prec, iou_score = model.evaluate(valid_dataset, steps=valid_steps)
    print(f"🔹 Fold {fold + 1} → Loss: {loss:.4f}, Accuracy: {acc:.4f}, IoU: {iou_score:.4f}")

    # ✅ Store results
    fold_losses.append(loss)
    fold_accuracies.append(acc)
    fold_ious.append(iou_score)

# ✅ Final Results
print("\n✅ Final 5-Fold Cross-Validation Results:")
print(f"Average Accuracy: {np.mean(fold_accuracies):.4f} ± {np.std(fold_accuracies):.4f}")
print(f"Average Loss: {np.mean(fold_losses):.4f} ± {np.std(fold_losses):.4f}")
print(f"Average IoU: {np.mean(fold_ious):.4f} ± {np.std(fold_ious):.4f}")


In [None]:
def move_logs():
    try:
        if os.path.exists(network_log_dir):
            print(f"✅ Moving logs from {local_log_dir} → {network_log_dir}")
            shutil.copytree(local_log_dir, network_log_dir, dirs_exist_ok=True)
            print("✅ Logs moved successfully.")
        else:
            print(f"⚠️ Network path {network_log_dir} not found. Copy manually.")
    except Exception as e:
        print(f"❌ Error moving logs: {e}")

move_logs()

In [None]:
batch = 4
epochs = 400
lr = 1e-4
#lr = 1e-3

for i in range(1, 6):
    set_seed(1000+i)
    
    for fold, (trainval_idx, test_idx) in enumerate(kf.split(images)):
        print(f"\n🔹 Training Fold {fold + 1}/5")
        
        trainval_x, test_x = images[trainval_idx], images[test_idx]
        trainval_y, test_y = masks[trainval_idx], masks[test_idx]
    
        # From trainval (80% of total), carve out 16% as validation
        train_x, val_x, train_y, val_y = train_test_split(
            trainval_x, trainval_y, test_size=0.16, random_state=42+i
        )
        
        # Confirm split sizes
        print(f"  Train: {len(train_x)} | Val: {len(val_x)} | Test: {len(test_x)}")
        
        # Prepare TF datasets
        # Augmentation is applied dynamically each epoch
        train_dataset = tf_dataset(train_x, train_y, batch=batch, augment=True)
        valid_dataset = tf_dataset(val_x, val_y, batch=batch)
    
        #model = ResUnetPlusPlus()
        model = XBoundNetPP()
    
    
        opt = tf.keras.optimizers.Adam(lr)
        metrics = ["acc", Recall(), Precision(), iou]
        #model.compile(loss=log_dice_loss, optimizer=opt, metrics=metrics)
        model.compile(loss=log_dice_bce_loss, optimizer=opt, metrics=metrics)
        
    
        callbacks = [
            tf.keras.callbacks.ModelCheckpoint(f"XBoundNet_{fold + 1}Fold_LDL07_BCE03_Drop02onlowres_Lre4_ensemble_model_{i}.h5", save_best_only=True),
            tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.1, patience=15),
            tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=50, restore_best_weights=True),
        ]
    
        train_steps = len(train_x) // batch
        valid_steps = len(val_x) // batch
        train_steps += 1 if len(train_x) % batch != 0 else 0
        valid_steps += 1 if len(val_x) % batch != 0 else 0
    
    
        history = model.fit(
            train_dataset,
            validation_data=valid_dataset,
            epochs=epochs,
            #steps_per_epoch=train_steps,
            validation_steps=valid_steps,
            callbacks=callbacks,
            shuffle=False#, verbose=2
        )
    
        # ✅ Evaluate Model
        loss, acc, rec, prec, iou_score = model.evaluate(valid_dataset, steps=valid_steps)
        print(f"🔹 Fold {fold + 1} → Loss: {loss:.4f}, Accuracy: {acc:.4f}, IoU: {iou_score:.4f}")
    
        # ✅ Store results
        fold_losses.append(loss)
        fold_accuracies.append(acc)
        fold_ious.append(iou_score)
    
    # ✅ Final Results
    print("\n✅ Final Ensemble 5-Fold Cross-Validation Results:")
    print(f"Average Ensemble Accuracy: {np.mean(fold_accuracies):.4f} ± {np.std(fold_accuracies):.4f}")
    print(f"Average Ensemble Loss: {np.mean(fold_losses):.4f} ± {np.std(fold_losses):.4f}")
    print(f"Average Ensemble IoU: {np.mean(fold_ious):.4f} ± {np.std(fold_ious):.4f}")


In [None]:
def read_image_predict(path):
    x = cv2.imread(path, cv2.IMREAD_COLOR)
    y = cv2.cvtColor(x, cv2.COLOR_RGB2BGR)
    y = cv2.resize(y, (256,256))
    y = y/255.0
    x = cv2.resize(x, (256,256))
    x = x/255.0    #(256, 256, 3)
    return x, y

def read_mask_predict(path):
    x = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    x = cv2.resize(x, (256,256))
    x = np.expand_dims(x, axis = -1)    #(256, 256, 1)
    return x

def normalize_img(path):
    x = cv2.imread(path, cv2.IMREAD_COLOR)
    x = cv2.cvtColor(x, cv2.COLOR_RGB2BGR)
    x = cv2.resize(x, (256,256))
    normalized_img = cv2.normalize(x, None, 0, 255, cv2.NORM_MINMAX).astype('uint8')
    return x

def mask_parse(mask):
    mask = np.squeeze(mask)
    mask = [mask, mask, mask]
    mask = np.transpose(mask, (1, 2, 0))
    return mask

In [None]:
def keep_largest_component(mask):    
    # Find all connected components
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)

    if num_labels <= 1:
        return binary_mask  # Only background or no components found

    # Ignore label 0 (background), get label of largest component
    largest_component = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])

    # Create new mask with only the largest component
    cleaned_mask = (labels == largest_component).astype(np.uint8) * 255
    return np.expand_dims(cleaned_mask, axis=-1)

In [None]:
def overlay_dual_mask_boundary(original_image, gt_mask, pred_mask):
    """
    Overlays both the ground truth and predicted mask boundaries on the original image.
    
    Args:
        original_image (numpy array): Original RGB image.
        gt_mask (numpy array): Ground truth binary segmentation mask.
        pred_mask (numpy array): Predicted binary segmentation mask.
    
    Returns:
        overlayed_image (numpy array): Original image with both boundaries overlayed.
    """

    # Ensure masks are grayscale
    if len(gt_mask.shape) == 3:
        gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_RGB2GRAY)
    if len(pred_mask.shape) == 3:
        pred_mask = cv2.cvtColor(pred_mask, cv2.COLOR_RGB2GRAY)

    # Convert masks to binary (0 or 255)
    gt_mask = (gt_mask > 0.5).astype(np.uint8) * 255
    pred_mask = (pred_mask > 0.5).astype(np.uint8) * 255

    # Find contours in both masks
    gt_contours, _ = cv2.findContours(gt_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    pred_contours, _ = cv2.findContours(pred_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Create a copy of the original image
    overlayed_image = original_image.copy()

    # Draw ground truth contours in **Red**
    cv2.drawContours(overlayed_image, gt_contours, -1, (255, 0, 0), 1)  # Red boundary

    # Draw predicted contours in **Green**
    cv2.drawContours(overlayed_image, pred_contours, -1, (0, 255, 0), 1)  # Green boundary

    return overlayed_image

def generate_heatmap(model, image, layer = "conv2d_37"):
    """
    Generates a heatmap from the last convolutional layer.
    """
    # Select the last convolutional layer in your model
    last_conv_layer = model.get_layer(layer)  # Update layer name if needed
    heatmap_model = tf.keras.models.Model(inputs=model.input, outputs=last_conv_layer.output)

    # Expand dimensions to match the model's expected input shape
    img_array = np.expand_dims(image, axis=0)

    # Get activations
    conv_output = heatmap_model.predict(img_array)

    # Compute the mean activation across all filters
    heatmap = np.mean(conv_output, axis=-1)[0]

    # Normalize the heatmap for visualization
    heatmap = cv2.resize(heatmap, (256, 256))
    heatmap = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8UC1)

    # Apply colormap (Jet colormap makes high-confidence areas red/yellow)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    
    return heatmap


In [None]:
def enable_dropout(model):
    """ Enable dropout at inference time. """
    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Dropout):
            layer.trainable = True
    return model

In [None]:
def monte_carlo_predictions(model, image, n_samples=50):
    predictions = []

    for _ in range(n_samples):
        pred = model(np.expand_dims(image, axis=0), training=True)[0]  # Force dropout active
        predictions.append(pred.numpy())

    predictions = np.stack(predictions, axis=0)
    mean_prediction = np.mean(predictions, axis=0)
    std_prediction = np.std(predictions, axis=0)

    return mean_prediction, std_prediction

In [None]:
def calculate_mu(uncertainty_map):
    """Mean Uncertainty (MU) as the average standard deviation across the map."""
    return np.mean(uncertainty_map)

def calculate_ruv(uncertainty_map, threshold=0.1):
    """
    Region Uncertainty Volume (RUV) as the proportion of pixels above a threshold.
    """
    uncertain_area = np.sum(uncertainty_map > threshold)
    total_area = uncertainty_map.size
    return uncertain_area / total_area

In [None]:
def compute_uncertainty_metrics(mc_predictions):
    eps = 1e-8
    mc_predictions = np.squeeze(mc_predictions, axis=-1)  # (N, H, W)

    # Mean prediction (mean probability)
    mean_pred = np.mean(mc_predictions, axis=0)

    # Predictive Entropy H(mean)
    predictive_entropy = -(
        mean_pred * np.log(mean_pred + eps) +
        (1 - mean_pred) * np.log(1 - mean_pred + eps)
    )

    # Expected Entropy E[H(p)]
    expected_entropy = np.mean(
        -(
            mc_predictions * np.log(mc_predictions + eps) +
            (1 - mc_predictions) * np.log(1 - mc_predictions + eps)
        ),
        axis=0
    )

    # Mutual Information = H(mean) - E[H(p)]
    mutual_information = predictive_entropy - expected_entropy

    # Variance
    variance = np.var(mc_predictions, axis=0)

    return predictive_entropy, expected_entropy, mutual_information, variance

In [None]:
model_save_file = "Ensemble_XBoundNetCombinedLDL07BCE03Drop02onlowresLre4.h5"

In [None]:
output_dir = os.path.join("kaz_segmentation_results", model_save_file.rsplit('.', 1)[0])
os.makedirs(output_dir, exist_ok=True)

In [None]:
model_save_file = "XBoundNetCombinedLDL07BCE03Drop02onlowresLre4.h5"

In [None]:
# Load trained models
models = [
    tf.keras.models.load_model(f"XBoundNet_{i}Fold_LDL07_BCE03_Drop02onlowres_Lre4.h5", custom_objects={"iou": iou, "log_dice_bce_loss": log_dice_bce_loss})
    for i in range(1, 6)
]

In [None]:
def keep_largest_components(prediction, threshold=0.4, min_component_area=50, apply_morph=True, component_rank=1):
    
    # Step 1: Threshold the soft prediction
    binary_mask = (prediction > threshold).astype(np.uint8)

    # Optional: Morphological closing to seal small holes/gaps
    if apply_morph:
        kernel = np.ones((5, 5), np.uint8)
        binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)

    # Step 2: Connected components
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary_mask, connectivity=8)

    # Step 3: Keep largest valid component (ignore background label 0)
    if num_labels <= 1:
        return np.zeros_like(prediction, dtype=np.uint8)[..., np.newaxis]

    # Step 3: Sort components by area, descending (excluding background 0)
    component_areas = [
        (i, stats[i, cv2.CC_STAT_AREA]) for i in range(1, num_labels)
        if stats[i, cv2.CC_STAT_AREA] > min_component_area
    ]
    sorted_components = sorted(component_areas, key=lambda x: x[1], reverse=True)

    if len(sorted_components) < component_rank:
        return np.zeros_like(prediction, dtype=np.uint8)[..., np.newaxis]

    target_component = sorted_components[component_rank - 1][0]
    cleaned_mask = (labels == target_component).astype(np.uint8)
    filled_mask = binary_fill_holes(cleaned_mask).astype(np.uint8) * 255
    return filled_mask[..., np.newaxis]


In [None]:
def overlay_uncertainty_mask(image, prediction, ground_truth, uncertainty_map, thresholds=[0, 25, 50, 75], alpha=0.4):
    """
    Overlay uncertainty-aware TP/FP/FN/Uncertain masks using thresholds in [0,100] (not percentile).
    """
    assert image.shape[:2] == prediction.shape == ground_truth.shape == uncertainty_map.shape
    results = []
    norm_unc_map = (uncertainty_map / np.max(uncertainty_map)) * 100.0

    for t in thresholds:
        # Mask for certain voxels
        certain_mask = (norm_unc_map <= t).astype(np.uint8)

        pred_masked = prediction * certain_mask
        gt_masked = ground_truth * certain_mask

        # Categories
        tp = np.logical_and(pred_masked == 1, gt_masked == 1)
        fp = np.logical_and(pred_masked == 1, gt_masked == 0)
        fn = np.logical_and(pred_masked == 0, gt_masked == 1)
        uncertain = (certain_mask == 0)

        # Base image
        base = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.copy()
        if base.ndim == 2:
            base = cv2.cvtColor(base, cv2.COLOR_GRAY2BGR)

        overlay = base.astype(np.float32)

        overlay[tp] = (1 - alpha) * overlay[tp] + alpha * np.array([0, 255, 0])     # Green
        overlay[fp] = (1 - alpha) * overlay[fp] + alpha * np.array([0, 0, 255])     # Red
        overlay[fn] = (1 - alpha) * overlay[fn] + alpha * np.array([255, 0, 0])     # Blue
        overlay[uncertain] = (1 - alpha) * overlay[uncertain] + alpha * np.array([255, 255, 0])  # Yellow

        overlay = np.clip(overlay, 0, 255).astype(np.uint8)
        results.append((t, overlay))

    return results


In [None]:
def binary_mask_overlay(mask, original_img, color=(200, 200, 255), alpha = 0.25):
    base = original_img.copy().astype(np.float32)
    overlay = base.copy().astype(np.float32)
    mask_color = np.zeros_like(base)
    mask_color[mask > 0] = color
    overlay = cv2.addWeighted(overlay, 1 - alpha, mask_color.astype(np.float32), alpha, 0)
    return overlay

def heatmap_overlay(conf_map, original_img, alpha = 0.25):
    base = original_img.copy().astype(np.float32)
    colormap = cv2.applyColorMap(conf_map, cv2.COLORMAP_INFERNO)
    overlay = cv2.addWeighted(base, 1 - alpha, colormap.astype(np.float32), alpha, 0)
    return overlay

In [None]:
def test_images(show_final=True, show_conv_layers=False, show_thresholds=False, show_uncertainty=False):
    # Iterate through test images
    for fold, (train_idx, test_idx) in enumerate(kf.split(images)):
        print(f"\n🔹 Predicting on Fold {fold + 1}/5")
    
        test_x_fold, test_y_fold = images[test_idx], masks[test_idx]  # Get test images/masks for the current fold
    
        model = models[fold]  # Load corresponding trained model
    
        for x_path, y_path in tqdm(zip(test_x_fold, test_y_fold), total=len(test_x_fold)):
            fname = os.path.basename(x_path)
            tname = fname.rsplit('.', 1)[0]
            #if (tname != "Pt 0006 - RA - TestID20240221154926--02-2810-21-39"):
             #   continue
            output_path = os.path.join(output_dir, tname)
            print("saving to: " + output_path)
            os.makedirs(output_path, exist_ok=True)
            #tname = "test_" + tname
            x, rgb_img = read_image_predict(x_path)
            y = read_mask_predict(y_path)
            x_normalized = normalize_img(x_path)
        
            # Generate predictions
            mc_predictions = np.stack([
                model(np.expand_dims(x, axis=0), training=True)[0].numpy()
                for _ in range(50)
            ], axis=0)  # (50, 256, 256, 1)
            
            mean_pred, uncertainty_map = monte_carlo_predictions(model, x, n_samples=50)
            mean_pred_calibrated = calibrate_predictions(mean_pred, calibrator)
            p_entropy, e_entropy, mi, var = compute_uncertainty_metrics(mc_predictions)

            calibrated_pred = cv2.normalize(mean_pred_calibrated, None, 0, 255, cv2.NORM_MINMAX)
            calibrated_pred = calibrated_pred.astype(np.uint8)
      
            processed_pred_resized = keep_largest_components(calibrated_pred, threshold=70)

            binary_gt = (y > 127).astype(np.uint8).squeeze()
            binary_pred = (processed_pred_resized > 127).astype(np.uint8).squeeze()
            
            overlays = overlay_uncertainty_mask(rgb_img, binary_pred, binary_gt, p_entropy, thresholds=[100, 75, 50, 25])
            
            for th, vis in overlays:
                tmp_path = os.path.join(output_path, f"Threshold_{th}.bmp")
                cv2.imwrite(tmp_path, vis)
                if show_thresholds:
                    plt.figure(figsize=(5, 5))
                    plt.imshow(vis)
                    plt.title(f"Threshold {th}")
                    plt.axis("off")
                    plt.show()

        
            conv_layers = [layer for layer in model.layers if isinstance(layer, tf.keras.layers.Conv2D)]
            for i, layer in enumerate(conv_layers):
                #print(f"Layer {i}: {layer.name}, Output Shape: {layer.output_shape}")
                heatmap = generate_heatmap(model, x, layer.name)
    
                if show_conv_layers:
                    # Display heatmap for the current layer
                    plt.figure(figsize=(6, 6))
                    plt.imshow(heatmap)
                    plt.axis('off')
                    plt.title(f"Heatmap from Layer: {layer.name}")
                    
                    # Show each heatmap separately
                    plt.show()
        
                # Convert heatmap to uint8 (0-255) for saving
                #heatmap_uint8 = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX).astype('uint8')
    
                overlay_result = overlay_dual_mask_boundary(heatmap, mask_parse(y), mask_parse(binary_pred))
                tmp_path = os.path.join(output_path, f"{layer.name}_with_contour.bmp")
                cv2.imwrite(tmp_path, overlay_result)     
                tmp_path = os.path.join(output_path, f"{layer.name}.bmp")
                cv2.imwrite(tmp_path, heatmap)
                    
    
            tmp_path = os.path.join(output_path, "original_image.bmp")
            og_uint8 = cv2.normalize(rgb_img, None, 0, 255, cv2.NORM_MINMAX).astype('uint8')
            cv2.imwrite(tmp_path, og_uint8) 
        
            tmp_path = os.path.join(output_path, "groundtruth_mask.bmp")
            cv2.imwrite(tmp_path, y)
        
            tmp_path = os.path.join(output_path, "predicted_mask.bmp")
            cv2.imwrite(tmp_path, processed_pred_resized)

            tmp_path = os.path.join(output_path, "mean_prediction_map.bmp")
            pred_map_uint8 = cv2.normalize(mean_pred.squeeze(), None, 0, 255, cv2.NORM_MINMAX).astype('uint8')
            colored_pred = cv2.applyColorMap(pred_map_uint8, cv2.COLORMAP_INFERNO)
            cv2.imwrite(tmp_path, colored_pred)

            tmp_path = os.path.join(output_path, "uncertainty_map.bmp")
            uncert_map_uint8 = cv2.normalize(uncertainty_map, None, 0, 255, cv2.NORM_MINMAX).astype('uint8')
            colored_uncert = cv2.applyColorMap(uncert_map_uint8, cv2.COLORMAP_INFERNO)
            cv2.imwrite(tmp_path, colored_uncert)

            tmp_path = os.path.join(output_path, "Mutual_Information_(Epistemic).bmp")
            mi_uint8 = cv2.normalize(mi, None, 0, 255, cv2.NORM_MINMAX).astype('uint8')
            colored_mi = cv2.applyColorMap(mi_uint8, cv2.COLORMAP_INFERNO)
            cv2.imwrite(tmp_path, colored_mi)

            tmp_path = os.path.join(output_path, "Predictive_Entropy.bmp")
            p_entropy_uint8 = cv2.normalize(p_entropy, None, 0, 255, cv2.NORM_MINMAX).astype('uint8')
            colored_pe = cv2.applyColorMap(p_entropy_uint8, cv2.COLORMAP_INFERNO)
            cv2.imwrite(tmp_path, colored_pe)

            tmp_path = os.path.join(output_path, "Variance.bmp")
            var_uint8 = cv2.normalize(var, None, 0, 255, cv2.NORM_MINMAX).astype('uint8')
            colored_var = cv2.applyColorMap(var_uint8, cv2.COLORMAP_INFERNO)
            cv2.imwrite(tmp_path, colored_var)
            
            contour_result_on_original = overlay_dual_mask_boundary(x_normalized, mask_parse(y), mask_parse(binary_pred))
            tmp_path = os.path.join(output_path, "contour_result_on_original.bmp")
            cont_uint8 = cv2.normalize(contour_result_on_original, None, 0, 255, cv2.NORM_MINMAX).astype('uint8')
            cv2.imwrite(tmp_path, cont_uint8)

            prediction_over_img = binary_mask_overlay(processed_pred_resized.squeeze(), x_normalized)
            tmp_path = os.path.join(output_path, "prediction_over_img.bmp")
            cv2.imwrite(tmp_path, prediction_over_img)

            mask_over_img = binary_mask_overlay(y.squeeze(), x_normalized)
            tmp_path = os.path.join(output_path, "mask_over_img.bmp")
            cv2.imwrite(tmp_path, mask_over_img)

            heatmap_over_image = heatmap_overlay(pred_map_uint8, x_normalized)
            tmp_path = os.path.join(output_path, "heatmap_over_image.bmp")
            cv2.imwrite(tmp_path, heatmap_over_image)

            uncertainty_over_image = heatmap_overlay(uncert_map_uint8, x_normalized)
            tmp_path = os.path.join(output_path, "uncertainty_over_image.bmp")
            cv2.imwrite(tmp_path, uncertainty_over_image)
                        
            if show_uncertainty:
                # 🔍 Show predictions
                plt.figure(figsize=(6, 6))
                plt.imshow(processed_pred_resized.squeeze(), cmap='gray')
                plt.axis('off')
                plt.title("Processed Prediction")
                plt.show()
                
                # Display heatmap for the current layer
                plt.figure(figsize=(6, 6))
                plt.imshow(mean_pred)
                plt.axis('off')
                plt.title("Mean prediction map")
                plt.show()
    
                plt.figure(figsize=(6, 6))
                #plt.imshow(uncertainty_map)
                plt.imshow(uncertainty_map, cmap='hot')
                plt.axis('off')
                plt.title("uncertainty prediction map")
                plt.show()
    
                plt.figure(figsize=(6, 6))
                plt.imshow(mi, cmap='hot')
                plt.axis('off')
                plt.title("Mutual Information (Epistemic)")
                plt.show()
    
                plt.figure(figsize=(6, 6))
                plt.imshow(p_entropy, cmap='hot')
                plt.axis('off')
                plt.title("Predictive Entropy")
                plt.show()
    
                plt.figure(figsize=(6, 6))
                plt.imshow(var, cmap='hot')
                plt.axis('off')
                plt.title("Variance (Std Dev²)")
                plt.show()
            
            if show_final:

                # Normalize grayscale base
                base_gray = x_normalized.astype(np.float32)
                base_gray -= base_gray.min()
                base_gray /= (base_gray.max() + 1e-8)
            
                fig, axes = plt.subplots(1, 5, figsize=(32, 5))
            
                # 1. Original Image
                axes[0].imshow(base_gray, cmap='gray')  # Ensure grayscale colormap is used
                axes[0].set_title(f"Original Image: {tname}")
                axes[0].axis("off")
            
                # 2. Ground Truth Mask
                axes[1].imshow(base_gray, cmap='gray')
                gt_mask = np.ma.masked_where(binary_gt == 0, binary_gt)
                axes[1].imshow(gt_mask, cmap='spring', alpha=0.25)
                axes[1].set_title("Annotated Mask over Image")
                axes[1].axis("off")
            
                # 3. Predicted Mask
                axes[2].imshow(base_gray, cmap='gray')
                pred_mask = np.ma.masked_where(binary_pred == 0, binary_pred)
                axes[2].imshow(pred_mask, cmap='winter', alpha=0.25)
                axes[2].set_title("Predicted Mask over Image")
                axes[2].axis("off")
                            
                # 4. Probability Map Overlay
                axes[3].imshow(base_gray, cmap='gray')
                mean_mask = np.ma.masked_where(mean_pred <= 0.001, mean_pred)
                im3 = axes[3].imshow(mean_mask, cmap='magma', alpha=0.25)
                axes[3].set_title("Probability over Image")
                axes[3].axis("off")
                cbar = fig.colorbar(im3, ax=axes[3], fraction=0.046, pad=0.04)
                cbar.set_label("Probability", rotation=270, labelpad=15)
            
                # 5. Uncertainty Map Overlay
                axes[4].imshow(base_gray, cmap='gray')
                uncert_mask = np.ma.masked_where(uncertainty_map <= 0.03, uncertainty_map)
                im4 = axes[4].imshow(uncert_mask, cmap='magma', alpha=0.25)
                axes[4].set_title("Uncertainty over Image")
                axes[4].axis("off")
                cbar = fig.colorbar(im4, ax=axes[4], fraction=0.046, pad=0.04)
                cbar.set_label("Uncertainty", rotation=270, labelpad=15)
            
                # Save
                tmp_path = os.path.join(output_path, "output_side_to_side.png")
                plt.savefig(tmp_path, bbox_inches="tight", dpi=600)
                plt.tight_layout()
                #plt.show()

In [None]:
test_images(show_thresholds = True, show_uncertainty = True, show_final=True)

In [None]:
test_images(show_final=True)

In [None]:
all_val_preds = []
all_val_labels = []

for fold, (trainval_idx, test_idx) in enumerate(kf.split(images)):
    print(f"\n🔹 Training Fold {fold + 1}/5")
    
    trainval_x_fold, test_x_fold = images[trainval_idx], images[test_idx]
    trainval_y_fold, test_y_fold = masks[trainval_idx], masks[test_idx]

    # From trainval (80% of total), carve out 16% as validation
    train_x_fold, val_x_fold, train_y_fold, val_y_fold = train_test_split(
        trainval_x_fold, trainval_y_fold, test_size=0.16, random_state=42
    )

    model = models[fold]  # Load corresponding trained model

    for xval, yval in tqdm(zip(val_x_fold, val_y_fold), total=len(val_x_fold)):
        x, rgb_img = read_image_predict(xval)
        y = read_mask_predict(yval)
        y_bin = (y > 127).astype(np.uint8)  # ✅ Convert to 0/1
        
        y_pred = model.predict(np.expand_dims(x, axis=0))[0]
        
        all_val_preds.append(y_pred.flatten())
        all_val_labels.append(y_bin.flatten())
        
val_preds = np.concatenate(all_val_preds)
val_labels = np.concatenate(all_val_labels)

In [None]:
calibrator = LogisticRegression(solver='liblinear')
calibrator.fit(val_preds.reshape(-1, 1), val_labels)

In [None]:
calibratorname = f"sigmoid_calibrator{model_save_file.rsplit('.', 1)[0]}.pkl"

In [None]:
joblib.dump(calibrator, calibratorname)

In [None]:
calibrator = joblib.load(calibratorname)

In [None]:
def calibrate_predictions(raw_preds, calibrator):
    flat_preds = raw_preds.flatten()
    calibrated_flat = calibrator.predict_proba(flat_preds.reshape(-1, 1))[:, 1]
    return calibrated_flat.reshape(raw_preds.shape)

In [None]:
plt.plot(history.history["acc"], label="Train Acc")
plt.plot(history.history["val_acc"], label="Val Acc")
plt.title("Training & Validation Accuracy")
plt.legend()


plt.show()

plt.plot(history.history["loss"], label="Train Loss")
plt.plot(history.history["val_loss"], label="Val Loss")
plt.title("Training & Validation Loss")
plt.legend()
plt.show()

In [None]:
model.evaluate(valid_dataset, steps=valid_steps)

In [None]:
pred_output_path = r"/path/to/dataset/Predictions"
#pred_folder_name = "XBN_AUG_LDL_NODROP"
#pred_folder_name = "XBN_AUG_LDL_DROP02"
#pred_folder_name = "XBN_AUG_LDL_DROP02Lre3"
#pred_folder_name = "XBN_AUG_LDL07_BCE03_DROP02onlylowres_Lre4"
pred_folder_name = "Ensemble_XBN_AUG_LDL07_BCE03_DROP02onlylowres_Lre4"
pred_output_dir = os.path.join(pred_output_path, pred_folder_name)
os.makedirs(pred_output_dir, exist_ok=True)

In [None]:
# Store evaluation results
test_fold_losses, test_fold_accuracies, test_fold_recall, test_fold_precision, test_fold_ious = [], [], [], [], []

valid_fold_losses, valid_fold_accuracies, valid_fold_recall, valid_fold_precision, valid_fold_ious = [], [], [], [], []

# ✅ Evaluate each model on its corresponding validation set
for fold, (trainval_idx, test_idx) in enumerate(kf.split(images)):

    trainval_x, test_x = images[trainval_idx], images[test_idx]
    trainval_y, test_y = masks[trainval_idx], masks[test_idx]

    # From trainval (80% of total), carve out 16% as validation
    train_x, val_x, train_y, val_y = train_test_split(
        trainval_x, trainval_y, test_size=0.16, random_state=42
    )
    
    # Prepare TF datasets
    test_dataset = tf_dataset(test_x, test_y, batch=8)
    valid_dataset = tf_dataset(val_x, val_y, batch=8)
    

    # Evaluate the corresponding model
    model_f = models[fold]
    testloss, testacc, testrec, testprec, testiou_score = model_f.evaluate(test_dataset, steps=len(test_x) // 4)
    validloss, validacc, validrec, validprec, validiou_score = model_f.evaluate(valid_dataset, steps=len(val_x) // 4)

    # Print results
    print(f"🔹 Fold {fold + 1} on Test Set → Loss: {testloss:.3f}, Accuracy: {testacc:.3f}, Recall: {testrec:.3f}, Precision: {testprec:.3f}, IoU: {testiou_score:.3f}")
    print(f"🔹 Fold {fold + 1} on Validation Set → Loss: {validloss:.3f}, Accuracy: {validacc:.3f}, Recall: {validrec:.3f}, Precision: {validprec:.3f}, IoU: {validiou_score:.3f}")

    # Store results
    test_fold_losses.append(testloss)
    test_fold_accuracies.append(testacc)
    test_fold_recall.append(testrec)
    test_fold_precision.append(testprec)
    test_fold_ious.append(testiou_score)

    valid_fold_losses.append(validloss)
    valid_fold_accuracies.append(validacc)
    valid_fold_recall.append(validrec)
    valid_fold_precision.append(validprec)
    valid_fold_ious.append(validiou_score)

# ✅ Print Final Results
print("\n✅ Final 5-Fold Cross-Validation Evaluation:")
print(f"Average Test Accuracy: {np.mean(test_fold_accuracies):.3f} ± {np.std(test_fold_accuracies):.3f}")
print(f"Average Test Recall: {np.mean(test_fold_recall):.3f} ± {np.std(test_fold_recall):.3f}")
print(f"Average Test Precision: {np.mean(test_fold_precision):.3f} ± {np.std(test_fold_precision):.3f}")
print(f"Average Test IoU: {np.mean(test_fold_ious):.3f} ± {np.std(test_fold_ious):.3f}")
print(f"Average Test Loss: {np.mean(test_fold_losses):.3f} ± {np.std(test_fold_losses):.3f}")

print(f"Average Valid Accuracy: {np.mean(valid_fold_accuracies):.3f} ± {np.std(valid_fold_accuracies):.3f}")
print(f"Average Valid Recall: {np.mean(valid_fold_recall):.3f} ± {np.std(valid_fold_recall):.3f}")
print(f"Average Valid Precision: {np.mean(valid_fold_precision):.3f} ± {np.std(valid_fold_precision):.3f}")
print(f"Average Valid IoU: {np.mean(valid_fold_ious):.3f} ± {np.std(valid_fold_ious):.3f}")
print(f"Average Valid Loss: {np.mean(valid_fold_losses):.3f} ± {np.std(valid_fold_losses):.3f}")

In [None]:
def generate_quantitative_heatmap_grid(heatmap_matrix, layer_name, save_path, step=16):
    # Downsample to make values readable
    grid_data = heatmap_matrix[::step, ::step]

    plt.figure(figsize=(10, 8))
    sns.heatmap(grid_data, annot=True, fmt=".2f", cmap="viridis", 
                xticklabels=False, yticklabels=False, cbar=True, square=True)
    plt.title(f"Quantitative Heatmap - {layer_name}")
    plt.tight_layout()
    #plt.savefig(save_path, dpi=300)
    #plt.close()
    plt.show()

In [None]:
for fold, (train_idx, test_idx) in enumerate(kf.split(images)):
    print(f"\n🔹 Predicting on Fold {fold + 1}/5")

    test_x_fold, test_y_fold = images[test_idx], masks[test_idx]  # Get test images/masks for the current fold

    model = models[fold]  # Load corresponding trained model

    for x_path, y_path in tqdm(zip(test_x_fold, test_y_fold), total=len(test_x_fold)):
        fname = os.path.basename(x_path)
        tname = fname.rsplit('.', 1)[0]
        output_path = os.path.join(output_dir, tname)
        x, rgb_img = read_image_predict(x_path)
        y = read_mask_predict(y_path)
        
        # Generate predictions
        y_pred = model.predict(np.expand_dims(x, axis=0))[0]
    
        conv_layers = [layer for layer in model.layers if isinstance(layer, tf.keras.layers.Conv2D)]
        for i, layer in enumerate(conv_layers):
            #print(f"Layer {i}: {layer.name}, Output Shape: {layer.output_shape}")
            heatmap = generate_heatmap(model, x, layer.name)

            # Display heatmap for the current layer
            plt.figure(figsize=(6, 6))
            plt.imshow(heatmap)
            plt.axis('off')
            plt.title(f"Heatmap from Layer: {layer.name}")
            
            # Show each heatmap separately
            plt.show()

            quant_path = os.path.join(output_path, f"{layer.name}_quant_heatmap.png")
            generate_quantitative_heatmap_grid(cv2.cvtColor(heatmap, cv2.COLOR_BGR2GRAY) / 255.0, layer.name, quant_path)
            
        break
    break

In [None]:


for fold, (train_idx, test_idx) in enumerate(kf.split(images)):
    print(f"\n🔹 Predicting on Fold {fold + 1}/5")

    test_x_fold, test_y_fold = images[test_idx], masks[test_idx]  # Get test images/masks for the current fold

    model = models[fold]  # Load corresponding trained model

    for x_path, y_path in tqdm(zip(test_x_fold, test_y_fold), total=len(test_x_fold)):
        x, rgb_img = read_image_predict(x_path)
        y = read_mask_predict(y_path)

        # ✅ Make Prediction
        y_pred = model.predict(np.expand_dims(x, axis=0))[0]

        # ✅ Resize prediction to original dimensions (width=788, height=510)
        y_pred_resized = cv2.resize(y_pred, (original_width, original_height), interpolation=cv2.INTER_LANCZOS4)
        
        y_pred_resized = keep_largest_component(y_pred_resized)
        
        # ✅ Remove channel dimension if present
        if y_pred_resized.ndim == 3:
            y_pred_resized = np.squeeze(y_pred_resized, axis=-1)
        
        # Save spot
        fname = os.path.basename(x_path)
        output_path = os.path.join(pred_output_dir, fname)

        pred_uint8 = cv2.normalize(y_pred_resized, None, 0, 255, cv2.NORM_MINMAX).astype('uint8')
        cv2.imwrite(output_path, pred_uint8)

In [None]:
def run_inference(save_ensemble=False):
    for fold, (train_idx, test_idx) in enumerate(kf.split(images)):
        print(f"\n🔹 Predicting on Fold {fold + 1}/5")

        test_x_fold, test_y_fold = images[test_idx], masks[test_idx]
        model = models[fold]
        model = enable_dropout(model)

        for x_path, y_path in tqdm(zip(test_x_fold, test_y_fold), total=len(test_x_fold)):
            x, rgb_img = read_image_predict(x_path)
            y = read_mask_predict(y_path)
            fname = os.path.basename(x_path)
            tname = fname.rsplit('.', 1)[0]

            # Generate MC predictions (50 samples)
            mc_predictions = np.stack([
                model(np.expand_dims(x, axis=0), training=True)[0].numpy()
                for _ in range(50)
            ], axis=0)  # (50, 256, 256, 1)

            # Compute all uncertainty metrics
            mean_pred, uncertainty_map = monte_carlo_predictions(model, x, n_samples=50)
            p_entropy, e_entropy, mi, var = compute_uncertainty_metrics(mc_predictions)

            # Post-process prediction
            mean_pred_calibrated = calibrate_predictions(mean_pred, calibrator)

            calibrated_pred_resized = cv2.normalize(mean_pred_calibrated, None, 0, 255, cv2.NORM_MINMAX)
            calibrated_pred_resized = calibrated_pred_resized.astype(np.uint8)
      
            calibrated_pred_resized = cv2.resize(
                calibrated_pred_resized.astype(np.uint8), (788, 510), interpolation=cv2.INTER_LANCZOS4
            )

            processed_pred_resized = keep_largest_components(calibrated_pred_resized, threshold=70)


            # ✅ Remove channel dimension if present
            if processed_pred_resized.ndim == 3:
                processed_pred_resized = np.squeeze(processed_pred_resized, axis=-1)
            
            if save_ensemble:
                output_path = os.path.join(pred_output_dir, f"{tname}.bmp")
                cv2.imwrite(output_path, processed_pred_resized)
                continue

            mu = calculate_mu(uncertainty_map)
            ruv = calculate_ruv(uncertainty_map)

            # 🧠 Print numerical summaries
            print(f"\n{tname}")
            print(f"Mean Uncertainty (MU): {mu}")
            print(f"Region Uncertainty Volume (RUV): {ruv}")
            print(f"Mean Mutual Information: {np.mean(mi):.4f}")
            print(f"Predictive Entropy Mean: {np.mean(p_entropy):.4f}")
            print(f"Variance (Std Dev²) Mean: {np.mean(var):.4f}")

            # 🔍 Show predictions
            plt.figure(figsize=(6, 6))
            plt.imshow(processed_pred_resized.squeeze(), cmap='gray')
            plt.axis('off')
            plt.title("Resized Processed Prediction")
            plt.show()

            y_mask_gt = cv2.imread(y_path, cv2.IMREAD_GRAYSCALE)
            plt.figure(figsize=(6, 6))
            plt.imshow(y_mask_gt, cmap='gray')
            plt.axis('off')
            plt.title("Ground Truth mask")
            plt.show()
            
            # Display heatmap for the current layer
            plt.figure(figsize=(6, 6))
            plt.imshow(mean_pred)
            plt.axis('off')
            plt.title("Mean prediction map")
            plt.show()

            plt.figure(figsize=(6, 6))
            plt.imshow(uncertainty_map)
            #plt.imshow(uncertainty_map, cmap='hot')
            plt.axis('off')
            plt.title("uncertainty prediction map")
            plt.show()

            plt.figure(figsize=(6, 6))
            plt.imshow(mi, cmap='hot')
            plt.axis('off')
            plt.title("Mutual Information (Epistemic)")
            plt.show()

            plt.figure(figsize=(6, 6))
            plt.imshow(p_entropy, cmap='hot')
            plt.axis('off')
            plt.title("Predictive Entropy")
            plt.show()

            plt.figure(figsize=(6, 6))
            plt.imshow(var, cmap='hot')
            plt.axis('off')
            plt.title("Variance (Std Dev²)")
            plt.show()

            

In [None]:
run_inference()

In [None]:
run_inference(save_ensemble=True)

In [None]:
def predict_ensemble(nfolds=5, nensembles=5, save_ensemble=False, mod="", manual_segment_this_path=None):
    
    for fold, (train_idx, test_idx) in enumerate(kf.split(images)):
        print(f"\n🔹 Predicting on Fold {fold + 1}/5")
        test_x_fold, test_y_fold = images[test_idx], masks[test_idx]

        ensemble_models = []
        for model_num in range(1, nensembles+1):
            model_string = mod.format(fold+1, model_num)
            m = tf.keras.models.load_model(model_string, custom_objects={"iou": iou, "log_dice_bce_loss": log_dice_bce_loss})
            m = enable_dropout(m)
            ensemble_models.append(m)
        
        for x_path, y_path in tqdm(zip(test_x_fold, test_y_fold), total=len(test_x_fold)):
            
            if manual_segment_this_path is not None:
                if manual_segment_this_path != x_path:
                    continue
            
            x, rgb_img = read_image_predict(x_path)
            y = read_mask_predict(y_path)
            fname = os.path.basename(x_path)
            tname = fname.rsplit('.', 1)[0]

            mcpreds = []
            meanpreds = []
            uncmaps = []

            for model in ensemble_models:
                # Generate MC predictions (50 samples)
                mc_predictions = np.stack([
                    model(np.expand_dims(x, axis=0), training=True)[0].numpy()
                    for _ in range(50)
                ], axis=0)  # (50, 256, 256, 1)

                # Compute all uncertainty metrics
                mean_pred, uncertainty_map = monte_carlo_predictions(model, x, n_samples=50)
                
                mcpreds.append(mc_predictions)
                meanpreds.append(mean_pred)
                uncmaps.append(uncertainty_map)

            # Final ensemble average across models
            mc_preds_all = np.stack(mcpreds, axis=0)      # (n_models, 50, 256, 256, 1)
            mean_preds_all = np.stack(meanpreds, axis=0)   # (n_models, 256, 256, 1)
            uncmaps_all = np.stack(uncmaps, axis=0)    # (n_models, 256, 256, 1)

            mean_mc_ensemble = mc_preds_all.mean(axis=(0,1))        # (256, 256, 1)
            mean_pred_ensemble = mean_preds_all.mean(axis=0)          # (256, 256, 1)
            mean_uncmap_ensemble = uncmaps_all.mean(axis=0)           # (256, 256, 1)

            p_entropy, e_entropy, mi, var = compute_uncertainty_metrics(mc_preds_all.mean(axis=0))

            # Post-process prediction
            mean_pred_calibrated = calibrate_predictions(mean_pred_ensemble, calibrator)

            calibrated_pred_resized = cv2.normalize(mean_pred_calibrated, None, 0, 255, cv2.NORM_MINMAX)
            calibrated_pred_resized = calibrated_pred_resized.astype(np.uint8)
      
            calibrated_pred_resized = cv2.resize(
                calibrated_pred_resized.astype(np.uint8), (788, 510), interpolation=cv2.INTER_LANCZOS4
            )

            processed_pred_resized = keep_largest_components(calibrated_pred_resized, threshold=70)

            if manual_segment_this_path is not None:
                processed_pred_resized = keep_largest_components(calibrated_pred_resized, threshold=30, component_rank=2)


            # ✅ Remove channel dimension if present
            if processed_pred_resized.ndim == 3:
                processed_pred_resized = np.squeeze(processed_pred_resized, axis=-1)
            
            if save_ensemble:
                output_path = os.path.join(pred_output_dir, f"{tname}.bmp")
                cv2.imwrite(output_path, processed_pred_resized)
                continue

            mu = calculate_mu(mean_uncmap_ensemble)
            ruv = calculate_ruv(mean_uncmap_ensemble)

            # 🧠 Print numerical summaries
            print(f"\n{tname}")
            print(f"Mean Uncertainty (MU): {mu}")
            print(f"Region Uncertainty Volume (RUV): {ruv}")
            print(f"Mean Mutual Information: {np.mean(mi):.4f}")
            print(f"Predictive Entropy Mean: {np.mean(p_entropy):.4f}")
            print(f"Variance (Std Dev²) Mean: {np.mean(var):.4f}")

            # 🔍 Show predictions
            plt.figure(figsize=(6, 6))
            plt.imshow(processed_pred_resized.squeeze(), cmap='gray')
            plt.axis('off')
            plt.title("Resized Processed Prediction")
            plt.show()

            y_mask_gt = cv2.imread(y_path, cv2.IMREAD_GRAYSCALE)
            plt.figure(figsize=(6, 6))
            plt.imshow(y_mask_gt, cmap='gray')
            plt.axis('off')
            plt.title("Ground Truth mask")
            plt.show()
            
            # Display heatmap for the current layer
            plt.figure(figsize=(6, 6))
            plt.imshow(mean_pred_ensemble)
            plt.axis('off')
            plt.title("Mean prediction map")
            plt.show()

            plt.figure(figsize=(6, 6))
            plt.imshow(mean_uncmap_ensemble)
            #plt.imshow(mean_uncmap_ensemble, cmap='hot')
            plt.axis('off')
            plt.title("uncertainty prediction map")
            plt.show()

            plt.figure(figsize=(6, 6))
            plt.imshow(mi, cmap='hot')
            plt.axis('off')
            plt.title("Mutual Information (Epistemic)")
            plt.show()

            plt.figure(figsize=(6, 6))
            plt.imshow(p_entropy, cmap='hot')
            plt.axis('off')
            plt.title("Predictive Entropy")
            plt.show()

            plt.figure(figsize=(6, 6))
            plt.imshow(var, cmap='hot')
            plt.axis('off')
            plt.title("Variance (Std Dev²)")
            plt.show()
            return

In [None]:
predict_ensemble(save_ensemble=True, mod="XBoundNet_{}Fold_LDL07_BCE03_Drop02onlowres_Lre4_ensemble_model_{}.h5")

In [None]:
manual_path = r"/path/to/dataset/Images/Pt 0057 - RA - TestID20240223131548--02-2812-45-35.bmp"
#predict_ensemble(save_ensemble=False, mod="XBoundNet_{}Fold_LDL07_BCE03_Drop02onlowres_Lre4_ensemble_model_{}.h5", manual_segment_this_path=manual_path)
predict_ensemble(save_ensemble=True, mod="XBoundNet_{}Fold_LDL07_BCE03_Drop02onlowres_Lre4_ensemble_model_{}.h5", manual_segment_this_path=manual_path)

In [None]:
def test_ensemble(nfolds=5, nensembles=5, mod="", show_final=True, show_conv_layers=False, show_thresholds=False, show_uncertainty=False):
    # Iterate through test images
    for fold, (train_idx, test_idx) in enumerate(kf.split(images)):
        print(f"\n🔹 Predicting on Fold {fold + 1}/5")
    
        test_x_fold, test_y_fold = images[test_idx], masks[test_idx]  # Get test images/masks for the current fold
    
        ensemble_models = []
        for model_num in range(1, nensembles+1):
            model_string = mod.format(fold+1, model_num)
            m = tf.keras.models.load_model(model_string, custom_objects={"iou": iou, "log_dice_bce_loss": log_dice_bce_loss})
            m = enable_dropout(m)
            ensemble_models.append(m)
    
        for x_path, y_path in tqdm(zip(test_x_fold, test_y_fold), total=len(test_x_fold)):
            fname = os.path.basename(x_path)
            tname = fname.rsplit('.', 1)[0]
            #if (tname != "Pt 0006 - RA - TestID20240221154926--02-2810-21-39"):
             #   continue
            output_path = os.path.join(output_dir, tname)
            print("saving to: " + output_path)
            os.makedirs(output_path, exist_ok=True)
            #tname = "test_" + tname
            x, rgb_img = read_image_predict(x_path)
            y = read_mask_predict(y_path)
            x_normalized = normalize_img(x_path)

            mcpreds = []
            meanpreds = []
            uncmaps = []

            for model in ensemble_models:
                # Generate MC predictions (50 samples)
                mc_predictions = np.stack([
                    model(np.expand_dims(x, axis=0), training=True)[0].numpy()
                    for _ in range(50)
                ], axis=0)  # (50, 256, 256, 1)

                # Compute all uncertainty metrics
                mean_pred, uncertainty_map = monte_carlo_predictions(model, x, n_samples=50)
                
                mcpreds.append(mc_predictions)
                meanpreds.append(mean_pred)
                uncmaps.append(uncertainty_map)

            # Final ensemble average across models
            mc_preds_all = np.stack(mcpreds, axis=0)      # (n_models, 50, 256, 256, 1)
            mean_preds_all = np.stack(meanpreds, axis=0)   # (n_models, 256, 256, 1)
            uncmaps_all = np.stack(uncmaps, axis=0)    # (n_models, 256, 256, 1)

            mean_mc_ensemble = mc_preds_all.mean(axis=(0,1))        # (256, 256, 1)
            mean_pred_ensemble = mean_preds_all.mean(axis=0)          # (256, 256, 1)
            mean_uncmap_ensemble = uncmaps_all.mean(axis=0)           # (256, 256, 1)

            p_entropy, e_entropy, mi, var = compute_uncertainty_metrics(mc_preds_all.mean(axis=0))

            # Post-process prediction
            mean_pred_calibrated = calibrate_predictions(mean_pred_ensemble, calibrator)

            calibrated_pred = cv2.normalize(mean_pred_calibrated, None, 0, 255, cv2.NORM_MINMAX)
            calibrated_pred = calibrated_pred.astype(np.uint8)
      
            processed_pred_resized = keep_largest_components(calibrated_pred, threshold=60)

            binary_gt = (y > 127).astype(np.uint8).squeeze()
            binary_pred = (processed_pred_resized > 127).astype(np.uint8).squeeze()
            
            overlays = overlay_uncertainty_mask(rgb_img, binary_pred, binary_gt, p_entropy, thresholds=[100, 75, 50, 25])
            
            for th, vis in overlays:
                tmp_path = os.path.join(output_path, f"Threshold_{th}.bmp")
                cv2.imwrite(tmp_path, vis)
                if show_thresholds:
                    plt.figure(figsize=(5, 5))
                    plt.imshow(vis)
                    plt.title(f"Threshold {th}")
                    plt.axis("off")
                    plt.show()

        
            conv_layers = [layer for layer in model.layers if isinstance(layer, tf.keras.layers.Conv2D)]
            for i, layer in enumerate(conv_layers):
                #print(f"Layer {i}: {layer.name}, Output Shape: {layer.output_shape}")
                heatmap = generate_heatmap(model, x, layer.name)
    
                if show_conv_layers:
                    # Display heatmap for the current layer
                    plt.figure(figsize=(6, 6))
                    plt.imshow(heatmap)
                    plt.axis('off')
                    plt.title(f"Heatmap from Layer: {layer.name}")
                    
                    # Show each heatmap separately
                    plt.show()
        
                # Convert heatmap to uint8 (0-255) for saving
                #heatmap_uint8 = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX).astype('uint8')
    
                overlay_result = overlay_dual_mask_boundary(heatmap, mask_parse(y), mask_parse(binary_pred))
                tmp_path = os.path.join(output_path, f"{layer.name}_with_contour.bmp")
                cv2.imwrite(tmp_path, overlay_result)     
                tmp_path = os.path.join(output_path, f"{layer.name}.bmp")
                cv2.imwrite(tmp_path, heatmap)
                    
    
            tmp_path = os.path.join(output_path, "original_image.bmp")
            og_uint8 = cv2.normalize(rgb_img, None, 0, 255, cv2.NORM_MINMAX).astype('uint8')
            cv2.imwrite(tmp_path, og_uint8) 
        
            tmp_path = os.path.join(output_path, "groundtruth_mask.bmp")
            cv2.imwrite(tmp_path, y)
        
            tmp_path = os.path.join(output_path, "predicted_mask.bmp")
            cv2.imwrite(tmp_path, processed_pred_resized)

            tmp_path = os.path.join(output_path, "mean_prediction_map.bmp")
            pred_map_uint8 = cv2.normalize(mean_pred_ensemble.squeeze(), None, 0, 255, cv2.NORM_MINMAX).astype('uint8')
            colored_pred = cv2.applyColorMap(pred_map_uint8, cv2.COLORMAP_INFERNO)
            cv2.imwrite(tmp_path, colored_pred)

            tmp_path = os.path.join(output_path, "uncertainty_map.bmp")
            uncert_map_uint8 = cv2.normalize(mean_uncmap_ensemble, None, 0, 255, cv2.NORM_MINMAX).astype('uint8')
            colored_uncert = cv2.applyColorMap(uncert_map_uint8, cv2.COLORMAP_INFERNO)
            cv2.imwrite(tmp_path, colored_uncert)

            tmp_path = os.path.join(output_path, "Mutual_Information_(Epistemic).bmp")
            mi_uint8 = cv2.normalize(mi, None, 0, 255, cv2.NORM_MINMAX).astype('uint8')
            colored_mi = cv2.applyColorMap(mi_uint8, cv2.COLORMAP_INFERNO)
            cv2.imwrite(tmp_path, colored_mi)

            tmp_path = os.path.join(output_path, "Predictive_Entropy.bmp")
            p_entropy_uint8 = cv2.normalize(p_entropy, None, 0, 255, cv2.NORM_MINMAX).astype('uint8')
            colored_pe = cv2.applyColorMap(p_entropy_uint8, cv2.COLORMAP_INFERNO)
            cv2.imwrite(tmp_path, colored_pe)

            tmp_path = os.path.join(output_path, "Variance.bmp")
            var_uint8 = cv2.normalize(var, None, 0, 255, cv2.NORM_MINMAX).astype('uint8')
            colored_var = cv2.applyColorMap(var_uint8, cv2.COLORMAP_INFERNO)
            cv2.imwrite(tmp_path, colored_var)
            
            contour_result_on_original = overlay_dual_mask_boundary(x_normalized, mask_parse(y), mask_parse(binary_pred))
            tmp_path = os.path.join(output_path, "contour_result_on_original.bmp")
            cont_uint8 = cv2.normalize(contour_result_on_original, None, 0, 255, cv2.NORM_MINMAX).astype('uint8')
            cv2.imwrite(tmp_path, cont_uint8)

            prediction_over_img = binary_mask_overlay(processed_pred_resized.squeeze(), x_normalized)
            tmp_path = os.path.join(output_path, "prediction_over_img.bmp")
            cv2.imwrite(tmp_path, prediction_over_img)

            mask_over_img = binary_mask_overlay(y.squeeze(), x_normalized)
            tmp_path = os.path.join(output_path, "mask_over_img.bmp")
            cv2.imwrite(tmp_path, mask_over_img)

            heatmap_over_image = heatmap_overlay(pred_map_uint8, x_normalized)
            tmp_path = os.path.join(output_path, "heatmap_over_image.bmp")
            cv2.imwrite(tmp_path, heatmap_over_image)

            uncertainty_over_image = heatmap_overlay(uncert_map_uint8, x_normalized)
            tmp_path = os.path.join(output_path, "uncertainty_over_image.bmp")
            cv2.imwrite(tmp_path, uncertainty_over_image)
                        
            if show_uncertainty:
                # 🔍 Show predictions
                plt.figure(figsize=(6, 6))
                plt.imshow(processed_pred_resized.squeeze(), cmap='gray')
                plt.axis('off')
                plt.title("Processed Prediction")
                plt.show()
                
                # Display heatmap for the current layer
                plt.figure(figsize=(6, 6))
                plt.imshow(mean_pred_ensemble)
                plt.axis('off')
                plt.title("Mean prediction map")
                plt.show()
    
                plt.figure(figsize=(6, 6))
                #plt.imshow(uncertainty_map)
                plt.imshow(mean_uncmap_ensemble, cmap='hot')
                plt.axis('off')
                plt.title("uncertainty prediction map")
                plt.show()
    
                plt.figure(figsize=(6, 6))
                plt.imshow(mi, cmap='hot')
                plt.axis('off')
                plt.title("Mutual Information (Epistemic)")
                plt.show()
    
                plt.figure(figsize=(6, 6))
                plt.imshow(p_entropy, cmap='hot')
                plt.axis('off')
                plt.title("Predictive Entropy")
                plt.show()
    
                plt.figure(figsize=(6, 6))
                plt.imshow(var, cmap='hot')
                plt.axis('off')
                plt.title("Variance (Std Dev²)")
                plt.show()
            
            if show_final:

                # Normalize grayscale base
                base_gray = x_normalized.astype(np.float32)
                base_gray -= base_gray.min()
                base_gray /= (base_gray.max() + 1e-8)
            
                fig, axes = plt.subplots(1, 5, figsize=(32, 5))
            
                # 1. Original Image
                axes[0].imshow(base_gray, cmap='gray')  # Ensure grayscale colormap is used
                axes[0].set_title(f"Original Image: {tname}")
                axes[0].axis("off")
            
                # 2. Ground Truth Mask
                axes[1].imshow(base_gray, cmap='gray')
                gt_mask = np.ma.masked_where(binary_gt == 0, binary_gt)
                axes[1].imshow(gt_mask, cmap='spring', alpha=0.25)
                axes[1].set_title("Annotated Mask over Image")
                axes[1].axis("off")
            
                # 3. Predicted Mask
                axes[2].imshow(base_gray, cmap='gray')
                pred_mask = np.ma.masked_where(binary_pred == 0, binary_pred)
                axes[2].imshow(pred_mask, cmap='winter', alpha=0.25)
                axes[2].set_title("Predicted Mask over Image")
                axes[2].axis("off")
                            
                # 4. Probability Map Overlay
                axes[3].imshow(base_gray, cmap='gray')
                mean_mask = np.ma.masked_where(mean_pred_ensemble <= 0.001, mean_pred_ensemble)
                im3 = axes[3].imshow(mean_mask, cmap='magma', alpha=0.25)
                axes[3].set_title("Probability over Image")
                axes[3].axis("off")
                cbar = fig.colorbar(im3, ax=axes[3], fraction=0.046, pad=0.04)
                cbar.set_label("Probability", rotation=270, labelpad=15)
            
                # 5. Uncertainty Map Overlay
                axes[4].imshow(base_gray, cmap='gray')
                uncert_mask = np.ma.masked_where(mean_uncmap_ensemble <= 0.03, mean_uncmap_ensemble)
                im4 = axes[4].imshow(uncert_mask, cmap='magma', alpha=0.25)
                axes[4].set_title("Uncertainty over Image")
                axes[4].axis("off")
                cbar = fig.colorbar(im4, ax=axes[4], fraction=0.046, pad=0.04)
                cbar.set_label("Uncertainty", rotation=270, labelpad=15)
            
                # Save
                tmp_path = os.path.join(output_path, "output_side_to_side.png")
                plt.savefig(tmp_path, bbox_inches="tight", dpi=600)
                plt.tight_layout()
                #plt.show()

In [None]:
test_ensemble(mod="XBoundNet_{}Fold_LDL07_BCE03_Drop02onlowres_Lre4_ensemble_model_{}.h5", show_final=True)