In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
import nibabel as nib
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (Input, Conv2D, MaxPooling2D, UpSampling2D, 
                                    concatenate, BatchNormalization, Activation, 
                                    Multiply, Add, Lambda, DepthwiseConv2D)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras import backend as K
import tensorflow as tf
from skimage.transform import resize
from tqdm import tqdm
from scipy.ndimage import distance_transform_edt
from skimage.morphology import erosion, dilation, square
from sklearn.model_selection import train_test_split


In [None]:
# Configuration
IMG_HEIGHT = 256
IMG_WIDTH = 256
IMG_CHANNELS = 1
BATCH_SIZE = 8  # Reduced from original due to memory requirements of attention
EPOCHS = 100
INIT_LR = 1e-4
N_FOLDS = 10
SEED = 42
DATA_PATH = "/kaggle/input/camus-dataset/database_nifti"

In [None]:
# Enhanced metrics with class-specific calculations
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

def bce_dice_loss(y_true, y_pred):
    return tf.keras.losses.binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)


In [None]:
# Novel Attention Gate Module
def attention_gate(x, g, inter_channel):
    """Hybrid Attention Gate with proper dimension handling"""
    # Get the number of channels in the input x
    x_channels = K.int_shape(x)[-1]
    
    # Process the gating signal (upsampled feature)
    g_conv = Conv2D(x_channels, (1, 1), strides=1, padding='same')(g)
    g_conv = BatchNormalization()(g_conv)
    g_conv = Activation('relu')(g_conv)
    
    # Process the input features
    x_conv = Conv2D(x_channels, (1, 1), strides=1, padding='same')(x)
    x_conv = BatchNormalization()(x_conv)
    x_conv = Activation('relu')(x_conv)
    
    # Add the processed features
    combined = Add()([x_conv, g_conv])
    combined = Activation('relu')(combined)
    
    # Attention coefficients
    attention = Conv2D(1, (1, 1), strides=1, padding='same', activation='sigmoid')(combined)
    
    # Apply attention
    return Multiply()([x, attention])


# Depthwise Separable Block
def depthwise_sep_block(x, filters, kernel_size=3, strides=1):
    """
    Depthwise separable convolution block
    Args:
        x: input tensor
        filters: number of output filters
        kernel_size: size of convolution kernel
        strides: stride length
    Returns:
        Output tensor after applying depthwise separable convolution
    """
    # Depthwise convolution
    x = DepthwiseConv2D(kernel_size=kernel_size,
                       strides=strides,
                       padding='same',
                       depth_multiplier=1)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    # Pointwise convolution
    x = Conv2D(filters=filters,
               kernel_size=1,
               strides=1,
               padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    return x

In [None]:
# HAG-UNet Architecture
def hag_unet(input_size=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)):
    inputs = Input(input_size)
    
    # Downsample path with depthwise blocks
    # Block 1
    conv1 = depthwise_sep_block(inputs, 32)
    conv1 = depthwise_sep_block(conv1, 32)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    
    # Block 2
    conv2 = depthwise_sep_block(pool1, 64)
    conv2 = depthwise_sep_block(conv2, 64)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    
    # Block 3
    conv3 = depthwise_sep_block(pool2, 128)
    conv3 = depthwise_sep_block(conv3, 128)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    
    # Block 4
    conv4 = depthwise_sep_block(pool3, 256)
    conv4 = depthwise_sep_block(conv4, 256)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
    
    # Bottleneck
    conv5 = depthwise_sep_block(pool4, 512)
    conv5 = depthwise_sep_block(conv5, 512)
    
    # Upsample path with attention gates
    # Up 1
    up6 = UpSampling2D(size=(2, 2))(conv5)
    att6 = attention_gate(conv4, up6, 256)
    merge6 = concatenate([up6, att6], axis=-1)
    conv6 = depthwise_sep_block(merge6, 256)
    conv6 = depthwise_sep_block(conv6, 256)
    
    # Up 2
    up7 = UpSampling2D(size=(2, 2))(conv6)
    att7 = attention_gate(conv3, up7, 128)
    merge7 = concatenate([up7, att7], axis=-1)
    conv7 = depthwise_sep_block(merge7, 128)
    conv7 = depthwise_sep_block(conv7, 128)
    
    # Up 3
    up8 = UpSampling2D(size=(2, 2))(conv7)
    att8 = attention_gate(conv2, up8, 64)
    merge8 = concatenate([up8, att8], axis=-1)
    conv8 = depthwise_sep_block(merge8, 64)
    conv8 = depthwise_sep_block(conv8, 64)
    
    # Up 4
    up9 = UpSampling2D(size=(2, 2))(conv8)
    att9 = attention_gate(conv1, up9, 32)
    merge9 = concatenate([up9, att9], axis=-1)
    conv9 = depthwise_sep_block(merge9, 32)
    conv9 = depthwise_sep_block(conv9, 32)
    
    # Output
    outputs = Conv2D(1, 1, activation='sigmoid')(conv9)
    
    model = Model(inputs=inputs, outputs=outputs)
    return model

# Enhanced evaluation metrics with contour refinement
def calculate_metrics(y_true, y_pred):
    # Apply morphological refinement to predictions
    refined_pred = np.zeros_like(y_pred)
    for i in range(y_pred.shape[0]):
        pred = (y_pred[i].squeeze() > 0.5).astype(np.uint8)
        # Small dilation then erosion to close small holes
        refined = dilation(pred, square(2))
        refined = erosion(refined, square(2))
        refined_pred[i] = refined[..., np.newaxis]
    
    y_pred = refined_pred
    
    # Dice coefficient
    intersection = np.sum(y_true * y_pred)
    dice = (2. * intersection + 1.) / (np.sum(y_true) + np.sum(y_pred) + 1.)
    
    # Mean absolute distance with contour weighting
    dt_true = distance_transform_edt(1 - y_true.squeeze())
    dt_pred = distance_transform_edt(1 - y_pred.squeeze())
    
    # Weight distances by how close they are to contours
    contour_true = y_true.squeeze() - erosion(y_true.squeeze(), square(3))
    contour_pred = y_pred.squeeze() - erosion(y_pred.squeeze(), square(3))
    
    contour_weights_true = 1 + 2 * contour_true  # Give 3x weight to contour pixels
    contour_weights_pred = 1 + 2 * contour_pred
    
    mean_dist = (np.mean(dt_pred[y_true.squeeze() > 0.5] * contour_weights_true[y_true.squeeze() > 0.5]) + 
                np.mean(dt_true[y_pred.squeeze() > 0.5] * contour_weights_pred[y_pred.squeeze() > 0.5])) / 2
    
    # Hausdorff distance with 95th percentile (more robust)
    if np.sum(contour_true) == 0 or np.sum(contour_pred) == 0:
        hausdorff = np.inf
    else:
        dt_true_contour = distance_transform_edt(1 - contour_true)
        dt_pred_contour = distance_transform_edt(1 - contour_pred)
        hd_true = np.percentile(contour_pred * dt_true_contour, 95)
        hd_pred = np.percentile(contour_true * dt_pred_contour, 95)
        hausdorff = max(hd_true, hd_pred)
    
    return dice, mean_dist, hausdorff


In [None]:

