In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
import nibabel as nib
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, BatchNormalization, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras import backend as K
import tensorflow as tf
from skimage.transform import resize
from tqdm import tqdm
from scipy.ndimage import distance_transform_edt
from skimage.morphology import erosion, dilation, square
from sklearn.model_selection import train_test_split
from scipy.ndimage import rotate
from skimage.transform import rescale

In [None]:
# Configuration (matches paper parameters)
IMG_HEIGHT = 256
IMG_WIDTH = 256
IMG_CHANNELS = 1
BATCH_SIZE = 8  # U-Net 1 uses 8, U-Net 2 uses 4
EPOCHS = 100
INIT_LR = 1e-4
N_FOLDS = 10  # 10-fold cross-validation as in paper
SEED = 42

# Path to CAMUS dataset
DATA_PATH = "/content/drive/MyDrive/database_nifti"


In [None]:
# Define metrics
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

def bce_dice_loss(y_true, y_pred):
    return tf.keras.losses.binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)

# Data loading and preprocessing
def load_nifti_image(file_path):
    img = nib.load(file_path)
    data = img.get_fdata()
    return np.squeeze(data)  # Remove singleton dimensions


# Updated preprocessing with augmentation
def preprocess_patient_with_augmentation(patient_folder):
    print(f"\nProcessing with augmentation: {patient_folder}")

    images = []
    masks_endo = []
    masks_epi = []
    masks_la = []

    views = ['2CH', '4CH']
    time_points = ['ED', 'ES']

    for view in views:
        for tp in time_points:
            base_name = os.path.basename(patient_folder)
            img_path = f"{patient_folder}/{base_name}_{view}_{tp}.nii"
            gt_path = f"{patient_folder}/{base_name}_{view}_{tp}_gt.nii"

            if not os.path.exists(img_path) or not os.path.exists(gt_path):
                continue

            try:
                img = nib.load(img_path).get_fdata()
                gt = nib.load(gt_path).get_fdata()

                img_resized = resize(img, (IMG_HEIGHT, IMG_WIDTH), preserve_range=True, anti_aliasing=True)
                gt_resized = resize(gt, (IMG_HEIGHT, IMG_WIDTH), preserve_range=True, anti_aliasing=False)

                img_resized = (img_resized - img_resized.min()) / (img_resized.max() - img_resized.min())

                # Augmentation: Apply rotations and scalings
                for angle in [0, 90, 180, 270]:
                    img_rotated = rotate(img_resized, angle, reshape=False)
                    gt_rotated = rotate(gt_resized, angle, reshape=False)

                    for scale in [0.9, 1.0, 1.1]:
                        img_scaled = rescale(img_rotated, scale, preserve_range=True, multichannel=False, anti_aliasing=True)
                        gt_scaled = rescale(gt_rotated, scale, preserve_range=True, multichannel=False, anti_aliasing=False)

                        # Ensure the scaled images are resized back to original dimensions
                        img_scaled_resized = resize(img_scaled, (IMG_HEIGHT, IMG_WIDTH), preserve_range=True, anti_aliasing=True)
                        gt_scaled_resized = resize(gt_scaled, (IMG_HEIGHT, IMG_WIDTH), preserve_range=True, anti_aliasing=False)

                        mask_endo = (gt_scaled_resized == 1).astype(np.float32)
                        mask_epi = (gt_scaled_resized == 2).astype(np.float32)
                        mask_la = (gt_scaled_resized == 3).astype(np.float32)

                        images.append(img_scaled_resized[..., np.newaxis])
                        masks_endo.append(mask_endo[..., np.newaxis])
                        masks_epi.append(mask_epi[..., np.newaxis])
                        masks_la.append(mask_la[..., np.newaxis])

            except Exception as e:
                print(f"Error processing {view}_{tp}: {str(e)}")
                continue

    return np.array(images), np.array(masks_endo), np.array(masks_epi), np.array(masks_la)


def preprocess_patient(patient_folder):
    print(f"\nProcessing: {patient_folder}")

    images = []
    masks_endo = []
    masks_epi = []
    masks_la = []

    views = ['2CH', '4CH']
    time_points = ['ED', 'ES']

    for view in views:
        for tp in time_points:
            base_name = os.path.basename(patient_folder)
            img_path = f"{patient_folder}/{base_name}_{view}_{tp}.nii"
            gt_path = f"{patient_folder}/{base_name}_{view}_{tp}_gt.nii"

            if not os.path.exists(img_path) or not os.path.exists(gt_path):
                continue

            try:
                img = nib.load(img_path).get_fdata()
                gt = nib.load(gt_path).get_fdata()

                img_resized = resize(img, (IMG_HEIGHT, IMG_WIDTH), preserve_range=True, anti_aliasing=True)
                gt_resized = resize(gt, (IMG_HEIGHT, IMG_WIDTH), preserve_range=True, anti_aliasing=False)

                img_resized = (img_resized - img_resized.min()) / (img_resized.max() - img_resized.min())

                mask_endo = (gt_resized == 1).astype(np.float32)
                mask_epi = (gt_resized == 2).astype(np.float32)
                mask_la = (gt_resized == 3).astype(np.float32)

                images.append(img_resized[..., np.newaxis])
                masks_endo.append(mask_endo[..., np.newaxis])
                masks_epi.append(mask_epi[..., np.newaxis])
                masks_la.append(mask_la[..., np.newaxis])

            except Exception as e:
                print(f"Error processing {view}_{tp}: {str(e)}")
                continue

    return np.array(images), np.array(masks_endo), np.array(masks_epi), np.array(masks_la)

In [None]:
from tqdm import tqdm

# Updated load_dataset function with a single progress bar
# Updated load_dataset function
def load_dataset_with_split(base_path, test_ratio=0.15):
    patient_folders = sorted([
        os.path.join(base_path, f)
        for f in os.listdir(base_path)
        if f.startswith('patient') and os.path.isdir(os.path.join(base_path, f))
    ])

    # Split into train/test
    np.random.shuffle(patient_folders)
    split_idx = int(len(patient_folders) * (1 - test_ratio))
    train_folders = patient_folders[:split_idx]
    test_folders = patient_folders[split_idx:]

    # Preprocess train data with augmentation
    all_images_train, all_masks_endo_train, all_masks_epi_train, all_masks_la_train = ([] for _ in range(4))

    for patient_folder in tqdm(train_folders, desc="Loading training patients"):
        images, masks_endo, masks_epi, masks_la = preprocess_patient_with_augmentation(patient_folder)
        all_images_train.append(images)
        all_masks_endo_train.append(masks_endo)
        all_masks_epi_train.append(masks_epi)
        all_masks_la_train.append(masks_la)

    # Preprocess test data without augmentation
    all_images_test, all_masks_endo_test, all_masks_epi_test, all_masks_la_test = ([] for _ in range(4))

    for patient_folder in tqdm(test_folders, desc="Loading test patients"):
        images, masks_endo, masks_epi, masks_la = preprocess_patient(patient_folder)
        all_images_test.append(images)
        all_masks_endo_test.append(masks_endo)
        all_masks_epi_test.append(masks_epi)
        all_masks_la_test.append(masks_la)

    return (np.concatenate(all_images_train, axis=0) if all_images_train else np.array([]),
            np.concatenate(all_masks_endo_train, axis=0) if all_masks_endo_train else np.array([]),
            np.concatenate(all_masks_epi_train, axis=0) if all_masks_epi_train else np.array([]),
            np.concatenate(all_masks_la_train, axis=0) if all_masks_la_train else np.array([]),
            np.concatenate(all_images_test, axis=0) if all_images_test else np.array([]),
            np.concatenate(all_masks_endo_test, axis=0) if all_masks_endo_test else np.array([]),
            np.concatenate(all_masks_epi_test, axis=0) if all_masks_epi_test else np.array([]),
            np.concatenate(all_masks_la_test, axis=0) if all_masks_la_test else np.array([]))

In [None]:
# U-Net 1 architecture (optimized for speed)
def unet1(input_size=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)):
    inputs = Input(input_size)

    # Downsampling path
    conv1 = Conv2D(32, 3, activation='relu', padding='same')(inputs)
    conv1 = Conv2D(32, 3, activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, 3, activation='relu', padding='same')(pool1)
    conv2 = Conv2D(64, 3, activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, 3, activation='relu', padding='same')(pool2)
    conv3 = Conv2D(128, 3, activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, 3, activation='relu', padding='same')(pool3)
    conv4 = Conv2D(256, 3, activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    # Bottom
    conv5 = Conv2D(512, 3, activation='relu', padding='same')(pool4)
    conv5 = Conv2D(512, 3, activation='relu', padding='same')(conv5)

    # Upsampling path
    up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=-1)
    conv6 = Conv2D(256, 3, activation='relu', padding='same')(up6)
    conv6 = Conv2D(256, 3, activation='relu', padding='same')(conv6)

    up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=-1)
    conv7 = Conv2D(128, 3, activation='relu', padding='same')(up7)
    conv7 = Conv2D(128, 3, activation='relu', padding='same')(conv7)

    up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=-1)
    conv8 = Conv2D(64, 3, activation='relu', padding='same')(up8)
    conv8 = Conv2D(64, 3, activation='relu', padding='same')(conv8)

    up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=-1)
    conv9 = Conv2D(32, 3, activation='relu', padding='same')(up9)
    conv9 = Conv2D(32, 3, activation='relu', padding='same')(conv9)

    outputs = Conv2D(1, 1, activation='sigmoid')(conv9)

    model = Model(inputs=inputs, outputs=outputs)
    return model