# Main training and evaluation (similar to original but with HAG-UNet)
def main():
    # Load dataset (same as original)
    print("Loading and preprocessing dataset...")
    X, y_endo, y_epi, y_la = load_dataset(DATA_PATH)
    
    # Create 10-fold cross-validation
    kf = KFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)
    fold_results = []
    
    for fold, (train_idx, test_idx) in enumerate(kf.split(X)):
        print(f"\n=== Fold {fold + 1}/{N_FOLDS} ===")
        
        # Split data
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y_endo[train_idx], y_endo[test_idx]
        
        # Train/val split
        X_train, X_val, y_train, y_val = train_test_split(
            X_train, y_train, test_size=0.1, random_state=SEED)
        
        # Create HAG-UNet model
        model = hag_unet()
        
        # Enhanced optimizer with warmup
        lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
            INIT_LR,
            decay_steps=1000,
            decay_rate=0.96,
            staircase=True)
        
        model.compile(optimizer=Adam(learning_rate=lr_schedule), 
                     loss=bce_dice_loss, 
                     metrics=[dice_coef, 'accuracy'])
        
        callbacks = [
            ModelCheckpoint(f"hag_unet_fold{fold}_best.keras", 
                           monitor='val_dice_coef', 
                           mode='max', 
                           save_best_only=True, 
                           verbose=1),
            EarlyStopping(monitor='val_dice_coef', 
                         patience=20,  # Increased patience for attention learning
                         mode='max', 
                         verbose=1),
            ReduceLROnPlateau(monitor='val_dice_coef', 
                            factor=0.5, 
                            patience=8,  # More patience for learning rate
                            min_lr=1e-6, 
                            mode='max', 
                            verbose=1)
        ]
        
        # Train with class weights (focus more on boundary pixels)
        sample_weights = np.ones_like(y_train)
        for i in range(len(y_train)):
            contours = y_train[i].squeeze() - erosion(y_train[i].squeeze(), square(3))
            sample_weights[i] = 1 + 2*contours  # 3x weight for boundary pixels
        
        print(f"Training on {len(X_train)} samples")
        history = model.fit(X_train, y_train,
                          batch_size=BATCH_SIZE,
                          epochs=EPOCHS,
                          validation_data=(X_val, y_val),
                          callbacks=callbacks,
                          sample_weight=sample_weights,
                          verbose=1)
        
        # Load best model
        model.load_weights(f"hag_unet_fold{fold}_best.keras")
        
        # Evaluate
        print(f"Evaluating on {len(X_test)} test samples")
        y_pred = model.predict(X_test, batch_size=BATCH_SIZE)
        
        # Calculate enhanced metrics
        dice_scores = []
        mean_distances = []
        hausdorff_distances = []
        
        for i in range(len(X_test)):
            dice, mean_dist, hausdorff = calculate_metrics(y_test[i], y_pred[i])
            dice_scores.append(dice)
            mean_distances.append(mean_dist)
            hausdorff_distances.append(hausdorff)
        
        fold_results.append({
            'dice': np.mean(dice_scores),
            'mean_dist': np.mean(mean_distances),
            'hausdorff': np.mean(hausdorff_distances)
        })
        
        print(f"Fold {fold + 1} Results:")
        print(f"Dice: {np.mean(dice_scores):.4f} ± {np.std(dice_scores):.4f}")
        print(f"Mean Distance: {np.mean(mean_distances):.4f} ± {np.std(mean_distances):.4f} mm")
        print(f"Hausdorff Distance: {np.mean(hausdorff_distances):.4f} ± {np.std(hausdorff_distances):.4f} mm")
        
        # Visualize attention effects
        plot_attention_maps(model, X_test[:3], y_test[:3], fold+1)
    
    # Final results
    print("\n=== HAG-UNet Final Results ===")
    avg_dice = np.mean([r['dice'] for r in fold_results])
    avg_mean_dist = np.mean([r['mean_dist'] for r in fold_results])
    avg_hausdorff = np.mean([r['hausdorff'] for r in fold_results])
    
    print(f"Average Dice: {avg_dice:.4f}")
    print(f"Average Mean Distance: {avg_mean_dist:.4f} mm")
    print(f"Average Hausdorff Distance: {avg_hausdorff:.4f} mm")
    
    # Compare with baseline
    print("\n=== Improvement Over Baseline ===")
    print(f"Dice Improvement: {(avg_dice - 0.932):.4f} ({(avg_dice - 0.932)/0.932*100:.2f}%)")
    print(f"Mean Distance Reduction: {(0.266 - avg_mean_dist):.4f} mm ({(0.266 - avg_mean_dist)/0.266*100:.2f}%)")
    print(f"Hausdorff Distance Reduction: {(11.009 - avg_hausdorff):.4f} mm ({(11.009 - avg_hausdorff)/11.009*100:.2f}%)")




In [None]:
def plot_attention_maps(model, images, masks, fold):
    """Visualize attention gate activations"""
    # Create partial model to get attention outputs
    layer_outputs = [layer.output for layer in model.layers if 'multiply' in layer.name]
    attention_model = Model(inputs=model.input, outputs=layer_outputs)
    
    # Get attention maps
    attention_maps = attention_model.predict(images)
    
    # Plot results
    plt.figure(figsize=(18, 6*len(images)))
    for i in range(len(images)):
        # Original
        plt.subplot(len(images), 4, 1+i*4)
        plt.imshow(images[i].squeeze(), cmap='gray')
        plt.title(f"Original Image (Fold {fold})")
        plt.axis('off')
        
        # Ground truth
        plt.subplot(len(images), 4, 2+i*4)
        plt.imshow(masks[i].squeeze(), cmap='gray')
        plt.title("Ground Truth")
        plt.axis('off')
        
        # Attention map 1 (deepest)
        plt.subplot(len(images), 4, 3+i*4)
        plt.imshow(attention_maps[0][i].squeeze(), cmap='hot')
        plt.title("Deep Attention Map")
        plt.axis('off')
        
        # Attention map 2 (mid-level)
        plt.subplot(len(images), 4, 4+i*4)
        plt.imshow(attention_maps[2][i].squeeze(), cmap='hot')
        plt.title("Mid-Level Attention Map")
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
if __name__ == '__main__':
    main()

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
import nibabel as nib
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Conv2D, MaxPooling2D, UpSampling2D, 
    concatenate, BatchNormalization, Activation, 
    Multiply, Add, Lambda, DepthwiseConv2D
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras import backend as K
import tensorflow as tf
from skimage.transform import resize
from tqdm import tqdm
from scipy.ndimage import distance_transform_edt
from skimage.morphology import erosion, dilation, square
from sklearn.model_selection import train_test_split

# Configuration
IMG_HEIGHT = 256
IMG_WIDTH = 256
IMG_CHANNELS = 1
BATCH_SIZE = 8
EPOCHS = 100
INIT_LR = 1e-4
N_FOLDS = 10
SEED = 42
DATA_PATH = "/kaggle/input/camus-dataset/database_nifti"  # Make sure this path is correct

# ==================== DATA LOADING FUNCTIONS ====================
def load_nifti_image(file_path):
    img = nib.load(file_path)
    data = img.get_fdata()
    return np.squeeze(data)  # Remove singleton dimensions

def preprocess_patient(patient_folder):
    images = []
    masks_endo = []
    masks_epi = []
    masks_la = []
    
    views = ['2CH', '4CH']
    time_points = ['ED', 'ES']
    
    for view in views:
        for tp in time_points:
            # Construct filenames
            base_name = os.path.basename(patient_folder)
            img_path = f"{patient_folder}/{base_name}_{view}_{tp}.nii"
            gt_path = f"{patient_folder}/{base_name}_{view}_{tp}_gt.nii"
            
            if not os.path.exists(img_path) or not os.path.exists(gt_path):
                print(f"Files not found for {view}_{tp}")
                continue
                
            try:
                # Load image
                img = nib.load(img_path).get_fdata()
                
                # Load ground truth
                gt = nib.load(gt_path).get_fdata()
                
                # Resize and normalize
                img_resized = resize(img, (IMG_HEIGHT, IMG_WIDTH), preserve_range=True, anti_aliasing=True)
                gt_resized = resize(gt, (IMG_HEIGHT, IMG_WIDTH), preserve_range=True, anti_aliasing=False)
                
                img_resized = (img_resized - img_resized.min()) / (img_resized.max() - img_resized.min())
                
                # Create masks
                mask_endo = (gt_resized == 1).astype(np.float32)
                mask_epi = (gt_resized == 2).astype(np.float32)
                mask_la = (gt_resized == 3).astype(np.float32)
                
                images.append(img_resized[..., np.newaxis])
                masks_endo.append(mask_endo[..., np.newaxis])
                masks_epi.append(mask_epi[..., np.newaxis])
                masks_la.append(mask_la[..., np.newaxis])
                
            except Exception as e:
                print(f"Error processing {view}_{tp}: {str(e)}")
                continue
    
    if not images:
        print("No valid images found for this patient")
        return np.array([]), np.array([]), np.array([]), np.array([])
    
    return np.array(images), np.array(masks_endo), np.array(masks_epi), np.array(masks_la)

def load_dataset(base_path):
    patient_folders = sorted([
        os.path.join(base_path, f) 
        for f in os.listdir(base_path) 
        if f.startswith('patient') and os.path.isdir(os.path.join(base_path, f))
    ])
    
    all_images = []
    all_masks_endo = []
    all_masks_epi = []
    all_masks_la = []
    
    for patient_folder in tqdm(patient_folders, desc="Loading patients"):
        images, masks_endo, masks_epi, masks_la = preprocess_patient(patient_folder)
        
        if images.size > 0:  # Only append if we got valid data
            all_images.append(images)
            all_masks_endo.append(masks_endo)
            all_masks_epi.append(masks_epi)
            all_masks_la.append(masks_la)
    
    # Check if we got any data at all
    if not all_images:
        raise ValueError("No valid image data found in any patient folder!")
    
    return (np.concatenate(all_images, axis=0),
            np.concatenate(all_masks_endo, axis=0),
            np.concatenate(all_masks_epi, axis=0),
            np.concatenate(all_masks_la, axis=0))

# ==================== MODEL ARCHITECTURE ====================
def attention_gate(x, g, inter_channel):
    """Hybrid Attention Gate with proper dimension handling"""
    # Get the number of channels in the input x
    x_channels = K.int_shape(x)[-1]
    
    # Process the gating signal (upsampled feature)
    g_conv = Conv2D(x_channels, (1, 1), strides=1, padding='same')(g)
    g_conv = BatchNormalization()(g_conv)
    g_conv = Activation('relu')(g_conv)
    
    # Process the input features
    x_conv = Conv2D(x_channels, (1, 1), strides=1, padding='same')(x)
    x_conv = BatchNormalization()(x_conv)
    x_conv = Activation('relu')(x_conv)
    
    # Add the processed features
    combined = Add()([x_conv, g_conv])
    combined = Activation('relu')(combined)
    
    # Attention coefficients
    attention = Conv2D(1, (1, 1), strides=1, padding='same', activation='sigmoid')(combined)
    
    # Apply attention
    return Multiply()([x, attention])

def depthwise_sep_block(x, filters, kernel_size=3, strides=1):
    """
    Depthwise separable convolution block
    Args:
        x: input tensor
        filters: number of output filters
        kernel_size: size of convolution kernel
        strides: stride length
    Returns:
        Output tensor after applying depthwise separable convolution
    """
    # Depthwise convolution
    x = DepthwiseConv2D(kernel_size=kernel_size,
                       strides=strides,
                       padding='same',
                       depth_multiplier=1)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    # Pointwise convolution
    x = Conv2D(filters=filters,
               kernel_size=1,
               strides=1,
               padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    return x

def hag_unet(input_size=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)):
    inputs = Input(input_size)
    
    # Downsample path
    # Block 1
    conv1 = depthwise_sep_block(inputs, 32)
    conv1 = depthwise_sep_block(conv1, 32)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    
    # Block 2
    conv2 = depthwise_sep_block(pool1, 64)
    conv2 = depthwise_sep_block(conv2, 64)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    
    # Block 3
    conv3 = depthwise_sep_block(pool2, 128)
    conv3 = depthwise_sep_block(conv3, 128)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    
    # Block 4
    conv4 = depthwise_sep_block(pool3, 256)
    conv4 = depthwise_sep_block(conv4, 256)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
    
    # Bottleneck
    conv5 = depthwise_sep_block(pool4, 512)
    conv5 = depthwise_sep_block(conv5, 512)
    
    # Upsample path with attention gates
    # Up 1
    up6 = UpSampling2D(size=(2, 2))(conv5)
    up6 = Conv2D(256, (2, 2), activation='relu', padding='same')(up6)  # Channel adjustment
    att6 = attention_gate(conv4, up6, 256)
    merge6 = concatenate([up6, att6], axis=-1)
    conv6 = depthwise_sep_block(merge6, 256)
    conv6 = depthwise_sep_block(conv6, 256)
    
    # Up 2
    up7 = UpSampling2D(size=(2, 2))(conv6)
    up7 = Conv2D(128, (2, 2), activation='relu', padding='same')(up7)  # Channel adjustment
    att7 = attention_gate(conv3, up7, 128)
    merge7 = concatenate([up7, att7], axis=-1)
    conv7 = depthwise_sep_block(merge7, 128)
    conv7 = depthwise_sep_block(conv7, 128)
    
    # Up 3
    up8 = UpSampling2D(size=(2, 2))(conv7)
    up8 = Conv2D(64, (2, 2), activation='relu', padding='same')(up8)  # Channel adjustment
    att8 = attention_gate(conv2, up8, 64)
    merge8 = concatenate([up8, att8], axis=-1)
    conv8 = depthwise_sep_block(merge8, 64)
    conv8 = depthwise_sep_block(conv8, 64)
    
    # Up 4
    up9 = UpSampling2D(size=(2, 2))(conv8)
    up9 = Conv2D(32, (2, 2), activation='relu', padding='same')(up9)  # Channel adjustment
    att9 = attention_gate(conv1, up9, 32)
    merge9 = concatenate([up9, att9], axis=-1)
    conv9 = depthwise_sep_block(merge9, 32)
    conv9 = depthwise_sep_block(conv9, 32)
    
    outputs = Conv2D(1, 1, activation='sigmoid')(conv9)
    
    model = Model(inputs=inputs, outputs=outputs)
    return model

# ==================== METRICS AND LOSS ====================
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

def bce_dice_loss(y_true, y_pred):
    return tf.keras.losses.binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)

# ==================== MAIN TRAINING LOOP ====================
def main():
    # Load dataset
    print("Loading and preprocessing dataset...")
    X, y_endo, y_epi, y_la = load_dataset(DATA_PATH)
    
    # Create 10-fold cross-validation
    kf = KFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)
    fold_results = []
    
    for fold, (train_idx, test_idx) in enumerate(kf.split(X)):
        print(f"\n=== Fold {fold + 1}/{N_FOLDS} ===")
        
        # Split data
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y_endo[train_idx], y_endo[test_idx]
        
        # Further split training into train/val (90/10)
        X_train, X_val, y_train, y_val = train_test_split(
            X_train, y_train, test_size=0.1, random_state=SEED)
        
        # Create model
        model = hag_unet()
        model.compile(optimizer=Adam(learning_rate=INIT_LR), 
                     loss=bce_dice_loss, 
                     metrics=[dice_coef, 'accuracy'])
        
        callbacks = [
            ModelCheckpoint(f"hag_unet_fold{fold}_best.keras", 
                          monitor='val_dice_coef', 
                          mode='max', 
                          save_best_only=True),
            EarlyStopping(monitor='val_dice_coef', 
                        patience=15, 
                        mode='max'),
            ReduceLROnPlateau(monitor='val_dice_coef',
                            factor=0.5,
                            patience=5,
                            min_lr=1e-6,
                            mode='max')
        ]
        
        # Train model
        print(f"Training on {len(X_train)} samples")
        history = model.fit(X_train, y_train,
                          batch_size=BATCH_SIZE,
                          epochs=EPOCHS,
                          validation_data=(X_val, y_val),
                          callbacks=callbacks,
                          verbose=1)
        
        # Evaluate
        model.load_weights(f"hag_unet_fold{fold}_best.keras")
        y_pred = model.predict(X_test, batch_size=BATCH_SIZE)
        
        # Calculate dice scores
        dice_scores = [dice_coef(y_test[i], y_pred[i]) for i in range(len(y_test))]
        avg_dice = np.mean(dice_scores)
        
        fold_results.append(avg_dice)
        print(f"Fold {fold+1} Dice: {avg_dice:.4f}")
    
    # Final results
    print("\n=== Final Cross-Validation Results ===")
    print(f"Average Dice: {np.mean(fold_results):.4f} ± {np.std(fold_results):.4f}")