# U-Net 2 architecture (optimized for accuracy)
def unet2(input_size=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)):
    inputs = Input(input_size)

    # Downsampling path
    conv1 = Conv2D(64, 3, padding='same')(inputs)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation('relu')(conv1)
    conv1 = Conv2D(64, 3, padding='same')(conv1)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation('relu')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, 3, padding='same')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation('relu')(conv2)
    conv2 = Conv2D(128, 3, padding='same')(conv2)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation('relu')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(256, 3, padding='same')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation('relu')(conv3)
    conv3 = Conv2D(256, 3, padding='same')(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation('relu')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(512, 3, padding='same')(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation('relu')(conv4)
    conv4 = Conv2D(512, 3, padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation('relu')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    # Bottom
    conv5 = Conv2D(1024, 3, padding='same')(pool4)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation('relu')(conv5)
    conv5 = Conv2D(1024, 3, padding='same')(conv5)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation('relu')(conv5)

    # Upsampling path
    up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=-1)
    conv6 = Conv2D(512, 3, padding='same')(up6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Activation('relu')(conv6)
    conv6 = Conv2D(512, 3, padding='same')(conv6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Activation('relu')(conv6)

    up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=-1)
    conv7 = Conv2D(256, 3, padding='same')(up7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Activation('relu')(conv7)
    conv7 = Conv2D(256, 3, padding='same')(conv7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Activation('relu')(conv7)

    up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=-1)
    conv8 = Conv2D(128, 3, padding='same')(up8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Activation('relu')(conv8)
    conv8 = Conv2D(128, 3, padding='same')(conv8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Activation('relu')(conv8)

    up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=-1)
    conv9 = Conv2D(64, 3, padding='same')(up9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Activation('relu')(conv9)
    conv9 = Conv2D(64, 3, padding='same')(conv9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Activation('relu')(conv9)

    outputs = Conv2D(1, 1, activation='sigmoid')(conv9)

    model = Model(inputs=inputs, outputs=outputs)
    return model


In [None]:
# Evaluation metrics
def calculate_metrics(y_true, y_pred):
    y_pred = (y_pred > 0.5).astype(np.float32)

    # Dice coefficient
    intersection = np.sum(y_true * y_pred)
    dice = (2. * intersection + 1.) / (np.sum(y_true) + np.sum(y_pred) + 1.)

    # Mean absolute distance
    dt_true = distance_transform_edt(1 - y_true.squeeze())
    dt_pred = distance_transform_edt(1 - y_pred.squeeze())
    mean_dist = (np.mean(dt_pred[y_true.squeeze() > 0.5]) +
                 np.mean(dt_true[y_pred.squeeze() > 0.5])) / 2

    # Hausdorff distance (approximation)
    contour_true = y_true.squeeze() - erosion(y_true.squeeze(), square(3))
    contour_pred = y_pred.squeeze() - erosion(y_pred.squeeze(), square(3))

    if np.sum(contour_true) == 0 or np.sum(contour_pred) == 0:
        hausdorff = np.inf
    else:
        dt_true_contour = distance_transform_edt(1 - contour_true)
        dt_pred_contour = distance_transform_edt(1 - contour_pred)
        hd_true = np.max(contour_pred * dt_true_contour)
        hd_pred = np.max(contour_true * dt_pred_contour)
        hausdorff = max(hd_true, hd_pred)

    return dice, mean_dist, hausdorff

In [None]:
def main():
    print("Loading and preprocessing dataset...")
    (X_train, y_endo_train, y_epi_train, y_la_train,
     X_test, y_endo_test, y_epi_test, y_la_test) = load_dataset_with_split(DATA_PATH)

    # Choose which U-Net to use
    unet_version = 1
    batch_size = 8 if unet_version == 1 else 4

    kf = KFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)
    fold_results = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(X_train)):
        print(f"\n=== Fold {fold + 1}/{N_FOLDS} ===")
        X_fold_train, X_val = X_train[train_idx], X_train[val_idx]
        y_fold_train, y_val = y_endo_train[train_idx], y_endo_train[val_idx]

        model = unet1() if unet_version == 1 else unet2()
        model.compile(optimizer=Adam(learning_rate=INIT_LR),
                      loss=bce_dice_loss,
                      metrics=[dice_coef, 'accuracy'])

        callbacks = [
            ModelCheckpoint(f"unet{unet_version}_fold{fold}_best.keras",
                            monitor='val_dice_coef',
                            mode='max',
                            save_best_only=True,
                            verbose=1),
            EarlyStopping(monitor='val_dice_coef', patience=15, mode='max', verbose=1),
            ReduceLROnPlateau(monitor='val_dice_coef', factor=0.5, patience=5, min_lr=1e-6, mode='max', verbose=1)
        ]

        print(f"Training on {len(X_fold_train)} samples, validating on {len(X_val)} samples")
        model.fit(X_fold_train, y_fold_train, batch_size=batch_size, epochs=EPOCHS,
                  validation_data=(X_val, y_val), callbacks=callbacks, verbose=1)

        model.load_weights(f"unet{unet_version}_fold{fold}_best.keras")

        # Test on one random sample from test set
        idx = np.random.randint(0, len(X_test))
        X_sample, y_sample = X_test[idx:idx+1], y_endo_test[idx:idx+1]
        y_pred_sample = model.predict(X_sample, batch_size=1)

        plt.figure(figsize=(12, 4))
        plt.subplot(1, 3, 1)
        plt.imshow(X_sample[0].squeeze(), cmap='gray')
        plt.title("Original Image")

        plt.subplot(1, 3, 2)
        plt.imshow(y_sample[0].squeeze(), cmap='gray')
        plt.title("Ground Truth")

        plt.subplot(1, 3, 3)
        plt.imshow(y_pred_sample[0].squeeze() > 0.5, cmap='gray')
        plt.title("Predicted Mask")

        plt.tight_layout()
        plt.show()

        # Collect fold results
        dice, mean_dist, hausdorff = calculate_metrics(y_sample[0], y_pred_sample[0] > 0.5)
        fold_results.append({
            'dice': dice,
            'mean_dist': mean_dist,
            'hausdorff': hausdorff,
        })

    # Print final cross-validation results
    print("\n=== Final Cross-Validation Results ===")
    avg_dice = np.mean([r['dice'] for r in fold_results])
    avg_mean_dist = np.mean([r['mean_dist'] for r in fold_results])
    avg_hausdorff = np.mean([r['hausdorff'] for r in fold_results])

    print(f"Average Dice: {avg_dice:.3f}")
    print(f"Average Mean Absolute Distance: {avg_mean_dist:.3f} mm")
    print(f"Average Hausdorff Distance: {avg_hausdorff:.3f} mm")

In [None]:
if __name__ == '__main__':
    main()

Loading and preprocessing dataset...


Loading training patients: 100%|██████████| 425/425 [00:00<00:00, 2880.03it/s]



Processing with augmentation: /content/drive/MyDrive/database_nifti/patient0450

Processing with augmentation: /content/drive/MyDrive/database_nifti/patient0340

Processing with augmentation: /content/drive/MyDrive/database_nifti/patient0004

Processing with augmentation: /content/drive/MyDrive/database_nifti/patient0176

Processing with augmentation: /content/drive/MyDrive/database_nifti/patient0044

Processing with augmentation: /content/drive/MyDrive/database_nifti/patient0219

Processing with augmentation: /content/drive/MyDrive/database_nifti/patient0197

Processing with augmentation: /content/drive/MyDrive/database_nifti/patient0264

Processing with augmentation: /content/drive/MyDrive/database_nifti/patient0130

Processing with augmentation: /content/drive/MyDrive/database_nifti/patient0200

Processing with augmentation: /content/drive/MyDrive/database_nifti/patient0012

Processing with augmentation: /content/drive/MyDrive/database_nifti/patient0161

Processing with augmentatio

Loading test patients:   0%|          | 0/75 [00:00<?, ?it/s]


Processing: /content/drive/MyDrive/database_nifti/patient0029

Processing: /content/drive/MyDrive/database_nifti/patient0487

Processing: /content/drive/MyDrive/database_nifti/patient0296

Processing: /content/drive/MyDrive/database_nifti/patient0160

Processing: /content/drive/MyDrive/database_nifti/patient0135

Processing: /content/drive/MyDrive/database_nifti/patient0317

Processing: /content/drive/MyDrive/database_nifti/patient0080

Processing: /content/drive/MyDrive/database_nifti/patient0411

Processing: /content/drive/MyDrive/database_nifti/patient0123

Processing: /content/drive/MyDrive/database_nifti/patient0289

Processing: /content/drive/MyDrive/database_nifti/patient0113

Processing: /content/drive/MyDrive/database_nifti/patient0292

Processing: /content/drive/MyDrive/database_nifti/patient0491

Processing: /content/drive/MyDrive/database_nifti/patient0233

Processing: /content/drive/MyDrive/database_nifti/patient0338

Processing: /content/drive/MyDrive/database_nifti/pati

Loading test patients: 100%|██████████| 75/75 [00:00<00:00, 2717.10it/s]


Processing: /content/drive/MyDrive/database_nifti/patient0336

Processing: /content/drive/MyDrive/database_nifti/patient0337

Processing: /content/drive/MyDrive/database_nifti/patient0120

Processing: /content/drive/MyDrive/database_nifti/patient0148

Processing: /content/drive/MyDrive/database_nifti/patient0186

Processing: /content/drive/MyDrive/database_nifti/patient0011

Processing: /content/drive/MyDrive/database_nifti/patient0243

Processing: /content/drive/MyDrive/database_nifti/patient0003

Processing: /content/drive/MyDrive/database_nifti/patient0259

Processing: /content/drive/MyDrive/database_nifti/patient0159

Processing: /content/drive/MyDrive/database_nifti/patient0153

Processing: /content/drive/MyDrive/database_nifti/patient0137

Processing: /content/drive/MyDrive/database_nifti/patient0121

Processing: /content/drive/MyDrive/database_nifti/patient0496

Processing: /content/drive/MyDrive/database_nifti/patient0363

Processing: /content/drive/MyDrive/database_nifti/pati




ValueError: Cannot have number of splits n_splits=10 greater than the number of samples: n_samples=0.

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
import nibabel as nib
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, BatchNormalization, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras import backend as K
import tensorflow as tf
from skimage.transform import resize
from tqdm import tqdm
from scipy.ndimage import distance_transform_edt
from skimage.morphology import erosion, dilation, square
from sklearn.model_selection import train_test_split
from scipy.ndimage import rotate
from skimage.transform import rescale

# Configuration
IMG_HEIGHT = 256
IMG_WIDTH = 256
IMG_CHANNELS = 1
BATCH_SIZE = 8  # U-Net 1 uses 8, U-Net 2 uses 4
EPOCHS = 100
INIT_LR = 1e-4
N_FOLDS = 10  # 10-fold cross-validation as in paper
SEED = 42

# Path to CAMUS dataset
DATA_PATH = "/content/drive/MyDrive/database_nifti"

# Define metrics
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

def bce_dice_loss(y_true, y_pred):
    return tf.keras.losses.binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)

# Data loading and preprocessing
def load_nifti_image(file_path):
    img = nib.load(file_path)
    data = img.get_fdata()
    return np.squeeze(data)  # Remove singleton dimensions

def preprocess_patient_with_augmentation(patient_folder):
    images = []
    masks_endo = []
    masks_epi = []
    masks_la = []

    views = ['2CH', '4CH']
    time_points = ['ED', 'ES']

    for view in views:
        for tp in time_points:
            base_name = os.path.basename(patient_folder)
            img_path = f"{patient_folder}/{base_name}_{view}_{tp}.nii.gz"
            gt_path = f"{patient_folder}/{base_name}_{view}_{tp}_gt.nii.gz"

            if not os.path.exists(img_path) or not os.path.exists(gt_path):
                continue

            try:
                img = nib.load(img_path).get_fdata()
                gt = nib.load(gt_path).get_fdata()

                img_resized = resize(img, (IMG_HEIGHT, IMG_WIDTH), preserve_range=True, anti_aliasing=True)
                gt_resized = resize(gt, (IMG_HEIGHT, IMG_WIDTH), preserve_range=True, anti_aliasing=False)

                img_resized = (img_resized - img_resized.min()) / (img_resized.max() - img_resized.min())

                # Augmentation: Apply rotations and scalings
                for angle in [0, 90, 180, 270]:
                    img_rotated = rotate(img_resized, angle, reshape=False)
                    gt_rotated = rotate(gt_resized, angle, reshape=False)

                    for scale in [0.9, 1.0, 1.1]:
                        img_scaled = rescale(img_rotated, scale, preserve_range=True, anti_aliasing=True)
                        gt_scaled = rescale(gt_rotated, scale, preserve_range=True, anti_aliasing=False)

                        # Ensure the scaled images are resized back to original dimensions
                        img_scaled_resized = resize(img_scaled, (IMG_HEIGHT, IMG_WIDTH), preserve_range=True, anti_aliasing=True)
                        gt_scaled_resized = resize(gt_scaled, (IMG_HEIGHT, IMG_WIDTH), preserve_range=True, anti_aliasing=False)

                        mask_endo = (gt_scaled_resized == 1).astype(np.float32)
                        mask_epi = (gt_scaled_resized == 2).astype(np.float32)
                        mask_la = (gt_scaled_resized == 3).astype(np.float32)

                        images.append(img_scaled_resized[..., np.newaxis])
                        masks_endo.append(mask_endo[..., np.newaxis])
                        masks_epi.append(mask_epi[..., np.newaxis])
                        masks_la.append(mask_la[..., np.newaxis])

            except Exception as e:
                print(f"Error processing {view}_{tp}: {str(e)}")
                continue

    return np.array(images), np.array(masks_endo), np.array(masks_epi), np.array(masks_la)

def preprocess_patient(patient_folder):
    images = []
    masks_endo = []
    masks_epi = []
    masks_la = []

    views = ['2CH', '4CH']
    time_points = ['ED', 'ES']

    for view in views:
        for tp in time_points:
            base_name = os.path.basename(patient_folder)
            img_path = f"{patient_folder}/{base_name}_{view}_{tp}.nii.gz"
            gt_path = f"{patient_folder}/{base_name}_{view}_{tp}_gt.nii.gz"

            if not os.path.exists(img_path) or not os.path.exists(gt_path):
                continue

            try:
                img = nib.load(img_path).get_fdata()
                gt = nib.load(gt_path).get_fdata()

                img_resized = resize(img, (IMG_HEIGHT, IMG_WIDTH), preserve_range=True, anti_aliasing=True)
                gt_resized = resize(gt, (IMG_HEIGHT, IMG_WIDTH), preserve_range=True, anti_aliasing=False)

                img_resized = (img_resized - img_resized.min()) / (img_resized.max() - img_resized.min())

                mask_endo = (gt_resized == 1).astype(np.float32)
                mask_epi = (gt_resized == 2).astype(np.float32)
                mask_la = (gt_resized == 3).astype(np.float32)

                images.append(img_resized[..., np.newaxis])
                masks_endo.append(mask_endo[..., np.newaxis])
                masks_epi.append(mask_epi[..., np.newaxis])
                masks_la.append(mask_la[..., np.newaxis])

            except Exception as e:
                print(f"Error processing {view}_{tp}: {str(e)}")
                continue

    return np.array(images), np.array(masks_endo), np.array(masks_epi), np.array(masks_la)

def load_dataset_with_split(base_path, test_ratio=0.15):
    patient_folders = sorted([
        os.path.join(base_path, f)
        for f in os.listdir(base_path)
        if f.startswith('patient') and os.path.isdir(os.path.join(base_path, f))
    ])

    # Verify we found patient folders
    if not patient_folders:
        raise ValueError(f"No patient folders found in {base_path}. Check your DATA_PATH.")

    # Split into train/test
    np.random.seed(SEED)
    np.random.shuffle(patient_folders)
    split_idx = int(len(patient_folders) * (1 - test_ratio))
    train_folders = patient_folders[:split_idx]
    test_folders = patient_folders[split_idx:]

    print(f"Found {len(patient_folders)} patients. Using {len(train_folders)} for training, {len(test_folders)} for testing.")

    # Preprocess train data with augmentation
    all_images_train, all_masks_endo_train, all_masks_epi_train, all_masks_la_train = [], [], [], []

    for patient_folder in tqdm(train_folders, desc="Loading training patients"):
        images, masks_endo, masks_epi, masks_la = preprocess_patient_with_augmentation(patient_folder)
        if len(images) > 0:  # Only add if we got data
            all_images_train.append(images)
            all_masks_endo_train.append(masks_endo)
            all_masks_epi_train.append(masks_epi)
            all_masks_la_train.append(masks_la)

    # Preprocess test data without augmentation
    all_images_test, all_masks_endo_test, all_masks_epi_test, all_masks_la_test = [], [], [], []

    for patient_folder in tqdm(test_folders, desc="Loading test patients"):
        images, masks_endo, masks_epi, masks_la = preprocess_patient(patient_folder)
        if len(images) > 0:  # Only add if we got data
            all_images_test.append(images)
            all_masks_endo_test.append(masks_endo)
            all_masks_epi_test.append(masks_epi)
            all_masks_la_test.append(masks_la)

    # Concatenate all data
    X_train = np.concatenate(all_images_train, axis=0) if all_images_train else np.array([])
    y_endo_train = np.concatenate(all_masks_endo_train, axis=0) if all_masks_endo_train else np.array([])
    y_epi_train = np.concatenate(all_masks_epi_train, axis=0) if all_masks_epi_train else np.array([])
    y_la_train = np.concatenate(all_masks_la_train, axis=0) if all_masks_la_train else np.array([])

    X_test = np.concatenate(all_images_test, axis=0) if all_images_test else np.array([])
    y_endo_test = np.concatenate(all_masks_endo_test, axis=0) if all_masks_endo_test else np.array([])
    y_epi_test = np.concatenate(all_masks_epi_test, axis=0) if all_masks_epi_test else np.array([])
    y_la_test = np.concatenate(all_masks_la_test, axis=0) if all_masks_la_test else np.array([])

    # Verify we have data
    if len(X_train) == 0:
        raise ValueError("No training data was loaded. Check your data paths and preprocessing.")
    if len(X_test) == 0:
        print("Warning: No test data was loaded.")

    print(f"\nTraining data shape: {X_train.shape}")
    print(f"Test data shape: {X_test.shape}")

    return X_train, y_endo_train, y_epi_train, y_la_train, X_test, y_endo_test, y_epi_test, y_la_test

# U-Net architectures (same as before)
def unet1(input_size=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)):
    inputs = Input(input_size)

    # Downsampling path
    conv1 = Conv2D(32, 3, activation='relu', padding='same')(inputs)
    conv1 = Conv2D(32, 3, activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, 3, activation='relu', padding='same')(pool1)
    conv2 = Conv2D(64, 3, activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, 3, activation='relu', padding='same')(pool2)
    conv3 = Conv2D(128, 3, activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, 3, activation='relu', padding='same')(pool3)
    conv4 = Conv2D(256, 3, activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    # Bottom
    conv5 = Conv2D(512, 3, activation='relu', padding='same')(pool4)
    conv5 = Conv2D(512, 3, activation='relu', padding='same')(conv5)

    # Upsampling path
    up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=-1)
    conv6 = Conv2D(256, 3, activation='relu', padding='same')(up6)
    conv6 = Conv2D(256, 3, activation='relu', padding='same')(conv6)

    up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=-1)
    conv7 = Conv2D(128, 3, activation='relu', padding='same')(up7)
    conv7 = Conv2D(128, 3, activation='relu', padding='same')(conv7)

    up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=-1)
    conv8 = Conv2D(64, 3, activation='relu', padding='same')(up8)
    conv8 = Conv2D(64, 3, activation='relu', padding='same')(conv8)

    up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=-1)
    conv9 = Conv2D(32, 3, activation='relu', padding='same')(up9)
    conv9 = Conv2D(32, 3, activation='relu', padding='same')(conv9)

    outputs = Conv2D(1, 1, activation='sigmoid')(conv9)

    model = Model(inputs=inputs, outputs=outputs)
    return model