if __name__ == '__main__':
    main()

In [18]:
### ---------- The working code starts here --------------------

In [19]:
import os
import numpy as np
import nibabel as nib
from skimage.transform import resize
from scipy.ndimage import rotate
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import (Input, Conv2D, BatchNormalization, ReLU, MaxPooling2D, 
                                     Conv2DTranspose, GlobalAveragePooling2D, Dense, Multiply, 
                                     Add, concatenate, LayerNormalization, Reshape, Lambda)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras import backend as K
from sklearn.model_selection import KFold, train_test_split
from tqdm import tqdm
import matplotlib.pyplot as plt
import tensorflow as tf
import pickle

In [20]:
# Configuration
IMG_HEIGHT = 256
IMG_WIDTH = 256
IMG_CHANNELS = 1
BATCH_SIZE = 8
EPOCHS = 100
INIT_LR = 1e-4
N_FOLDS = 3
SEED = 42
DATA_PATH = "/kaggle/input/camus-dataset/database_nifti"  # Update this with the actual dataset path

# Metrics
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

def bce_dice_loss(y_true, y_pred):
    return tf.keras.losses.binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)

def iou_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    union = K.sum(y_true_f) + K.sum(y_pred_f) - intersection
    return (intersection + smooth) / (union + smooth)

def average_precision(y_true, y_pred, thresholds=tf.constant(np.arange(0.0, 1.1, 0.1))):
    precisions = []
    for threshold in thresholds:
        y_pred_thresholded = tf.cast(y_pred > threshold, tf.float32)
        tp = tf.reduce_sum(y_true * y_pred_thresholded)
        fp = tf.reduce_sum((1 - y_true) * y_pred_thresholded)
        precision = tp / (tp + fp + K.epsilon())
        precisions.append(precision)
    return tf.reduce_mean(tf.stack(precisions))

In [21]:
# Preprocessing
def load_nifti_image(file_path):
    img = nib.load(file_path)
    return np.squeeze(img.get_fdata())

def preprocess_patient_with_rotation(patient_folder, rotation_angles=[0, 90, 180, 270]):
    images = []
    masks_endo = []
    masks_epi = []
    masks_la = []
    
    views = ['2CH', '4CH']
    time_points = ['ED', 'ES']
    
    for view in views:
        for tp in time_points:
            base_name = os.path.basename(patient_folder)
            img_path = f"{patient_folder}/{base_name}_{view}_{tp}.nii"
            gt_path = f"{patient_folder}/{base_name}_{view}_{tp}_gt.nii"
            
            if not os.path.exists(img_path) or not os.path.exists(gt_path):
                continue
            
            try:
                img = load_nifti_image(img_path)
                gt = load_nifti_image(gt_path)
                
                img_resized = resize(img, (IMG_HEIGHT, IMG_WIDTH), preserve_range=True, anti_aliasing=True)
                gt_resized = resize(gt, (IMG_HEIGHT, IMG_WIDTH), preserve_range=True, anti_aliasing=False)
                
                img_resized = (img_resized - img_resized.min()) / (img_resized.max() - img_resized.min())
                
                mask_endo = (gt_resized == 1).astype(np.float32)
                mask_epi = (gt_resized == 2).astype(np.float32)
                mask_la = (gt_resized == 3).astype(np.float32)
                
                for angle in rotation_angles:
                    rotated_img = rotate(img_resized, angle, reshape=False, mode='reflect')
                    rotated_mask_endo = rotate(mask_endo, angle, reshape=False, mode='reflect')
                    rotated_mask_epi = rotate(mask_epi, angle, reshape=False, mode='reflect')
                    rotated_mask_la = rotate(mask_la, angle, reshape=False, mode='reflect')
                    
                    images.append(rotated_img[..., np.newaxis])
                    masks_endo.append(rotated_mask_endo[..., np.newaxis])
                    masks_epi.append(rotated_mask_epi[..., np.newaxis])
                    masks_la.append(rotated_mask_la[..., np.newaxis])
            
            except Exception as e:
                print(f"Error processing {base_name}: {e}")
                continue
    
    if images:
        return np.array(images), np.array(masks_endo), np.array(masks_epi), np.array(masks_la)
    else:
        return np.array([]), np.array([]), np.array([]), np.array([])

def load_dataset(base_path):
    patient_folders = sorted([
        os.path.join(base_path, f) 
        for f in os.listdir(base_path) 
        if f.startswith('patient') and os.path.isdir(os.path.join(base_path, f))
    ])
    
    all_images = []
    all_masks_endo = []
    all_masks_epi = []
    all_masks_la = []
    
    for patient_folder in tqdm(patient_folders, desc="Loading patients"):
        images, masks_endo, masks_epi, masks_la = preprocess_patient_with_rotation(patient_folder)
        
        if images.size > 0:
            all_images.append(images)
            all_masks_endo.append(masks_endo)
            all_masks_epi.append(masks_epi)
            all_masks_la.append(masks_la)
    
    if not all_images:
        raise ValueError("No valid image data found in any patient folder!")
    
    return (np.concatenate(all_images, axis=0),
            np.concatenate(all_masks_endo, axis=0),
            np.concatenate(all_masks_epi, axis=0),
            np.concatenate(all_masks_la, axis=0))

In [22]:

# Mamba-based TransUNet Model
def mamba_block(x, hidden_dim, ssm_dim, dropout_rate=0.1):
    """Implements a simplified Mamba block for 2D feature maps."""
    batch, height, width, channels = K.int_shape(x)
    
    # 1. Depthwise convolution for local feature mixing
    x_res = x
    x = LayerNormalization()(x)
    x = Conv2D(channels, kernel_size=3, padding='same', groups=channels)(x)
    
    # 2. Project to hidden dimension
    x = Conv2D(hidden_dim, kernel_size=1)(x)
    x = ReLU()(x)
    
    # 3. Simplified SSM (State Space Model) path
    # Reshape to sequence for SSM processing
    x_reshaped = Reshape((height * width, hidden_dim))(x)
    
    # Simplified SSM implementation (using dense layers as approximation)
    ssm = Dense(ssm_dim, activation='swish')(x_reshaped)
    ssm = Dense(hidden_dim)(ssm)
    
    # Residual connection
    ssm = Dense(hidden_dim)(x_reshaped) + ssm
    
    # Reshape back to spatial dimensions
    x_ssm = Reshape((height, width, hidden_dim))(ssm)
    
    # 4. Project back to channel dimension
    x_out = Conv2D(channels, kernel_size=1)(x_ssm)
    
    # Add dropout and residual
    x_out = tf.keras.layers.Dropout(dropout_rate)(x_out)
    x_out = x_res + x_out
    
    return x_out

def transunet_mamba(input_size=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)):
    """TransUNet architecture with Mamba blocks in the encoder."""
    inputs = Input(input_size)
    
    # Initial convolution to project to higher dimension
    x = Conv2D(64, kernel_size=7, strides=2, padding='same')(inputs)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    
    # Encoder path with Mamba blocks
    # Level 1
    x1 = mamba_block(x, hidden_dim=128, ssm_dim=64)
    x1 = mamba_block(x1, hidden_dim=128, ssm_dim=64)
    p1 = MaxPooling2D(pool_size=(2, 2))(x1)
    
    # Level 2
    x2 = mamba_block(p1, hidden_dim=256, ssm_dim=128)
    x2 = mamba_block(x2, hidden_dim=256, ssm_dim=128)
    p2 = MaxPooling2D(pool_size=(2, 2))(x2)
    
    # Level 3
    x3 = mamba_block(p2, hidden_dim=512, ssm_dim=256)
    x3 = mamba_block(x3, hidden_dim=512, ssm_dim=256)
    p3 = MaxPooling2D(pool_size=(2, 2))(x3)
    
    # Level 4 (bottleneck)
    x4 = mamba_block(p3, hidden_dim=1024, ssm_dim=512)
    x4 = mamba_block(x4, hidden_dim=1024, ssm_dim=512)
    
    # Decoder path with skip connections
    # Up level 4 to level 3
    u3 = Conv2DTranspose(512, kernel_size=3, strides=2, padding='same')(x4)
    u3 = concatenate([u3, x3], axis=-1)
    u3 = Conv2D(512, kernel_size=3, padding='same')(u3)
    u3 = BatchNormalization()(u3)
    u3 = ReLU()(u3)
    
    # Up level 3 to level 2
    u2 = Conv2DTranspose(256, kernel_size=3, strides=2, padding='same')(u3)
    u2 = concatenate([u2, x2], axis=-1)
    u2 = Conv2D(256, kernel_size=3, padding='same')(u2)
    u2 = BatchNormalization()(u2)
    u2 = ReLU()(u2)
    
    # Up level 2 to level 1
    u1 = Conv2DTranspose(128, kernel_size=3, strides=2, padding='same')(u2)
    u1 = concatenate([u1, x1], axis=-1)
    u1 = Conv2D(128, kernel_size=3, padding='same')(u1)
    u1 = BatchNormalization()(u1)
    u1 = ReLU()(u1)
    
    # Final upsampling to original resolution
    u0 = Conv2DTranspose(64, kernel_size=3, strides=2, padding='same')(u1)
    u0 = Conv2D(64, kernel_size=3, padding='same')(u0)
    u0 = BatchNormalization()(u0)
    u0 = ReLU()(u0)
    
    # Output layer
    outputs = Conv2D(1, kernel_size=1, activation='sigmoid')(u0)
    
    return Model(inputs=inputs, outputs=outputs)

# KFold Splits
def save_kfold_splits(X, y, n_splits, seed, save_path):
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=seed)
    splits = [(train_idx.tolist(), test_idx.tolist()) for train_idx, test_idx in kf.split(X)]
    with open(save_path, 'wb') as f:
        pickle.dump(splits, f)

def load_kfold_splits(file_path):
    with open(file_path, 'rb') as f:
        return pickle.load(f)




In [23]:
# Main Function
def main():
    print("Loading and preprocessing dataset...")
    X, y_endo, _, _ = load_dataset(DATA_PATH)
    save_kfold_splits(X, y_endo, N_FOLDS, SEED, "kfold_splits.pkl")
    
    splits = load_kfold_splits("kfold_splits.pkl")
    fold_results = []
    
    for fold, (train_idx, test_idx) in enumerate(splits):
        print(f"\n=== Fold {fold + 1}/{N_FOLDS} ===")
        
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y_endo[train_idx], y_endo[test_idx]
        
        X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=SEED)
        
        model = transunet_mamba()
        model.compile(optimizer=Adam(learning_rate=INIT_LR), 
                      loss=bce_dice_loss, 
                      metrics=[dice_coef, iou_coef, 'accuracy'])
        
        callbacks = [
            ModelCheckpoint(f"mamba_unet_fold{fold}_best.keras", 
                          monitor='val_dice_coef', 
                          mode='max', 
                          save_best_only=True, 
                          verbose=1),
            EarlyStopping(monitor='val_dice_coef', 
                         patience=15, 
                         mode='max', 
                         verbose=1),
            ReduceLROnPlateau(monitor='val_dice_coef', 
                            factor=0.5, 
                            patience=5, 
                            min_lr=1e-6, 
                            mode='max', 
                            verbose=1)
        ]
        
        print(f"Training on {len(X_train)} samples, validating on {len(X_val)} samples")
        model.fit(X_train, y_train, 
                batch_size=BATCH_SIZE, 
                epochs=EPOCHS, 
                validation_data=(X_val, y_val), 
                callbacks=callbacks, 
                verbose=1)
        
        model.load_weights(f"mamba_unet_fold{fold}_best.keras")
        print(f"Evaluating on {len(X_test)} test samples")
        y_pred = (model.predict(X_test, batch_size=BATCH_SIZE) > 0.5).astype(np.float32)
        
        dice_scores = [dice_coef(y_test[i], y_pred[i]).numpy() for i in range(len(y_test))]
        fold_results.append({'dice': np.mean(dice_scores)})
    
    print("\n=== Final Cross-Validation Results ===")
    avg_dice = np.mean([r['dice'] for r in fold_results])
    print(f"Average Dice: {avg_dice:.3f}")

In [None]:
if __name__ == '__main__':
    main()

Loading and preprocessing dataset...


Loading patients: 100%|██████████| 500/500 [07:21<00:00,  1.13it/s]



=== Fold 1/3 ===
Training on 4799 samples, validating on 534 samples
Epoch 1/100