def unet2(input_size=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)):
    inputs = Input(input_size)

    # Downsampling path
    conv1 = Conv2D(64, 3, padding='same')(inputs)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation('relu')(conv1)
    conv1 = Conv2D(64, 3, padding='same')(conv1)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation('relu')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, 3, padding='same')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation('relu')(conv2)
    conv2 = Conv2D(128, 3, padding='same')(conv2)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation('relu')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(256, 3, padding='same')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation('relu')(conv3)
    conv3 = Conv2D(256, 3, padding='same')(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation('relu')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(512, 3, padding='same')(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation('relu')(conv4)
    conv4 = Conv2D(512, 3, padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation('relu')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    # Bottom
    conv5 = Conv2D(1024, 3, padding='same')(pool4)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation('relu')(conv5)
    conv5 = Conv2D(1024, 3, padding='same')(conv5)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation('relu')(conv5)

    # Upsampling path
    up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=-1)
    conv6 = Conv2D(512, 3, padding='same')(up6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Activation('relu')(conv6)
    conv6 = Conv2D(512, 3, padding='same')(conv6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Activation('relu')(conv6)

    up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=-1)
    conv7 = Conv2D(256, 3, padding='same')(up7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Activation('relu')(conv7)
    conv7 = Conv2D(256, 3, padding='same')(conv7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Activation('relu')(conv7)

    up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=-1)
    conv8 = Conv2D(128, 3, padding='same')(up8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Activation('relu')(conv8)
    conv8 = Conv2D(128, 3, padding='same')(conv8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Activation('relu')(conv8)

    up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=-1)
    conv9 = Conv2D(64, 3, padding='same')(up9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Activation('relu')(conv9)
    conv9 = Conv2D(64, 3, padding='same')(conv9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Activation('relu')(conv9)

    outputs = Conv2D(1, 1, activation='sigmoid')(conv9)

    model = Model(inputs=inputs, outputs=outputs)
    return model

# Evaluation metrics (same as before)
def calculate_metrics(y_true, y_pred):
    y_pred = (y_pred > 0.5).astype(np.float32)

    # Dice coefficient
    intersection = np.sum(y_true * y_pred)
    dice = (2. * intersection + 1.) / (np.sum(y_true) + np.sum(y_pred) + 1.)

    # Mean absolute distance
    dt_true = distance_transform_edt(1 - y_true.squeeze())
    dt_pred = distance_transform_edt(1 - y_pred.squeeze())
    mean_dist = (np.mean(dt_pred[y_true.squeeze() > 0.5]) +
                 np.mean(dt_true[y_pred.squeeze() > 0.5])) / 2

    # Hausdorff distance (approximation)
    contour_true = y_true.squeeze() - erosion(y_true.squeeze(), square(3))
    contour_pred = y_pred.squeeze() - erosion(y_pred.squeeze(), square(3))

    if np.sum(contour_true) == 0 or np.sum(contour_pred) == 0:
        hausdorff = np.inf
    else:
        dt_true_contour = distance_transform_edt(1 - contour_true)
        dt_pred_contour = distance_transform_edt(1 - contour_pred)
        hd_true = np.max(contour_pred * dt_true_contour)
        hd_pred = np.max(contour_true * dt_pred_contour)
        hausdorff = max(hd_true, hd_pred)

    return dice, mean_dist, hausdorff

def main():
    print("Loading and preprocessing dataset...")
    try:
        (X_train, y_endo_train, y_epi_train, y_la_train,
         X_test, y_endo_test, y_epi_test, y_la_test) = load_dataset_with_split(DATA_PATH)

        # Verify we have data
        if len(X_train) == 0:
            raise ValueError("No training data available after loading.")

        print(f"\nTraining on {len(X_train)} samples")
        if len(X_test) > 0:
            print(f"Testing on {len(X_test)} samples")

        # Choose which U-Net to use
        unet_version = 1
        batch_size = 8 if unet_version == 1 else 4

        # Adjust number of folds if we don't have enough samples
        actual_folds = min(N_FOLDS, len(X_train))
        if actual_folds < N_FOLDS:
            print(f"Reducing number of folds from {N_FOLDS} to {actual_folds} due to limited samples")

        kf = KFold(n_splits=actual_folds, shuffle=True, random_state=SEED)
        fold_results = []

        for fold, (train_idx, val_idx) in enumerate(kf.split(X_train)):
            print(f"\n=== Fold {fold + 1}/{actual_folds} ===")
            X_fold_train, X_val = X_train[train_idx], X_train[val_idx]
            y_fold_train, y_val = y_endo_train[train_idx], y_endo_train[val_idx]

            model = unet1() if unet_version == 1 else unet2()
            model.compile(optimizer=Adam(learning_rate=INIT_LR),
                          loss=bce_dice_loss,
                          metrics=[dice_coef, 'accuracy'])

            callbacks = [
                ModelCheckpoint(f"unet{unet_version}_fold{fold}_best.keras",
                                monitor='val_dice_coef',
                                mode='max',
                                save_best_only=True,
                                verbose=1),
                EarlyStopping(monitor='val_dice_coef', patience=15, mode='max', verbose=1),
                ReduceLROnPlateau(monitor='val_dice_coef', factor=0.5, patience=5, min_lr=1e-6, mode='max', verbose=1)
            ]

            print(f"Training on {len(X_fold_train)} samples, validating on {len(X_val)} samples")
            history = model.fit(
                X_fold_train, y_fold_train,
                batch_size=batch_size,
                epochs=EPOCHS,
                validation_data=(X_val, y_val),
                callbacks=callbacks,
                verbose=1
            )

            model.load_weights(f"unet{unet_version}_fold{fold}_best.keras")

            # Test on one random sample from test set if available
            if len(X_test) > 0:
                idx = np.random.randint(0, len(X_test))
                X_sample, y_sample = X_test[idx:idx+1], y_endo_test[idx:idx+1]
                y_pred_sample = model.predict(X_sample, batch_size=1)

                plt.figure(figsize=(12, 4))
                plt.subplot(1, 3, 1)
                plt.imshow(X_sample[0].squeeze(), cmap='gray')
                plt.title("Original Image")

                plt.subplot(1, 3, 2)
                plt.imshow(y_sample[0].squeeze(), cmap='gray')
                plt.title("Ground Truth")

                plt.subplot(1, 3, 3)
                plt.imshow(y_pred_sample[0].squeeze() > 0.5, cmap='gray')
                plt.title("Predicted Mask")

                plt.tight_layout()
                plt.show()

                # Collect fold results
                dice, mean_dist, hausdorff = calculate_metrics(y_sample[0], y_pred_sample[0] > 0.5)
                fold_results.append({
                    'dice': dice,
                    'mean_dist': mean_dist,
                    'hausdorff': hausdorff,
                })

        # Print final cross-validation results if we have any
        if fold_results:
            print("\n=== Final Cross-Validation Results ===")
            avg_dice = np.mean([r['dice'] for r in fold_results])
            avg_mean_dist = np.mean([r['mean_dist'] for r in fold_results])
            avg_hausdorff = np.mean([r['hausdorff'] for r in fold_results])

            print(f"Average Dice: {avg_dice:.3f}")
            print(f"Average Mean Absolute Distance: {avg_mean_dist:.3f} mm")
            print(f"Average Hausdorff Distance: {avg_hausdorff:.3f} mm")
        else:
            print("\nNo fold results to report (possibly no test data available)")

    except Exception as e:
        print(f"Error in main execution: {str(e)}")
        raise

if __name__ == '__main__':
    main()

Loading and preprocessing dataset...
Found 500 patients. Using 425 for training, 75 for testing.


Loading training patients:  28%|██▊       | 118/425 [07:10<18:40,  3.65s/it]


KeyboardInterrupt: 

## Here starts the U-net code with the new split imported form kaggle


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
import nibabel as nib
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, BatchNormalization, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras import backend as K
import tensorflow as tf
from skimage.transform import resize
from tqdm import tqdm
from scipy.ndimage import distance_transform_edt
from skimage.morphology import erosion, dilation, square
from sklearn.model_selection import train_test_split
from scipy.ndimage import rotate
from skimage.transform import rescale

In [None]:
# Configuration
IMG_HEIGHT = 256
IMG_WIDTH = 256
IMG_CHANNELS = 1
BATCH_SIZE = 4  # U-Net 1 uses 8, U-Net 2 uses 4
EPOCHS = 100
INIT_LR = 1e-4
N_FOLDS = 10  # 10-fold cross-validation as in paper
SEED = 42

# Path to CAMUS dataset
DATA_PATH = "/content/drive/MyDrive/database_nifti"


In [None]:
# Define metrics
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

def bce_dice_loss(y_true, y_pred):
    return tf.keras.losses.binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)

# Data loading and preprocessing
def load_nifti_image(file_path):
    img = nib.load(file_path)
    data = img.get_fdata()
    return np.squeeze(data)  # Remove singleton dimensions

def preprocess_patient_with_augmentation(patient_folder):
    images = []
    masks_endo = []
    masks_epi = []
    masks_la = []

    views = ['2CH', '4CH']
    time_points = ['ED', 'ES']

    for view in views:
        for tp in time_points:
            base_name = os.path.basename(patient_folder)
            img_path = f"{patient_folder}/{base_name}_{view}_{tp}.nii.gz"
            gt_path = f"{patient_folder}/{base_name}_{view}_{tp}_gt.nii.gz"

            if not os.path.exists(img_path) or not os.path.exists(gt_path):
                continue

            try:
                img = nib.load(img_path).get_fdata()
                gt = nib.load(gt_path).get_fdata()

                img_resized = resize(img, (IMG_HEIGHT, IMG_WIDTH), preserve_range=True, anti_aliasing=True)
                gt_resized = resize(gt, (IMG_HEIGHT, IMG_WIDTH), preserve_range=True, anti_aliasing=False)

                img_resized = (img_resized - img_resized.min()) / (img_resized.max() - img_resized.min())

                # Augmentation: Apply rotations and scalings
                for angle in [0, 90, 180, 270]:
                    img_rotated = rotate(img_resized, angle, reshape=False)
                    gt_rotated = rotate(gt_resized, angle, reshape=False)

                    for scale in [0.9, 1.0, 1.1]:
                        img_scaled = rescale(img_rotated, scale, preserve_range=True, anti_aliasing=True)
                        gt_scaled = rescale(gt_rotated, scale, preserve_range=True, anti_aliasing=False)

                        # Ensure the scaled images are resized back to original dimensions
                        img_scaled_resized = resize(img_scaled, (IMG_HEIGHT, IMG_WIDTH), preserve_range=True, anti_aliasing=True)
                        gt_scaled_resized = resize(gt_scaled, (IMG_HEIGHT, IMG_WIDTH), preserve_range=True, anti_aliasing=False)

                        mask_endo = (gt_scaled_resized == 1).astype(np.float32)
                        mask_epi = (gt_scaled_resized == 2).astype(np.float32)
                        mask_la = (gt_scaled_resized == 3).astype(np.float32)

                        images.append(img_scaled_resized[..., np.newaxis])
                        masks_endo.append(mask_endo[..., np.newaxis])
                        masks_epi.append(mask_epi[..., np.newaxis])
                        masks_la.append(mask_la[..., np.newaxis])

            except Exception as e:
                print(f"Error processing {view}_{tp}: {str(e)}")
                continue

    return np.array(images), np.array(masks_endo), np.array(masks_epi), np.array(masks_la)

def preprocess_patient(patient_folder):
    images = []
    masks_endo = []
    masks_epi = []
    masks_la = []

    views = ['2CH', '4CH']
    time_points = ['ED', 'ES']

    for view in views:
        for tp in time_points:
            base_name = os.path.basename(patient_folder)
            img_path = f"{patient_folder}/{base_name}_{view}_{tp}.nii.gz"
            gt_path = f"{patient_folder}/{base_name}_{view}_{tp}_gt.nii.gz"

            if not os.path.exists(img_path) or not os.path.exists(gt_path):
                continue

            try:
                img = nib.load(img_path).get_fdata()
                gt = nib.load(gt_path).get_fdata()

                img_resized = resize(img, (IMG_HEIGHT, IMG_WIDTH), preserve_range=True, anti_aliasing=True)
                gt_resized = resize(gt, (IMG_HEIGHT, IMG_WIDTH), preserve_range=True, anti_aliasing=False)

                img_resized = (img_resized - img_resized.min()) / (img_resized.max() - img_resized.min())

                mask_endo = (gt_resized == 1).astype(np.float32)
                mask_epi = (gt_resized == 2).astype(np.float32)
                mask_la = (gt_resized == 3).astype(np.float32)

                images.append(img_resized[..., np.newaxis])
                masks_endo.append(mask_endo[..., np.newaxis])
                masks_epi.append(mask_epi[..., np.newaxis])
                masks_la.append(mask_la[..., np.newaxis])

            except Exception as e:
                print(f"Error processing {view}_{tp}: {str(e)}")
                continue

    return np.array(images), np.array(masks_endo), np.array(masks_epi), np.array(masks_la)

def load_dataset_with_split(base_path, test_ratio=0.15):
    patient_folders = sorted([
        os.path.join(base_path, f)
        for f in os.listdir(base_path)
        if f.startswith('patient') and os.path.isdir(os.path.join(base_path, f))
    ])

    # Verify we found patient folders
    if not patient_folders:
        raise ValueError(f"No patient folders found in {base_path}. Check your DATA_PATH.")

    # Split into train/test
    np.random.seed(SEED)
    np.random.shuffle(patient_folders)
    split_idx = int(len(patient_folders) * (1 - test_ratio))
    train_folders = patient_folders[:split_idx]
    test_folders = patient_folders[split_idx:]

    print(f"Found {len(patient_folders)} patients. Using {len(train_folders)} for training, {len(test_folders)} for testing.")

    # Preprocess train data with augmentation
    all_images_train, all_masks_endo_train, all_masks_epi_train, all_masks_la_train = [], [], [], []

    for patient_folder in tqdm(train_folders, desc="Loading training patients"):
        images, masks_endo, masks_epi, masks_la = preprocess_patient_with_augmentation(patient_folder)
        if len(images) > 0:  # Only add if we got data
            all_images_train.append(images)
            all_masks_endo_train.append(masks_endo)
            all_masks_epi_train.append(masks_epi)
            all_masks_la_train.append(masks_la)

    # Preprocess test data without augmentation
    all_images_test, all_masks_endo_test, all_masks_epi_test, all_masks_la_test = [], [], [], []

    for patient_folder in tqdm(test_folders, desc="Loading test patients"):
        images, masks_endo, masks_epi, masks_la = preprocess_patient(patient_folder)
        if len(images) > 0:  # Only add if we got data
            all_images_test.append(images)
            all_masks_endo_test.append(masks_endo)
            all_masks_epi_test.append(masks_epi)
            all_masks_la_test.append(masks_la)

    # Concatenate all data
    X_train = np.concatenate(all_images_train, axis=0) if all_images_train else np.array([])
    y_endo_train = np.concatenate(all_masks_endo_train, axis=0) if all_masks_endo_train else np.array([])
    y_epi_train = np.concatenate(all_masks_epi_train, axis=0) if all_masks_epi_train else np.array([])
    y_la_train = np.concatenate(all_masks_la_train, axis=0) if all_masks_la_train else np.array([])

    X_test = np.concatenate(all_images_test, axis=0) if all_images_test else np.array([])
    y_endo_test = np.concatenate(all_masks_endo_test, axis=0) if all_masks_endo_test else np.array([])
    y_epi_test = np.concatenate(all_masks_epi_test, axis=0) if all_masks_epi_test else np.array([])
    y_la_test = np.concatenate(all_masks_la_test, axis=0) if all_masks_la_test else np.array([])

    # Verify we have data
    if len(X_train) == 0:
        raise ValueError("No training data was loaded. Check your data paths and preprocessing.")
    if len(X_test) == 0:
        print("Warning: No test data was loaded.")

    print(f"\nTraining data shape: {X_train.shape}")
    print(f"Test data shape: {X_test.shape}")

    return X_train, y_endo_train, y_epi_train, y_la_train, X_test, y_endo_test, y_epi_test, y_la_test


In [None]:
# U-Net architectures (same as before)
def unet1(input_size=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)):
    inputs = Input(input_size)

    # Downsampling path
    conv1 = Conv2D(32, 3, activation='relu', padding='same')(inputs)
    conv1 = Conv2D(32, 3, activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, 3, activation='relu', padding='same')(pool1)
    conv2 = Conv2D(64, 3, activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, 3, activation='relu', padding='same')(pool2)
    conv3 = Conv2D(128, 3, activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, 3, activation='relu', padding='same')(pool3)
    conv4 = Conv2D(256, 3, activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    # Bottom
    conv5 = Conv2D(512, 3, activation='relu', padding='same')(pool4)
    conv5 = Conv2D(512, 3, activation='relu', padding='same')(conv5)

    # Upsampling path
    up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=-1)
    conv6 = Conv2D(256, 3, activation='relu', padding='same')(up6)
    conv6 = Conv2D(256, 3, activation='relu', padding='same')(conv6)

    up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=-1)
    conv7 = Conv2D(128, 3, activation='relu', padding='same')(up7)
    conv7 = Conv2D(128, 3, activation='relu', padding='same')(conv7)

    up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=-1)
    conv8 = Conv2D(64, 3, activation='relu', padding='same')(up8)
    conv8 = Conv2D(64, 3, activation='relu', padding='same')(conv8)

    up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=-1)
    conv9 = Conv2D(32, 3, activation='relu', padding='same')(up9)
    conv9 = Conv2D(32, 3, activation='relu', padding='same')(conv9)

    outputs = Conv2D(1, 1, activation='sigmoid')(conv9)

    model = Model(inputs=inputs, outputs=outputs)
    return model

def unet2(input_size=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)):
    inputs = Input(input_size)

    # Downsampling path
    conv1 = Conv2D(64, 3, padding='same')(inputs)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation('relu')(conv1)
    conv1 = Conv2D(64, 3, padding='same')(conv1)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation('relu')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, 3, padding='same')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation('relu')(conv2)
    conv2 = Conv2D(128, 3, padding='same')(conv2)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation('relu')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(256, 3, padding='same')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation('relu')(conv3)
    conv3 = Conv2D(256, 3, padding='same')(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation('relu')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(512, 3, padding='same')(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation('relu')(conv4)
    conv4 = Conv2D(512, 3, padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation('relu')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    # Bottom
    conv5 = Conv2D(1024, 3, padding='same')(pool4)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation('relu')(conv5)
    conv5 = Conv2D(1024, 3, padding='same')(conv5)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation('relu')(conv5)

    # Upsampling path
    up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=-1)
    conv6 = Conv2D(512, 3, padding='same')(up6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Activation('relu')(conv6)
    conv6 = Conv2D(512, 3, padding='same')(conv6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Activation('relu')(conv6)

    up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=-1)
    conv7 = Conv2D(256, 3, padding='same')(up7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Activation('relu')(conv7)
    conv7 = Conv2D(256, 3, padding='same')(conv7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Activation('relu')(conv7)

    up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=-1)
    conv8 = Conv2D(128, 3, padding='same')(up8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Activation('relu')(conv8)
    conv8 = Conv2D(128, 3, padding='same')(conv8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Activation('relu')(conv8)

    up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=-1)
    conv9 = Conv2D(64, 3, padding='same')(up9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Activation('relu')(conv9)
    conv9 = Conv2D(64, 3, padding='same')(conv9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Activation('relu')(conv9)

    outputs = Conv2D(1, 1, activation='sigmoid')(conv9)

    model = Model(inputs=inputs, outputs=outputs)
    return model


In [None]:
# Evaluation metrics (same as before)
def calculate_metrics(y_true, y_pred):
    y_pred = (y_pred > 0.5).astype(np.float32)

    # Dice coefficient
    intersection = np.sum(y_true * y_pred)
    dice = (2. * intersection + 1.) / (np.sum(y_true) + np.sum(y_pred) + 1.)

    # Mean absolute distance
    dt_true = distance_transform_edt(1 - y_true.squeeze())
    dt_pred = distance_transform_edt(1 - y_pred.squeeze())
    mean_dist = (np.mean(dt_pred[y_true.squeeze() > 0.5]) +
                 np.mean(dt_true[y_pred.squeeze() > 0.5])) / 2

    # Hausdorff distance (approximation)
    contour_true = y_true.squeeze() - erosion(y_true.squeeze(), square(3))
    contour_pred = y_pred.squeeze() - erosion(y_pred.squeeze(), square(3))

    if np.sum(contour_true) == 0 or np.sum(contour_pred) == 0:
        hausdorff = np.inf
    else:
        dt_true_contour = distance_transform_edt(1 - contour_true)
        dt_pred_contour = distance_transform_edt(1 - contour_pred)
        hd_true = np.max(contour_pred * dt_true_contour)
        hd_pred = np.max(contour_true * dt_pred_contour)
        hausdorff = max(hd_true, hd_pred)

    return dice, mean_dist, hausdorff

In [None]:
def main():
    print("Loading and preprocessing dataset...")
    try:
        (X_train, y_endo_train, y_epi_train, y_la_train,
         X_test, y_endo_test, y_epi_test, y_la_test) = load_dataset_with_split(DATA_PATH)

        # Verify we have data
        if len(X_train) == 0:
            raise ValueError("No training data available after loading.")

        print(f"\nTraining on {len(X_train)} samples")
        if len(X_test) > 0:
            print(f"Testing on {len(X_test)} samples")

        # Choose which U-Net to use
        unet_version = 1
        batch_size = 8 if unet_version == 1 else 4

        # Adjust number of folds if we don't have enough samples
        actual_folds = min(N_FOLDS, len(X_train))
        if actual_folds < N_FOLDS:
            print(f"Reducing number of folds from {N_FOLDS} to {actual_folds} due to limited samples")

        kf = KFold(n_splits=actual_folds, shuffle=True, random_state=SEED)
        fold_results = []

        for fold, (train_idx, val_idx) in enumerate(kf.split(X_train)):
            print(f"\n=== Fold {fold + 1}/{actual_folds} ===")
            X_fold_train, X_val = X_train[train_idx], X_train[val_idx]
            y_fold_train, y_val = y_endo_train[train_idx], y_endo_train[val_idx]

            model = unet1() if unet_version == 1 else unet2()
            model.compile(optimizer=Adam(learning_rate=INIT_LR),
                          loss=bce_dice_loss,
                          metrics=[dice_coef, 'accuracy'])

            callbacks = [
                ModelCheckpoint(f"unet{unet_version}_fold{fold}_best.keras",
                                monitor='val_dice_coef',
                                mode='max',
                                save_best_only=True,
                                verbose=1),
                EarlyStopping(monitor='val_dice_coef', patience=15, mode='max', verbose=1),
                ReduceLROnPlateau(monitor='val_dice_coef', factor=0.5, patience=5, min_lr=1e-6, mode='max', verbose=1)
            ]

            print(f"Training on {len(X_fold_train)} samples, validating on {len(X_val)} samples")
            history = model.fit(
                X_fold_train, y_fold_train,
                batch_size=batch_size,
                epochs=EPOCHS,
                validation_data=(X_val, y_val),
                callbacks=callbacks,
                verbose=1
            )

            model.load_weights(f"unet{unet_version}_fold{fold}_best.keras")

            # Test on one random sample from test set if available
            if len(X_test) > 0:
                idx = np.random.randint(0, len(X_test))
                X_sample, y_sample = X_test[idx:idx+1], y_endo_test[idx:idx+1]
                y_pred_sample = model.predict(X_sample, batch_size=1)

                plt.figure(figsize=(12, 4))
                plt.subplot(1, 3, 1)
                plt.imshow(X_sample[0].squeeze(), cmap='gray')
                plt.title("Original Image")

                plt.subplot(1, 3, 2)
                plt.imshow(y_sample[0].squeeze(), cmap='gray')
                plt.title("Ground Truth")

                plt.subplot(1, 3, 3)
                plt.imshow(y_pred_sample[0].squeeze() > 0.5, cmap='gray')
                plt.title("Predicted Mask")

                plt.tight_layout()
                plt.show()

                # Collect fold results
                dice, mean_dist, hausdorff = calculate_metrics(y_sample[0], y_pred_sample[0] > 0.5)
                fold_results.append({
                    'dice': dice,
                    'mean_dist': mean_dist,
                    'hausdorff': hausdorff,
                })

        # Print final cross-validation results if we have any
        if fold_results:
            print("\n=== Final Cross-Validation Results ===")
            avg_dice = np.mean([r['dice'] for r in fold_results])
            avg_mean_dist = np.mean([r['mean_dist'] for r in fold_results])
            avg_hausdorff = np.mean([r['hausdorff'] for r in fold_results])

            print(f"Average Dice: {avg_dice:.3f}")
            print(f"Average Mean Absolute Distance: {avg_mean_dist:.3f} mm")
            print(f"Average Hausdorff Distance: {avg_hausdorff:.3f} mm")
        else:
            print("\nNo fold results to report (possibly no test data available)")

    except Exception as e:
        print(f"Error in main execution: {str(e)}")
        raise

In [None]:
if __name__ == '__main__':
    main()

Loading and preprocessing dataset...
Found 500 patients. Using 425 for training, 75 for testing.


Loading training patients: 100%|██████████| 425/425 [1:00:10<00:00,  8.50s/it]
Loading test patients: 100%|██████████| 75/75 [09:22<00:00,  7.50s/it]



Training data shape: (20400, 256, 256, 1)
Test data shape: (300, 256, 256, 1)

Training on 20400 samples
Testing on 300 samples

=== Fold 1/10 ===
Training on 18360 samples, validating on 2040 samples


 # Optimized version


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
import nibabel as nib
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, BatchNormalization, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras import backend as K
import tensorflow as tf
from skimage.transform import resize
from tqdm import tqdm
from scipy.ndimage import distance_transform_edt
from skimage.morphology import erosion, dilation, square
from sklearn.model_selection import train_test_split
from scipy.ndimage import rotate
from skimage.transform import rescale
import gc

# Configuration
IMG_HEIGHT = 256
IMG_WIDTH = 256
IMG_CHANNELS = 1
BATCH_SIZE = 4  # Reduced from original
EPOCHS = 100
INIT_LR = 1e-4
N_FOLDS = 5  # Reduced from 10 to 5 for faster testing
SEED = 42

# Path to CAMUS dataset
DATA_PATH = "/content/drive/MyDrive/database_nifti"

# Define metrics
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

def bce_dice_loss(y_true, y_pred):
    return tf.keras.losses.binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)

# Data Generator Class
class CAMUSDataGenerator(tf.keras.utils.Sequence):
    def __init__(self, patient_folders, target_shape=(256, 256), batch_size=8, augment=True, shuffle=True):
        self.patient_folders = patient_folders
        self.target_shape = target_shape
        self.batch_size = batch_size
        self.augment = augment
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        return int(np.ceil(len(self.patient_folders) / self.batch_size))

    def __getitem__(self, index):
        batch_folders = self.patient_folders[index*self.batch_size:(index+1)*self.batch_size]
        X, y = self.__load_and_process_batch(batch_folders)
        return X, y

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.patient_folders)

    def __load_and_process_batch(self, batch_folders):
        batch_images = []
        batch_masks = []

        for patient_folder in batch_folders:
            images, masks = self.__load_patient(patient_folder)
            batch_images.extend(images)
            batch_masks.extend(masks)

        return np.array(batch_images), np.array(batch_masks)

    def __load_patient(self, patient_folder):
        images = []
        masks = []

        views = ['2CH', '4CH']
        time_points = ['ED', 'ES']

        for view in views:
            for tp in time_points:
                base_name = os.path.basename(patient_folder)
                img_path = f"{patient_folder}/{base_name}_{view}_{tp}.nii.gz"
                gt_path = f"{patient_folder}/{base_name}_{view}_{tp}_gt.nii.gz"

                if not os.path.exists(img_path) or not os.path.exists(gt_path):
                    continue

                try:
                    img = nib.load(img_path).get_fdata()
                    gt = nib.load(gt_path).get_fdata()

                    # Basic preprocessing
                    img_resized = resize(img, self.target_shape, preserve_range=True, anti_aliasing=True)
                    gt_resized = resize(gt, self.target_shape, preserve_range=True, anti_aliasing=False)

                    img_resized = (img_resized - img_resized.min()) / (img_resized.max() - img_resized.min())

                    # Only create endocardium mask for this example
                    mask = (gt_resized == 1).astype(np.float32)

                    if self.augment:
                        # Reduced augmentation - only 2 angles and 2 scales
                        for angle in [0, 90]:  # Only 0 and 90 degrees
                            img_rotated = rotate(img_resized, angle, reshape=False)
                            mask_rotated = rotate(mask, angle, reshape=False)

                            for scale in [1.0, 1.05]:  # Only original and slight zoom
                                img_scaled = rescale(img_rotated, scale, preserve_range=True, anti_aliasing=True)
                                mask_scaled = rescale(mask_rotated, scale, preserve_range=True, anti_aliasing=False)

                                # Ensure correct size
                                img_scaled = resize(img_scaled, self.target_shape, preserve_range=True, anti_aliasing=True)
                                mask_scaled = resize(mask_scaled, self.target_shape, preserve_range=True, anti_aliasing=False)

                                images.append(img_scaled[..., np.newaxis])
                                masks.append(mask_scaled[..., np.newaxis])
                    else:
                        # No augmentation for validation/test
                        images.append(img_resized[..., np.newaxis])
                        masks.append(mask[..., np.newaxis])

                except Exception as e:
                    print(f"Error processing {view}_{tp}: {str(e)}")
                    continue

        return images, masks

# U-Net architectures (unchanged from original)
def unet1(input_size=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)):
    inputs = Input(input_size)

    # Downsampling path
    conv1 = Conv2D(32, 3, activation='relu', padding='same')(inputs)
    conv1 = Conv2D(32, 3, activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, 3, activation='relu', padding='same')(pool1)
    conv2 = Conv2D(64, 3, activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, 3, activation='relu', padding='same')(pool2)
    conv3 = Conv2D(128, 3, activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, 3, activation='relu', padding='same')(pool3)
    conv4 = Conv2D(256, 3, activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    # Bottom
    conv5 = Conv2D(512, 3, activation='relu', padding='same')(pool4)
    conv5 = Conv2D(512, 3, activation='relu', padding='same')(conv5)

    # Upsampling path
    up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=-1)
    conv6 = Conv2D(256, 3, activation='relu', padding='same')(up6)
    conv6 = Conv2D(256, 3, activation='relu', padding='same')(conv6)

    up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=-1)
    conv7 = Conv2D(128, 3, activation='relu', padding='same')(up7)
    conv7 = Conv2D(128, 3, activation='relu', padding='same')(conv7)

    up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=-1)
    conv8 = Conv2D(64, 3, activation='relu', padding='same')(up8)
    conv8 = Conv2D(64, 3, activation='relu', padding='same')(conv8)

    up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=-1)
    conv9 = Conv2D(32, 3, activation='relu', padding='same')(up9)
    conv9 = Conv2D(32, 3, activation='relu', padding='same')(conv9)

    outputs = Conv2D(1, 1, activation='sigmoid')(conv9)

    model = Model(inputs=inputs, outputs=outputs)
    return model

def unet2(input_size=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)):
    inputs = Input(input_size)

    # Downsampling path
    conv1 = Conv2D(64, 3, padding='same')(inputs)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation('relu')(conv1)
    conv1 = Conv2D(64, 3, padding='same')(conv1)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation('relu')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, 3, padding='same')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation('relu')(conv2)
    conv2 = Conv2D(128, 3, padding='same')(conv2)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation('relu')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(256, 3, padding='same')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation('relu')(conv3)
    conv3 = Conv2D(256, 3, padding='same')(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation('relu')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(512, 3, padding='same')(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation('relu')(conv4)
    conv4 = Conv2D(512, 3, padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation('relu')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    # Bottom
    conv5 = Conv2D(1024, 3, padding='same')(pool4)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation('relu')(conv5)
    conv5 = Conv2D(1024, 3, padding='same')(conv5)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation('relu')(conv5)

    # Upsampling path
    up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=-1)
    conv6 = Conv2D(512, 3, padding='same')(up6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Activation('relu')(conv6)
    conv6 = Conv2D(512, 3, padding='same')(conv6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Activation('relu')(conv6)

    up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=-1)
    conv7 = Conv2D(256, 3, padding='same')(up7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Activation('relu')(conv7)
    conv7 = Conv2D(256, 3, padding='same')(conv7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Activation('relu')(conv7)

    up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=-1)
    conv8 = Conv2D(128, 3, padding='same')(up8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Activation('relu')(conv8)
    conv8 = Conv2D(128, 3, padding='same')(conv8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Activation('relu')(conv8)

    up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=-1)
    conv9 = Conv2D(64, 3, padding='same')(up9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Activation('relu')(conv9)
    conv9 = Conv2D(64, 3, padding='same')(conv9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Activation('relu')(conv9)

    outputs = Conv2D(1, 1, activation='sigmoid')(conv9)

    model = Model(inputs=inputs, outputs=outputs)
    return model

def calculate_metrics(y_true, y_pred):
    y_pred = (y_pred > 0.5).astype(np.float32)

    # Dice coefficient
    intersection = np.sum(y_true * y_pred)
    dice = (2. * intersection + 1.) / (np.sum(y_true) + np.sum(y_pred) + 1.)

    # Mean absolute distance
    dt_true = distance_transform_edt(1 - y_true.squeeze())
    dt_pred = distance_transform_edt(1 - y_pred.squeeze())
    mean_dist = (np.mean(dt_pred[y_true.squeeze() > 0.5]) +
                 np.mean(dt_true[y_pred.squeeze() > 0.5])) / 2

    # Hausdorff distance (approximation)
    contour_true = y_true.squeeze() - erosion(y_true.squeeze(), square(3))
    contour_pred = y_pred.squeeze() - erosion(y_pred.squeeze(), square(3))

    if np.sum(contour_true) == 0 or np.sum(contour_pred) == 0:
        hausdorff = np.inf
    else:
        dt_true_contour = distance_transform_edt(1 - contour_true)
        dt_pred_contour = distance_transform_edt(1 - contour_pred)
        hd_true = np.max(contour_pred * dt_true_contour)
        hd_pred = np.max(contour_true * dt_pred_contour)
        hausdorff = max(hd_true, hd_pred)

    return dice, mean_dist, hausdorff

def main():
    print("Setting up data generators...")

    # Get list of patient folders
    patient_folders = sorted([
        os.path.join(DATA_PATH, f)
        for f in os.listdir(DATA_PATH)
        if f.startswith('patient') and os.path.isdir(os.path.join(DATA_PATH, f))
    ])

    if not patient_folders:
        raise ValueError(f"No patient folders found in {DATA_PATH}")

    # Split into train/test
    np.random.seed(SEED)
    np.random.shuffle(patient_folders)
    split_idx = int(len(patient_folders) * 0.85)  # 85% train, 15% test
    train_folders = patient_folders[:split_idx]
    test_folders = patient_folders[split_idx:]

    print(f"Found {len(patient_folders)} patients. Using {len(train_folders)} for training, {len(test_folders)} for testing.")

    # Create generators
    train_gen = CAMUSDataGenerator(train_folders, batch_size=BATCH_SIZE, augment=True)
    test_gen = CAMUSDataGenerator(test_folders, batch_size=BATCH_SIZE, augment=False)

    # Choose which U-Net to use
    unet_version = 1

    # Adjust number of folds if we don't have enough samples
    actual_folds = min(N_FOLDS, len(train_folders))
    if actual_folds < N_FOLDS:
        print(f"Reducing number of folds from {N_FOLDS} to {actual_folds} due to limited samples")

    kf = KFold(n_splits=actual_folds, shuffle=True, random_state=SEED)
    fold_results = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(train_folders)):
        print(f"\n=== Fold {fold + 1}/{actual_folds} ===")

        # Create train/val generators for this fold
        fold_train_folders = np.array(train_folders)[train_idx]
        fold_val_folders = np.array(train_folders)[val_idx]

        train_gen = CAMUSDataGenerator(fold_train_folders, batch_size=BATCH_SIZE, augment=True)
        val_gen = CAMUSDataGenerator(fold_val_folders, batch_size=BATCH_SIZE, augment=False)

        # Create and compile model
        model = unet1() if unet_version == 1 else unet2()
        model.compile(optimizer=Adam(learning_rate=INIT_LR),
                      loss=bce_dice_loss,
                      metrics=[dice_coef, 'accuracy'])

        callbacks = [
            ModelCheckpoint(f"unet{unet_version}_fold{fold}_best.keras",
                            monitor='val_dice_coef',
                            mode='max',
                            save_best_only=True,
                            verbose=1),
            EarlyStopping(monitor='val_dice_coef', patience=15, mode='max', verbose=1),
            ReduceLROnPlateau(monitor='val_dice_coef', factor=0.5, patience=5, min_lr=1e-6, mode='max', verbose=1)
        ]

        # Train model
        history = model.fit(
            train_gen,
            epochs=EPOCHS,
            validation_data=val_gen,
            callbacks=callbacks,
            verbose=1
        )

        # Clean up to save memory
        K.clear_session()
        gc.collect()

        # Evaluate on test set
        model.load_weights(f"unet{unet_version}_fold{fold}_best.keras")

        # Get a batch from test set for evaluation
        X_test_batch, y_test_batch = next(iter(test_gen))
        y_pred = model.predict(X_test_batch)

        # Calculate metrics for each sample in batch
        for i in range(len(X_test_batch)):
            dice, mean_dist, hausdorff = calculate_metrics(y_test_batch[i], y_pred[i])
            fold_results.append({
                'dice': dice,
                'mean_dist': mean_dist,
                'hausdorff': hausdorff,
            })

            # Visualize first sample
            if i == 0:
                plt.figure(figsize=(12, 4))
                plt.subplot(1, 3, 1)
                plt.imshow(X_test_batch[i].squeeze(), cmap='gray')
                plt.title("Original Image")

                plt.subplot(1, 3, 2)
                plt.imshow(y_test_batch[i].squeeze(), cmap='gray')
                plt.title("Ground Truth")

                plt.subplot(1, 3, 3)
                plt.imshow(y_pred[i].squeeze() > 0.5, cmap='gray')
                plt.title("Predicted Mask")

                plt.tight_layout()
                plt.show()

    # Print final results
    if fold_results:
        print("\n=== Final Results ===")
        avg_dice = np.mean([r['dice'] for r in fold_results])
        avg_mean_dist = np.mean([r['mean_dist'] for r in fold_results])
        avg_hausdorff = np.mean([r['hausdorff'] for r in fold_results])

        print(f"Average Dice: {avg_dice:.3f}")
        print(f"Average Mean Absolute Distance: {avg_mean_dist:.3f} mm")
        print(f"Average Hausdorff Distance: {avg_hausdorff:.3f} mm")
    else:
        print("\nNo results to report")

if __name__ == '__main__':
    main()

Setting up data generators...
Found 500 patients. Using 425 for training, 75 for testing.

=== Fold 1/5 ===


  self._warn_if_super_not_called()


Epoch 1/100
[1m 7/85[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m10:10[0m 8s/step - accuracy: 0.5807 - dice_coef: 0.1494 - loss: 1.5384

KeyboardInterrupt: 