I0000 00:00:1744902243.078811     278 service.cc:148] XLA service 0x7858d0002490 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1744902243.078859     278 service.cc:156]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
I0000 00:00:1744902243.078863     278 service.cc:156]   StreamExecutor device (1): Tesla T4, Compute Capability 7.5
I0000 00:00:1744902246.491155     278 cuda_dnn.cc:529] Loaded cuDNN version 90300
I0000 00:00:1744902285.386093     278 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 386ms/step - accuracy: 0.5907 - dice_coef: 0.3951 - iou_coef: 0.2546 - loss: 0.9742
Epoch 1: val_dice_coef improved from -inf to 0.61768, saving model to mamba_unet_fold0_best.keras
[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m332s[0m 422ms/step - accuracy: 0.5908 - dice_coef: 0.3953 - iou_coef: 0.2547 - loss: 0.9738 - val_accuracy: 0.6573 - val_dice_coef: 0.6177 - val_iou_coef: 0.4483 - val_loss: 0.5345 - learning_rate: 1.0000e-04
Epoch 2/100
[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 321ms/step - accuracy: 0.6587 - dice_coef: 0.6876 - iou_coef: 0.5261 - loss: 0.4245
Epoch 2: val_dice_coef improved from 0.61768 to 0.78362, saving model to mamba_unet_fold0_best.keras
[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m200s[0m 334ms/step - accuracy: 0.6587 - dice_coef: 0.6876 - iou_coef: 0.5262 - loss: 0.4244 - val_accuracy: 0.6649 - val_dice_coef: 0.7836 - val_iou_coef: 0.645

In [17]:
##Mamba Try Number 2

In [None]:
if __name__ == '__main__':
    main()

Loading and preprocessing dataset...


Loading patients: 100%|██████████| 500/500 [06:01<00:00,  1.38it/s]



=== Fold 1/3 ===
Training on 4799 samples, validating on 534 samples
Epoch 1/100


I0000 00:00:1744994760.122192     175 service.cc:148] XLA service 0x7fc440002780 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1744994760.127102     175 service.cc:156]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
I0000 00:00:1744994760.127125     175 service.cc:156]   StreamExecutor device (1): Tesla T4, Compute Capability 7.5
I0000 00:00:1744994763.766654     175 cuda_dnn.cc:529] Loaded cuDNN version 90300
I0000 00:00:1744994808.541280     175 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 475ms/step - accuracy: 0.6069 - dice_coef: 0.4381 - iou_coef: 0.2916 - loss: 0.8788
Epoch 1: val_dice_coef improved from -inf to 0.23565, saving model to mamba_unet_fold0_best.keras
[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m387s[0m 514ms/step - accuracy: 0.6069 - dice_coef: 0.4383 - iou_coef: 0.2918 - loss: 0.8784 - val_accuracy: 0.5981 - val_dice_coef: 0.2356 - val_iou_coef: 0.1351 - val_loss: 1.0325 - learning_rate: 1.0000e-04
Epoch 2/100
[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 402ms/step - accuracy: 0.6587 - dice_coef: 0.7245 - iou_coef: 0.5705 - loss: 0.3798
Epoch 2: val_dice_coef improved from 0.23565 to 0.79953, saving model to mamba_unet_fold0_best.keras
[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m251s[0m 418ms/step - accuracy: 0.6587 - dice_coef: 0.7245 - iou_coef: 0.5705 - loss: 0.3798 - val_accuracy: 0.6650 - val_dice_coef: 0.7995 - val_iou_coef: 0.667

### Here start the implementaion of the true mamba structure with SSM that failed.

In [4]:
import os
import numpy as np
import nibabel as nib
from skimage.transform import resize
from scipy.ndimage import rotate
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Conv2D, BatchNormalization, ReLU, MaxPooling2D,
    Conv2DTranspose, concatenate, LayerNormalization, 
    Reshape, Dense, Multiply, Add, Lambda, Dropout
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras import backend as K
import tensorflow as tf
from sklearn.model_selection import KFold, train_test_split
from tqdm import tqdm
import pickle

# Configuration
IMG_HEIGHT = 256
IMG_WIDTH = 256
IMG_CHANNELS = 1
BATCH_SIZE = 4  # Reduced to help with memory
EPOCHS = 100
INIT_LR = 1e-4
N_FOLDS = 3
SEED = 42
DATA_PATH = "/kaggle/input/camus-dataset/database_nifti"

# Metrics (unchanged from your original code)
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

def bce_dice_loss(y_true, y_pred):
    return tf.keras.losses.binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)

# Preprocessing (unchanged from your original code)
def load_nifti_image(file_path):
    img = nib.load(file_path)
    return np.squeeze(img.get_fdata())

def preprocess_patient_with_rotation(patient_folder, rotation_angles=[0, 90, 180, 270]):
    # ... (keep your existing implementation)
    pass

def load_dataset(base_path):
    # ... (keep your existing implementation)
    pass

# ----------------------------
# Mamba-Specific Components
# ----------------------------

class SelectiveSSM(tf.keras.layers.Layer):
    def __init__(self, d_model, ssm_dim, **kwargs):
        super().__init__(**kwargs)
        self.d_model = d_model
        self.ssm_dim = ssm_dim
        
        # Projections
        self.in_proj = Dense(d_model * 3)  # Produces delta, B, and skip
        self.out_proj = Dense(d_model)
        
        # SSM parameters
        self.A = self.add_weight(shape=(ssm_dim,), initializer='random_normal')  # State matrix
        self.C = self.add_weight(shape=(d_model, ssm_dim), initializer='random_normal')  # Output projection
        
    def call(self, x):
        batch, seq_len, _ = x.shape
        
        # Project input to get delta, B, and skip connection
        x_proj = self.in_proj(x)  # (batch, seq_len, 3*d_model)
        delta, B, x_skip = tf.split(x_proj, 3, axis=-1)  # Each (batch, seq_len, d_model)
        
        # Discretization
        A_bar = tf.exp(tf.einsum('bnd,d->bnd', delta, self.A))  # (batch, seq_len, d_model)
        B_bar = tf.einsum('bnd,bnd->bn', delta, B)  # (batch, seq_len)
        
        # Selective scan
        h = tf.zeros((batch, self.ssm_dim))
        outputs = []
        for t in range(seq_len):
            h = A_bar[:,t,:] * h + B_bar[:,t,None]
            yt = tf.einsum('bd,dk->bk', h, self.C)
            outputs.append(yt)
        
        y = tf.stack(outputs, axis=1)  # (batch, seq_len, d_model)
        return self.out_proj(y + x_skip)

def mamba_block(x, d_model, ssm_dim, dropout_rate=0.1):
    """Full Mamba block implementation"""
    # 1. Layer normalization
    x_norm = LayerNormalization()(x)
    
    # 2. Depthwise convolution for local feature mixing
    x_conv = Conv2D(d_model, kernel_size=3, padding='same', groups=d_model)(x_norm)
    
    # 3. Tokenization: Flatten spatial dims to sequence
    batch, h, w, c = tf.shape(x_conv)
    tokens = Reshape((h * w, c))(x_conv)  # (batch, seq_len, d_model)
    
    # 4. Add positional embeddings
    positions = tf.range(h * w)
    pos_emb = tf.one_hot(positions, depth=h*w)
    pos_emb = Dense(d_model)(pos_emb)  # (seq_len, d_model)
    tokens = tokens + pos_emb[None,:,:]
    
    # 5. Apply selective SSM
    ssm = SelectiveSSM(d_model, ssm_dim)(tokens)
    
    # 6. Reshape back to spatial
    x_ssm = Reshape((h, w, d_model))(ssm)
    
    # 7. Project back to channel dimension
    x_out = Conv2D(K.int_shape(x)[-1], kernel_size=1)(x_ssm)
    
    # 8. Add dropout and residual
    x_out = Dropout(dropout_rate)(x_out)
    return x + x_out

# ----------------------------
# CNN Front-End
# ----------------------------

def cnn_feature_extractor(inputs):
    """CNN backbone to extract multi-scale features"""
    # Level 1 (1/2 resolution)
    x1 = Conv2D(64, kernel_size=3, strides=1, padding='same')(inputs)
    x1 = BatchNormalization()(x1)
    x1 = ReLU()(x1)
    p1 = MaxPooling2D(pool_size=2)(x1)
    
    # Level 2 (1/4 resolution)
    x2 = Conv2D(128, kernel_size=3, strides=1, padding='same')(p1)
    x2 = BatchNormalization()(x2)
    x2 = ReLU()(x2)
    p2 = MaxPooling2D(pool_size=2)(x2)
    
    # Level 3 (1/8 resolution)
    x3 = Conv2D(256, kernel_size=3, strides=1, padding='same')(p2)
    x3 = BatchNormalization()(x3)
    x3 = ReLU()(x3)
    p3 = MaxPooling2D(pool_size=2)(x3)
    
    return x1, x2, x3, p3  # Return all skip connections

# ----------------------------
# Full Mamba TransUNet Model
# ----------------------------

def transunet_mamba(input_size=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)):
    inputs = Input(input_size)
    
    # CNN Front-End
    x1, x2, x3, cnn_feats = cnn_feature_extractor(inputs)
    
    # Mamba Encoder
    x_mamba = mamba_block(cnn_feats, d_model=256, ssm_dim=128)
    x_mamba = mamba_block(x_mamba, d_model=256, ssm_dim=128)
    
    # Decoder with skip connections
    # Up 1/8 -> 1/4
    u3 = Conv2DTranspose(256, kernel_size=3, strides=2, padding='same')(x_mamba)
    u3 = concatenate([u3, x3], axis=-1)
    u3 = Conv2D(256, kernel_size=3, padding='same')(u3)
    u3 = BatchNormalization()(u3)
    u3 = ReLU()(u3)
    
    # Up 1/4 -> 1/2
    u2 = Conv2DTranspose(128, kernel_size=3, strides=2, padding='same')(u3)
    u2 = concatenate([u2, x2], axis=-1)
    u2 = Conv2D(128, kernel_size=3, padding='same')(u2)
    u2 = BatchNormalization()(u2)
    u2 = ReLU()(u2)
    
    # Up 1/2 -> original
    u1 = Conv2DTranspose(64, kernel_size=3, strides=2, padding='same')(u2)
    u1 = concatenate([u1, x1], axis=-1)
    u1 = Conv2D(64, kernel_size=3, padding='same')(u1)
    u1 = BatchNormalization()(u1)
    u1 = ReLU()(u1)
    
    # Final output
    outputs = Conv2D(1, kernel_size=1, activation='sigmoid')(u1)
    
    return Model(inputs=inputs, outputs=outputs)

# Rest of your code remains the same (KFold, main function, etc.)

2025-04-18 15:40:20.006870: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744990820.272115      82 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744990820.345105      82 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [None]:
# ----------------------------
# K-Fold Cross Validation Setup
# ----------------------------

def save_kfold_splits(X, y, n_splits, seed, save_path):
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=seed)
    splits = [(train_idx.tolist(), test_idx.tolist()) for train_idx, test_idx in kf.split(X)]
    with open(save_path, 'wb') as f:
        pickle.dump(splits, f)

def load_kfold_splits(file_path):
    with open(file_path, 'rb') as f:
        return pickle.load(f)

# ----------------------------
# Training and Evaluation
# ----------------------------

def train_model(X_train, y_train, X_val, y_val, fold, input_shape):
    # Build model
    model = transunet_mamba(input_shape)
    
    # Compile with mixed precision
    optimizer = Adam(learning_rate=INIT_LR)
    
    model.compile(
        optimizer=optimizer,
        loss=bce_dice_loss,
        metrics=[dice_coef, iou_coef, 'accuracy']
    )
    
    # Callbacks
    callbacks = [
        ModelCheckpoint(
            f"mamba_unet_fold{fold}_best.keras",
            monitor='val_dice_coef',
            mode='max',
            save_best_only=True,
            save_weights_only=False
        ),
        EarlyStopping(
            monitor='val_dice_coef',
            patience=15,
            mode='max',
            restore_best_weights=True
        ),
        ReduceLROnPlateau(
            monitor='val_dice_coef',
            factor=0.5,
            patience=5,
            min_lr=1e-6,
            mode='max'
        )
    ]
    
    # Train with reduced memory footprint
    history = model.fit(
        X_train, y_train,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        validation_data=(X_val, y_val),
        callbacks=callbacks,
        verbose=1
    )
    
    return model, history

def evaluate_model(model, X_test, y_test):
    y_pred = model.predict(X_test, batch_size=BATCH_SIZE)
    y_pred_thresh = (y_pred > 0.5).astype(np.float32)
    
    dice_scores = [dice_coef(y_test[i], y_pred_thresh[i]).numpy() 
                  for i in range(len(y_test))]
    return np.mean(dice_scores)


def preprocess_patient_with_rotation(patient_folder, rotation_angles=[0, 90, 180, 270]):
    images = []
    masks_endo = []
    masks_epi = []
    masks_la = []
    
    views = ['2CH', '4CH']
    time_points = ['ED', 'ES']
    
    for view in views:
        for tp in time_points:
            base_name = os.path.basename(patient_folder)
            img_path = f"{patient_folder}/{base_name}_{view}_{tp}.nii"
            gt_path = f"{patient_folder}/{base_name}_{view}_{tp}_gt.nii"
            
            if not os.path.exists(img_path) or not os.path.exists(gt_path):
                continue
            
            try:
                img = load_nifti_image(img_path)
                gt = load_nifti_image(gt_path)
                
                img_resized = resize(img, (IMG_HEIGHT, IMG_WIDTH), 
                                  preserve_range=True, anti_aliasing=True)
                gt_resized = resize(gt, (IMG_HEIGHT, IMG_WIDTH), 
                                  preserve_range=True, anti_aliasing=False)
                
                img_resized = (img_resized - img_resized.min()) / (img_resized.max() - img_resized.min())
                
                mask_endo = (gt_resized == 1).astype(np.float32)
                mask_epi = (gt_resized == 2).astype(np.float32)
                mask_la = (gt_resized == 3).astype(np.float32)
                
                for angle in rotation_angles:
                    rotated_img = rotate(img_resized, angle, reshape=False, mode='reflect')
                    rotated_mask_endo = rotate(mask_endo, angle, reshape=False, mode='reflect')
                    rotated_mask_epi = rotate(mask_epi, angle, reshape=False, mode='reflect')
                    rotated_mask_la = rotate(mask_la, angle, reshape=False, mode='reflect')
                    
                    images.append(rotated_img[..., np.newaxis])
                    masks_endo.append(rotated_mask_endo[..., np.newaxis])
                    masks_epi.append(rotated_mask_epi[..., np.newaxis])
                    masks_la.append(rotated_mask_la[..., np.newaxis])
            
            except Exception as e:
                print(f"Error processing {base_name}: {e}")
                continue
    
    return np.array(images), np.array(masks_endo), np.array(masks_epi), np.array(masks_la)

def load_dataset(base_path):
    patient_folders = sorted([
        os.path.join(base_path, f) 
        for f in os.listdir(base_path) 
        if f.startswith('patient') and os.path.isdir(os.path.join(base_path, f))
    ])
    
    all_images = []
    all_masks_endo = []
    all_masks_epi = []
    all_masks_la = []
    
    for patient_folder in tqdm(patient_folders, desc="Loading patients"):
        images, masks_endo, masks_epi, masks_la = preprocess_patient_with_rotation(patient_folder)
        
        if images.size > 0:
            all_images.append(images)
            all_masks_endo.append(masks_endo)
            all_masks_epi.append(masks_epi)
            all_masks_la.append(masks_la)
    
    if not all_images:
        raise ValueError("No valid image data found in any patient folder!")
    
    # Stack all patient data
    X = np.concatenate(all_images, axis=0)
    y_endo = np.concatenate(all_masks_endo, axis=0)
    y_epi = np.concatenate(all_masks_epi, axis=0)
    y_la = np.concatenate(all_masks_la, axis=0)
    
    return X, y_endo, y_epi, y_la





# ----------------------------
# Main Execution
# ----------------------------

def main():
    # Load and preprocess dataset
    print("Loading and preprocessing dataset...")
    X, y_endo, _, _ = load_dataset(DATA_PATH)
    
    # Save k-fold splits
    save_kfold_splits(X, y_endo, N_FOLDS, SEED, "kfold_splits.pkl")
    splits = load_kfold_splits("kfold_splits.pkl")
    
    fold_results = []
    
    for fold, (train_idx, test_idx) in enumerate(splits):
        print(f"\n=== Fold {fold + 1}/{N_FOLDS} ===")
        
        # Split data
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y_endo[train_idx], y_endo[test_idx]
        
        # Further split into train/val
        X_train, X_val, y_train, y_val = train_test_split(
            X_train, y_train, 
            test_size=0.1, 
            random_state=SEED
        )
        
        # Train model
        print(f"Training on {len(X_train)} samples, validating on {len(X_val)} samples")
        model, history = train_model(
            X_train, y_train,
            X_val, y_val,
            fold,
            input_shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)
        )
        
        # Evaluate on test set
        print(f"Evaluating on {len(X_test)} test samples")
        model.load_weights(f"mamba_unet_fold{fold}_best.keras")  # Load best weights
        fold_dice = evaluate_model(model, X_test, y_test)
        fold_results.append(fold_dice)
        print(f"Fold {fold + 1} Dice: {fold_dice:.4f}")
        
        # Clean up to save memory
        del model
        tf.keras.backend.clear_session()
    
    # Final results
    print("\n=== Cross-Validation Results ===")
    print(f"Average Dice: {np.mean(fold_results):.4f} ± {np.std(fold_results):.4f}")
    print("Per-fold results:", [f"{x:.4f}" for x in fold_results])

if __name__ == '__main__':
    # Configure TensorFlow for better memory management
    physical_devices = tf.config.list_physical_devices('GPU')
    if physical_devices:
        try:
            tf.config.experimental.set_memory_growth(physical_devices[0], True)
        except:
            pass
    
    main()

Loading and preprocessing dataset...


Loading patients: 100%|██████████| 500/500 [05:20<00:00,  1.56it/s]


## Optimized version of the mamba implementation

In [10]:
import os
import numpy as np
import nibabel as nib
from skimage.transform import resize
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Conv2D, BatchNormalization, ReLU, MaxPooling2D,
    Conv2DTranspose, concatenate, LayerNormalization, 
    Reshape, Dense, Dropout, Lambda
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
import tensorflow as tf
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# Reduced configuration for memory
IMG_HEIGHT, IMG_WIDTH = 224, 224  # Reduced from 256
IMG_CHANNELS = 1
BATCH_SIZE = 2  # Reduced from 8
EPOCHS = 50
INIT_LR = 1e-4
DATA_PATH = "/kaggle/input/camus-dataset/database_nifti"

# Simplified metrics
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

# Memory-efficient data loading
def load_and_preprocess_single(path, target_size):
    img = nib.load(path).get_fdata()
    img = resize(img, target_size, preserve_range=True, anti_aliasing=True)
    return (img - img.min()) / (img.max() - img.min() + 1e-7)

def load_patient(patient_folder):
    images, masks = [], []
    views = ['2CH', '4CH']
    time_points = ['ED', 'ES']
    
    for view in views:
        for tp in time_points:
            base = os.path.basename(patient_folder)
            img_path = f"{patient_folder}/{base}_{view}_{tp}.nii"
            mask_path = f"{patient_folder}/{base}_{view}_{tp}_gt.nii"
            
            if os.path.exists(img_path) and os.path.exists(mask_path):
                try:
                    img = load_and_preprocess_single(img_path, (IMG_HEIGHT, IMG_WIDTH))
                    mask = load_and_preprocess_single(mask_path, (IMG_HEIGHT, IMG_WIDTH))
                    
                    images.append(img[..., np.newaxis])
                    masks.append(mask[..., np.newaxis])
                except Exception as e:
                    print(f"Error loading {base}: {str(e)}")
    
    return np.array(images), np.array(masks)

class SelectiveSSM(tf.keras.layers.Layer):
    def __init__(self, d_model, ssm_dim):
        super().__init__()
        self.d_model = d_model
        self.ssm_dim = ssm_dim
        
    def build(self, input_shape):
        # Input shape: (batch, seq_len, d_model)
        self.dense1 = Dense(self.d_model * 2)
        self.dense2 = Dense(self.d_model)
        self.built = True
        
    def call(self, x):
        x_proj = self.dense1(x)
        delta, B = tf.split(x_proj, 2, axis=-1)
        A_bar = tf.exp(delta)
        B_bar = tf.einsum('bnd,bnd->bn', delta, B)
        
        seq_len = tf.shape(x)[1]  # Dynamic sequence length

        # Initialize the hidden state with the same shape as A_bar[:, 0, :]
        h_init = tf.zeros((tf.shape(x)[0], self.d_model))  # Match d_model dimension
        
        # Define the recurrent step for tf.scan
        def step(h, t):
            h = A_bar[:, t, :] * h + B_bar[:, t, None]
            return h

        # Use tf.scan to replace the for loop
        h_states = tf.scan(step, tf.range(seq_len), initializer=h_init)
        
        # Transpose to match the required shape: (batch, seq_len, d_model)
        h_states = tf.transpose(h_states, perm=[1, 0, 2])
        
        return self.dense2(h_states)
    
    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[1], self.d_model)

def mamba_block(x, d_model, ssm_dim):
    x_norm = LayerNormalization()(x)
    x_conv = Conv2D(d_model, 3, padding='same')(x_norm)
    
    # Wrap dynamic spatial dimension computation in a Lambda layer
    def reshape_fn(x_conv):
        batch, h, w, c = tf.shape(x_conv)[0], tf.shape(x_conv)[1], tf.shape(x_conv)[2], tf.shape(x_conv)[3]
        return tf.reshape(x_conv, [batch, h * w, d_model])
    
    # Compute the output shape for reshape_fn
    def reshape_output_shape(input_shape):
        batch_size, height, width, _ = input_shape
        return (batch_size, height * width, d_model)
    
    x_flat = Lambda(reshape_fn, output_shape=reshape_output_shape)(x_conv)
    x_ssm = SelectiveSSM(d_model, ssm_dim)(x_flat)
    
    # Wrap output reshaping in a Lambda layer
    def reshape_back_fn(x_ssm):
        # Dynamically calculate dimensions using the shape of x_ssm
        batch = tf.shape(x_ssm)[0]
        seq_len = tf.shape(x_ssm)[1]
        height = width = tf.cast(tf.sqrt(tf.cast(seq_len, tf.float32)), tf.int32)  # Fix: Cast seq_len to float32 before applying sqrt
        return tf.reshape(x_ssm, [batch, height, width, d_model])
    
    # Compute the output shape for reshape_back_fn
    def reshape_back_output_shape(input_shape):
        batch_size, seq_len, _ = input_shape
        height = width = int(tf.sqrt(seq_len))  # This is static and will not be used at runtime
        return (batch_size, height, width, d_model)
    
    x_out = Lambda(reshape_back_fn, output_shape=reshape_back_output_shape)(x_ssm)
    return x + x_out
#Simplified model
def build_model(input_shape):
    inputs = Input(input_shape)
    
    # Encoder
    x1 = Conv2D(32, 3, strides=2, padding='same')(inputs)
    x1 = mamba_block(x1, 32, 16)
    
    x2 = Conv2D(64, 3, strides=2, padding='same')(x1)
    x2 = mamba_block(x2, 64, 32)
    
    # Decoder
    u1 = Conv2DTranspose(32, 3, strides=2, padding='same')(x2)
    u1 = concatenate([u1, x1])
    u1 = Conv2D(32, 3, padding='same')(u1)
    
    outputs = Conv2DTranspose(1, 3, strides=2, padding='same', activation='sigmoid')(u1)
    
    return Model(inputs, outputs)

# Training with memory monitoring
def train():
    # Load data incrementally
    patient_folders = [f for f in os.listdir(DATA_PATH) if f.startswith('patient')]
    X, y = [], []
    
    for folder in tqdm(patient_folders[:20]):  # Limit to 20 patients for memory
        imgs, masks = load_patient(os.path.join(DATA_PATH, folder))
        X.extend(imgs)
        y.extend(masks)
    
    X, y = np.array(X), np.array(y)
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
    
    # Build and train
    model = build_model((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
    model.compile(optimizer=Adam(INIT_LR), loss=dice_loss, metrics=[dice_coef])
    
    model.fit(X_train, y_train,
             batch_size=BATCH_SIZE,
             epochs=EPOCHS,
             validation_data=(X_val, y_val),
             callbacks=[
                 ModelCheckpoint("best_model.keras", save_best_only=True),
                 EarlyStopping(patience=5)
             ])

if __name__ == '__main__':
    # Configure GPU memory growth
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
        except RuntimeError as e:
            print(e)
    
    train()

100%|██████████| 20/20 [00:01<00:00, 15.13it/s]


InvalidArgumentError: Exception encountered when calling Lambda.call().

[1mValue for attr 'T' of int32 is not in the list of allowed values: bfloat16, half, float, double, complex64, complex128
	; NodeDef: {{node Sqrt}}; Op<name=Sqrt; signature=x:T -> y:T; attr=T:type,allowed=[DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128]> [Op:Sqrt] name: [0m

Arguments received by Lambda.call():
  • args=('<KerasTensor shape=(None, 12544, 32), dtype=float32, sparse=False, name=keras_tensor_111>',)
  • kwargs={'mask': 'None'}