In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model, callbacks
import torch
import torch.nn as nn
import numpy as np
import cv2
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import random
from sklearn.metrics import confusion_matrix
from tensorflow.image import ssim as tf_ssim
from tensorflow.image import psnr as tf_psnr

# Constants
IMAGE_SIZE = (256, 256)
DATASET_DIR = "/kaggle/input/fives-a-fundus-image/FIVES A Fundus Image Dataset for AI-based Vessel Segmentation"

# Utility class for He initialization (from the PyTorch implementation)
class InitWeights_He(object):
    def __init__(self, neg_slope=1e-2):
        self.neg_slope = neg_slope

    def __call__(self, module):
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d):
            module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope)
            if module.bias is not None:
                module.bias = nn.init.constant_(module.bias, 0)



class ConvBlock(layers.Layer):
    def __init__(self, filters, dropout_rate=0, **kwargs):
        super(ConvBlock, self).__init__(**kwargs)
        self.conv1 = layers.Conv2D(filters, 3, padding='same', use_bias=False)
        self.bn1 = layers.BatchNormalization()
        self.dropout1 = layers.Dropout(dropout_rate)
        self.activation1 = layers.LeakyReLU(0.1)
        
        self.conv2 = layers.Conv2D(filters, 3, padding='same', use_bias=False)
        self.bn2 = layers.BatchNormalization()
        self.dropout2 = layers.Dropout(dropout_rate)
        self.activation2 = layers.LeakyReLU(0.1)
        
    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.dropout1(x)
        x = self.activation1(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.dropout2(x)
        x = self.activation2(x)
        
        return x


class FeatureFuse(layers.Layer):
    def __init__(self, filters, **kwargs):
        super(FeatureFuse, self).__init__(**kwargs)
        self.conv11 = layers.Conv2D(filters, 1, padding='same', use_bias=False)
        self.conv33 = layers.Conv2D(filters, 3, padding='same', use_bias=False)
        self.conv33_di = layers.Conv2D(filters, 3, padding='same', use_bias=False, dilation_rate=2)
        self.bn = layers.BatchNormalization()
        self.activation = layers.LeakyReLU(0.1)
        
    def call(self, inputs):
        x1 = self.conv11(inputs)
        x2 = self.conv33(inputs)
        x3 = self.conv33_di(inputs)
        out = self.bn(x1 + x2 + x3)
        return self.activation(out)


class UpBlock(layers.Layer):
    def __init__(self, filters, **kwargs):
        super(UpBlock, self).__init__(**kwargs)
        self.up = layers.Conv2DTranspose(filters, 2, strides=2, padding='same', use_bias=False)
        self.bn = layers.BatchNormalization()
        self.activation = layers.LeakyReLU(0.1)
        
    def call(self, inputs):
        x = self.up(inputs)
        x = self.bn(x)
        x = self.activation(x)
        return x


class DownBlock(layers.Layer):
    def __init__(self, filters, **kwargs):
        super(DownBlock, self).__init__(**kwargs)
        self.down = layers.Conv2D(filters, 2, strides=2, padding='same', use_bias=False)
        self.bn = layers.BatchNormalization()
        self.activation = layers.LeakyReLU(0.1)
        
    def call(self, inputs):
        x = self.down(inputs)
        x = self.bn(x)
        x = self.activation(x)
        return x


class FRUNetBlock(layers.Layer):
    def __init__(self, filters, dropout_rate=0, is_up=False, is_down=False, fuse=True, **kwargs):
        super(FRUNetBlock, self).__init__(**kwargs)
        self.is_up = is_up
        self.is_down = is_down
        
        if fuse:
            self.fuse_layer = FeatureFuse(filters)
        else:
            self.fuse_layer = layers.Conv2D(filters, 1, padding='same')
            
        self.conv_block = ConvBlock(filters, dropout_rate)
        
        if is_up:
            self.up_block = UpBlock(filters // 2)
            
        if is_down:
            self.down_block = DownBlock(filters * 2)
            
    def call(self, inputs):
        # Apply the feature fusion or 1x1 conv
        x = self.fuse_layer(inputs)
        
        # Apply conv block
        x = self.conv_block(x)
        
        if not self.is_up and not self.is_down:
            return x
        elif self.is_up and not self.is_down:
            x_up = self.up_block(x)
            return [x, x_up]
        elif not self.is_up and self.is_down:
            x_down = self.down_block(x)
            return [x, x_down]
        else:
            x_up = self.up_block(x)
            x_down = self.down_block(x)
            return [x, x_up, x_down]


def fr_unet_model(input_shape=(256, 256, 1), num_classes=1, feature_scale=2, dropout=0.2, fuse=True, out_ave=True):
    inputs = layers.Input(input_shape)
    
    # Define filters
    filters = [64, 128, 256, 512, 1024]
    filters = [int(x / feature_scale) for x in filters]
    
    # Encoder path
    block1_3 = FRUNetBlock(filters[0], dropout, is_up=False, is_down=True, fuse=fuse)
    x1_3, x_down1_3 = block1_3(inputs)
    
    block1_2 = FRUNetBlock(filters[0], dropout, is_up=False, is_down=True, fuse=fuse)
    x1_2, x_down1_2 = block1_2(x1_3)
    
    block2_2 = FRUNetBlock(filters[1], dropout, is_up=True, is_down=True, fuse=fuse)
    x2_2, x_up2_2, x_down2_2 = block2_2(x_down1_3)
    
    # Concatenate features and continue encoding
    block1_1 = FRUNetBlock(filters[0], dropout, is_up=False, is_down=True, fuse=fuse)
    x1_1, x_down1_1 = block1_1(layers.concatenate([x1_2, x_up2_2], axis=-1))
    
    block2_1 = FRUNetBlock(filters[1], dropout, is_up=True, is_down=True, fuse=fuse)
    x2_1, x_up2_1, x_down2_1 = block2_1(layers.concatenate([x_down1_2, x2_2], axis=-1))
    
    block3_1 = FRUNetBlock(filters[2], dropout, is_up=True, is_down=True, fuse=fuse)
    x3_1, x_up3_1, x_down3_1 = block3_1(x_down2_2)
    
    # Middle pathway
    block10 = FRUNetBlock(filters[0], dropout, is_up=False, is_down=True, fuse=fuse)
    x10, x_down10 = block10(layers.concatenate([x1_1, x_up2_1], axis=-1))
    
    block20 = FRUNetBlock(filters[1], dropout, is_up=True, is_down=True, fuse=fuse)
    x20, x_up20, x_down20 = block20(layers.concatenate([x_down1_1, x2_1, x_up3_1], axis=-1))
    
    block30 = FRUNetBlock(filters[2], dropout, is_up=True, is_down=False, fuse=fuse)
    x30, x_up30 = block30(layers.concatenate([x_down2_1, x3_1], axis=-1))
    
    block40 = FRUNetBlock(filters[3], dropout, is_up=True, is_down=False, fuse=fuse)
    _, x_up40 = block40(x_down3_1)
    
    # Decoder path
    block11 = FRUNetBlock(filters[0], dropout, is_up=False, is_down=True, fuse=fuse)
    x11, x_down11 = block11(layers.concatenate([x10, x_up20], axis=-1))
    
    block21 = FRUNetBlock(filters[1], dropout, is_up=True, is_down=False, fuse=fuse)
    x21, x_up21 = block21(layers.concatenate([x_down10, x20, x_up30], axis=-1))
    
    block31 = FRUNetBlock(filters[2], dropout, is_up=True, is_down=False, fuse=fuse)
    _, x_up31 = block31(layers.concatenate([x_down20, x30, x_up40], axis=-1))
    
    block12 = FRUNetBlock(filters[0], dropout, is_up=False, is_down=False, fuse=fuse)
    x12 = block12(layers.concatenate([x11, x_up21], axis=-1))
    
    block22 = FRUNetBlock(filters[1], dropout, is_up=True, is_down=False, fuse=fuse)
    _, x_up22 = block22(layers.concatenate([x_down11, x21, x_up31], axis=-1))
    
    block13 = FRUNetBlock(filters[0], dropout, is_up=False, is_down=False, fuse=fuse)
    x13 = block13(layers.concatenate([x12, x_up22], axis=-1))
    
    # Final outputs
    final1 = layers.Conv2D(num_classes, 1, padding='same')(x1_1)
    final2 = layers.Conv2D(num_classes, 1, padding='same')(x10)
    final3 = layers.Conv2D(num_classes, 1, padding='same')(x11)
    final4 = layers.Conv2D(num_classes, 1, padding='same')(x12)
    final5 = layers.Conv2D(num_classes, 1, padding='same')(x13)
    
    if out_ave:
        # Average the outputs
        output = layers.average([final1, final2, final3, final4, final5])
    else:
        output = final5
    
    # Apply sigmoid activation for binary segmentation
    output = layers.Activation('sigmoid')(output)
    
    model = Model(inputs=inputs, outputs=output)
    return model


# Data loading functions (from the original script)
def load_png_files(dataset_dir):
    images, masks = [], []
    
    # Load training data
    train_img_dir = os.path.join(dataset_dir, "train/Original")
    train_mask_dir = os.path.join(dataset_dir, "train/Ground truth")
    
    train_img_files = sorted(os.listdir(train_img_dir))
    
    for img_file in train_img_files:
        if img_file.endswith(".png"):
            # Load image
            img_path = os.path.join(train_img_dir, img_file)
            img = cv2.imread(img_path)
            
            # Load corresponding mask (assuming same filename in Ground truth folder)
            mask_path = os.path.join(train_mask_dir, img_file)
            if os.path.exists(mask_path):
                mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                
                # Resize
                img = cv2.resize(img, IMAGE_SIZE)
                mask = cv2.resize(mask, IMAGE_SIZE)
                
                # Convert to normalized grayscale if needed
                if len(img.shape) == 3 and img.shape[2] == 3:
                    # Extract green channel (most informative for retinal images)
                    img = img[:, :, 1]
                
                # Normalize and preprocess image
                img = img.astype(np.float32) / 255.0
                
                # Ensure binary mask (threshold if not already binary)
                mask = (mask > 127).astype(np.uint8)
                
                # Add channel dimension
                img = img[..., np.newaxis]
                mask = mask[..., np.newaxis]
                
                images.append(img)
                masks.append(mask)
    
    # Optionally, also load test data if needed
    test_img_dir = os.path.join(dataset_dir, "test/Original")
    test_mask_dir = os.path.join(dataset_dir, "test/Ground truth")
    
    if os.path.exists(test_img_dir) and os.path.exists(test_mask_dir):
        test_img_files = sorted(os.listdir(test_img_dir))
        
        for img_file in test_img_files:
            if img_file.endswith(".png"):
                # Load image
                img_path = os.path.join(test_img_dir, img_file)
                img = cv2.imread(img_path)
                
                # Load corresponding mask
                mask_path = os.path.join(test_mask_dir, img_file)
                if os.path.exists(mask_path):
                    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                    
                    # Resize
                    img = cv2.resize(img, IMAGE_SIZE)
                    mask = cv2.resize(mask, IMAGE_SIZE)
                    
                    # Convert to normalized grayscale if needed
                    if len(img.shape) == 3 and img.shape[2] == 3:
                        # Extract green channel (most informative for retinal images)
                        img = img[:, :, 1]
                    
                    # Normalize and preprocess image
                    img = img.astype(np.float32) / 255.0
                    
                    # Ensure binary mask
                    mask = (mask > 127).astype(np.uint8)
                    
                    # Add channel dimension
                    img = img[..., np.newaxis]
                    mask = mask[..., np.newaxis]
                    
                    images.append(img)
                    masks.append(mask)
    
    return np.array(images), np.array(masks)

# Data Augmentation function 
def augment_data(images, masks):
    augmented_images = []
    augmented_masks = []
    
    for img, mask in zip(images, masks):
        # Original
        augmented_images.append(img)
        augmented_masks.append(mask)
      
        # Rotation 90 degrees
        rot_img = np.rot90(img, k=1, axes=(0, 1))
        rot_mask = np.rot90(mask, k=1, axes=(0, 1))
        augmented_images.append(rot_img)
        augmented_masks.append(rot_mask)
        
        # Flipping horizontally (common for retinal images)
        flip_img = np.flip(img, axis=1)
        flip_mask = np.flip(mask, axis=1)
        augmented_images.append(flip_img)
        augmented_masks.append(flip_mask)
        
        # Brightness variation (random)
        brightness = np.random.uniform(0.8, 1.2)
        bright_img = np.clip(img * brightness, 0, 1)
        augmented_images.append(bright_img)
        augmented_masks.append(mask)
        
        # Contrast adjustment (important for retinal vessels)
        alpha = np.random.uniform(0.9, 1.1)  # Simple contrast control
        beta = np.random.uniform(-0.1, 0.1)  # Simple brightness control
        contrast_img = np.clip(alpha * img + beta, 0, 1)
        augmented_images.append(contrast_img)
        augmented_masks.append(mask)

    return np.array(augmented_images), np.array(augmented_masks)

# Custom Dice Loss
def dice_loss(y_true, y_pred):
    smooth = 1e-6
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    dice = (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
    return 1 - dice

# Combined loss function
def bce_dice_loss(y_true, y_pred):
    return tf.keras.losses.binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)

def dice_coefficient(y_true, y_pred, smooth=1e-6):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)

def iou_score(y_true, y_pred, smooth=1e-6):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    union = tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) - intersection
    
    return (intersection + smooth) / (union + smooth)

# Additional metrics
def mse(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    return tf.reduce_mean(tf.square(y_true - y_pred))

def mae(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    return tf.reduce_mean(tf.abs(y_true - y_pred))

def psnr(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    return tf_psnr(y_true, y_pred, max_val=1.0)

def ssim(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    return tf.reduce_mean(tf_ssim(y_true, y_pred, max_val=1.0))

# Load and process data
print("Loading dataset...")
images, masks = load_png_files(DATASET_DIR)
print(f"Original dataset: {len(images)} images, {len(masks)} masks")

# Augment training data
print("Augmenting data...")
augmented_images, augmented_masks = augment_data(images, masks)
print(f"Augmented dataset: {len(augmented_images)} images, {len(augmented_masks)} masks")

# Split into training, validation, and test sets
X_train, X_temp, y_train, y_temp = train_test_split(augmented_images, augmented_masks, test_size=0.3, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=1/3, random_state=42)
print(f"Train: {len(X_train)}, Validation: {len(X_val)}, Test: {len(X_test)}")

# Create FR-UNet model
model = fr_unet_model(
    input_shape=(256, 256, 1),
    num_classes=1,
    feature_scale=2,
    dropout=0.2,
    fuse=True,
    out_ave=True
)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss=bce_dice_loss,
    metrics=['accuracy', dice_coefficient, iou_score, mse, mae, psnr, ssim]
)

# Display model summary
model.summary()

# Callbacks
early_stopping = callbacks.EarlyStopping(
    monitor='val_dice_coefficient',
    mode='max',
    patience=10,
    restore_best_weights=True,
    verbose=1
)

lr_scheduler = callbacks.ReduceLROnPlateau(
    monitor='val_dice_coefficient',
    mode='max',
    factor=0.5,
    patience=3,
    min_lr=1e-6,
    verbose=1
)

checkpoint = callbacks.ModelCheckpoint(
    "best_fr_unet_retina_vessel_segmentation_model.keras",
    monitor='val_dice_coefficient',
    mode='max',
    save_best_only=True,
    verbose=1
)

# TensorBoard callback for visualization
tensorboard_callback = callbacks.TensorBoard(
    log_dir="./logs",
    histogram_freq=1,
    update_freq='epoch'
)

# Train model
print("Training FR-UNet model...")
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=100,
    batch_size=16,
    callbacks=[early_stopping, lr_scheduler, checkpoint, tensorboard_callback]
)

# Save final model
model.save("fr_unet_retina_vessel_segmentation_final.keras")

# Visualize training history
def plot_training_history(history):
    # Plot loss and standard metrics
    plt.figure(figsize=(20, 15))
    
    plt.subplot(3, 3, 1)
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(3, 3, 2)
    plt.plot(history.history['accuracy'], label='Training Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title('Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.subplot(3, 3, 3)
    plt.plot(history.history['dice_coefficient'], label='Training Dice')
    plt.plot(history.history['val_dice_coefficient'], label='Validation Dice')
    plt.title('Dice Coefficient')
    plt.xlabel('Epoch')
    plt.ylabel('Dice')
    plt.legend()
    
    # Plot additional metrics: IoU, MSE, MAE
    plt.subplot(3, 3, 4)
    plt.plot(history.history['iou_score'], label='Training IoU')
    plt.plot(history.history['val_iou_score'], label='Validation IoU')
    plt.title('IoU Score')
    plt.xlabel('Epoch')
    plt.ylabel('IoU')
    plt.legend()
    
    plt.subplot(3, 3, 5)
    plt.plot(history.history['mse'], label='Training MSE')
    plt.plot(history.history['val_mse'], label='Validation MSE')
    plt.title('Mean Squared Error')
    plt.xlabel('Epoch')
    plt.ylabel('MSE')
    plt.legend()
    
    plt.subplot(3, 3, 6)
    plt.plot(history.history['mae'], label='Training MAE')
    plt.plot(history.history['val_mae'], label='Validation MAE')
    plt.title('Mean Absolute Error')
    plt.xlabel('Epoch')
    plt.ylabel('MAE')
    plt.legend()
    
    # Plot PSNR and SSIM
    plt.subplot(3, 3, 7)
    plt.plot(history.history['psnr'], label='Training PSNR')
    plt.plot(history.history['val_psnr'], label='Validation PSNR')
    plt.title('Peak Signal-to-Noise Ratio')
    plt.xlabel('Epoch')
    plt.ylabel('PSNR (dB)')
    plt.legend()
    
    plt.subplot(3, 3, 8)
    plt.plot(history.history['ssim'], label='Training SSIM')
    plt.plot(history.history['val_ssim'], label='Validation SSIM')
    plt.title('Structural Similarity Index')
    plt.xlabel('Epoch')
    plt.ylabel('SSIM')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('frunet_training_history.png')
    plt.show()

print("Plotting training history...")
plot_training_history(history)

# Model evaluation
print("Evaluating FR-UNet model...")
evaluation = model.evaluate(X_test, y_test)
print(f"Test Loss: {evaluation[0]}")
print(f"Test Accuracy: {evaluation[1]}")
print(f"Test Dice Coefficient: {evaluation[2]}")
print(f"Test IoU Score: {evaluation[3]}")
print(f"Test MSE: {evaluation[4]}")
print(f"Test MAE: {evaluation[5]}")
print(f"Test PSNR: {evaluation[6]}")
print(f"Test SSIM: {evaluation[7]}")

# Generate predictions
y_pred = model.predict(X_test)
y_pred_bin = (y_pred > 0.5).astype(np.uint8)
y_test_bin = (y_test > 0.5).astype(np.uint8)

# Calculate metrics
dice_scores = [dice_coefficient(y_true, y_pred).numpy() for y_true, y_pred in zip(y_test_bin, y_pred_bin)]
iou_scores = [iou_score(y_true, y_pred).numpy() for y_true, y_pred in zip(y_test_bin, y_pred_bin)]
mse_scores = [mse(y_true, y_pred).numpy() for y_true, y_pred in zip(y_test, y_pred)]
mae_scores = [mae(y_true, y_pred).numpy() for y_true, y_pred in zip(y_test, y_pred)]
psnr_scores = [psnr(y_true, y_pred).numpy() for y_true, y_pred in zip(y_test, y_pred)]
ssim_scores = [ssim(y_true, y_pred).numpy() for y_true, y_pred in zip(y_test, y_pred)]

# Calculate means
dice_mean = np.mean(dice_scores)
iou_mean = np.mean(iou_scores)
mse_mean = np.mean(mse_scores)
mae_mean = np.mean(mae_scores)
psnr_mean = np.mean(psnr_scores)
ssim_mean = np.mean(ssim_scores)

print(f"Mean Dice Score: {dice_mean:.4f}")
print(f"Mean IoU Score: {iou_mean:.4f}")
print(f"Mean MSE: {mse_mean:.4f}")
print(f"Mean MAE: {mae_mean:.4f}")
print(f"Mean PSNR: {psnr_mean:.4f} dB")
print(f"Mean SSIM: {ssim_mean:.4f}")

# Confusion matrix
flat_y_test = y_test_bin.flatten()
flat_y_pred = y_pred_bin.flatten()
cm = confusion_matrix(flat_y_test, flat_y_pred)

tn, fp, fn, tp = cm.ravel()
specificity = tn / (tn + fp)
sensitivity = tp / (tp + fn)
precision = tp / (tp + fp)
accuracy = (tp + tn) / (tp + tn + fp + fn)

print("Confusion Matrix:")
print(cm)
print(f"Accuracy: {accuracy:.4f}")
print(f"Sensitivity (Recall): {sensitivity:.4f}")
print(f"Specificity: {specificity:.4f}")
print(f"Precision: {precision:.4f}")

# Function to visualize random predictions
def visualize_random_predictions(X_test, y_test, y_pred, num_samples=5):
    indices = random.sample(range(len(X_test)), num_samples)
    
    plt.figure(figsize=(15, 4*num_samples))
    
    for i, idx in enumerate(indices):
        # Original image
        plt.subplot(num_samples, 3, i*3 + 1)
        plt.imshow(X_test[idx].squeeze(), cmap='gray')
        plt.title(f"Original Image {idx}")
        plt.axis('off')
        
        # Ground truth mask
        plt.subplot(num_samples, 3, i*3 + 2)
        plt.imshow(y_test[idx].squeeze(), cmap='viridis')
        plt.title(f"Ground Truth Mask {idx}")
        plt.axis('off')
        
        # Predicted mask
        plt.subplot(num_samples, 3, i*3 + 3)
        plt.imshow(y_pred[idx].squeeze(), cmap='viridis')
        dice_val = dice_coefficient(y_test[idx], y_pred_bin[idx]).numpy()
        psnr_val = psnr(y_test[idx], y_pred[idx]).numpy()
        ssim_val = ssim(y_test[idx], y_pred[idx]).numpy()
        plt.title(f"Predicted Mask {idx}\nDice: {dice_val:.4f}\nPSNR: {psnr_val:.2f}, SSIM: {ssim_val:.4f}")
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('random_predictions.png')
    plt.show()

# Visualize predictions on random test samples
print("Visualizing random predictions...")
visualize_random_predictions(X_test, y_test, y_pred)

# Function to visualize overlaid predictions
def visualize_overlaid_predictions(X_test, y_test, y_pred, num_samples=5):
    indices = random.sample(range(len(X_test)), num_samples)
    
    plt.figure(figsize=(15, 6*num_samples))
    
    for i, idx in enumerate(indices):
        img = X_test[idx].squeeze()
        true_mask = y_test[idx].squeeze()
        pred_mask = y_pred[idx].squeeze() > 0.5
        
        # Create an RGB version of the grayscale image
        img_rgb = np.repeat(img[:, :, np.newaxis], 3, axis=2)
        
        # Create overlay images
        true_overlay = img_rgb.copy()
        pred_overlay = img_rgb.copy()
        
        # Add red channel for true mask
        true_overlay[true_mask > 0, 0] = 1.0
        true_overlay[true_mask > 0, 1] = 0.0
        true_overlay[true_mask > 0, 2] = 0.0
        
        # Add green channel for predicted mask
        pred_overlay[pred_mask > 0, 0] = 0.0
        pred_overlay[pred_mask > 0, 1] = 1.0
        pred_overlay[pred_mask > 0, 2] = 0.0
        
        # Comparison overlay
        comp_overlay = img_rgb.copy()
        # True positive: yellow
        comp_overlay[(true_mask > 0) & (pred_mask > 0), 0] = 1.0
        comp_overlay[(true_mask > 0) & (pred_mask > 0), 1] = 1.0
        comp_overlay[(true_mask > 0) & (pred_mask > 0), 2] = 0.0
        # False positive: green
        comp_overlay[(true_mask == 0) & (pred_mask > 0), 0] = 0.0
        comp_overlay[(true_mask == 0) & (pred_mask > 0), 1] = 1.0
        comp_overlay[(true_mask == 0) & (pred_mask > 0), 2] = 0.0
        # False negative: red
        comp_overlay[(true_mask > 0) & (pred_mask == 0), 0] = 1.0
        comp_overlay[(true_mask > 0) & (pred_mask == 0), 1] = 0.0
        comp_overlay[(true_mask > 0) & (pred_mask == 0), 2] = 0.0
        
        # Original image with true mask overlay
        plt.subplot(num_samples, 3, i*3 + 1)
        plt.imshow(true_overlay)
        plt.title(f"True Mask Overlay")
        plt.axis('off')
        
        # Original image with predicted mask overlay
        plt.subplot(num_samples, 3, i*3 + 2)
        plt.imshow(pred_overlay)
        plt.title(f"Predicted Mask Overlay")
        plt.axis('off')
        
        # Comparison overlay
        plt.subplot(num_samples, 3, i*3 + 3)
        plt.imshow(comp_overlay)
        dice_val = dice_coefficient(y_test[idx], y_pred_bin[idx]).numpy()
        mse_val = mse(y_test[idx], y_pred[idx]).numpy()
        ssim_val = ssim(y_test[idx], y_pred[idx]).numpy()
        plt.title(f"Comparison Overlay\nYellow: TP, Green: FP, Red: FN\nDice: {dice_val:.4f}, MSE: {mse_val:.4f}, SSIM: {ssim_val:.4f}")
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('overlaid_predictions.png')
    plt.show()

# Visualize overlaid predictions
print("Visualizing overlaid predictions...")
visualize_overlaid_predictions(X_test, y_test, y_pred_bin)

# Visualize metric distributions
def visualize_metric_distributions():
    plt.figure(figsize=(20, 10))
    
    # Dice Coefficient distribution
    plt.subplot(2, 4, 1)
    plt.hist(dice_scores, bins=20, alpha=0.75, color='blue')
    plt.axvline(dice_mean, color='red', linestyle='dashed', linewidth=2)
    plt.title(f'Dice Coefficient (Mean: {dice_mean:.4f})')
    plt.xlabel('Dice Coefficient')
    plt.ylabel('Frequency')
    
    # IoU Score distribution
    plt.subplot(2, 4, 2)
    plt.hist(iou_scores, bins=20, alpha=0.75, color='green')
    plt.axvline(iou_mean, color='red', linestyle='dashed', linewidth=2)
    plt.title(f'IoU Score (Mean: {iou_mean:.4f})')
    plt.xlabel('IoU Score')
    plt.ylabel('Frequency')
    
    # MSE distribution
    plt.subplot(2, 4, 3)
    plt.hist(mse_scores, bins=20, alpha=0.75, color='purple')
    plt.axvline(mse_mean, color='red', linestyle='dashed', linewidth=2)
    plt.title(f'MSE (Mean: {mse_mean:.4f})')
    plt.xlabel('MSE')
    plt.ylabel('Frequency')
    
    # MAE distribution
    plt.subplot(2, 4, 4)
    plt.hist(mae_scores, bins=20, alpha=0.75, color='orange')
    plt.axvline(mae_mean, color='red', linestyle='dashed', linewidth=2)
    plt.title(f'MAE (Mean: {mae_mean:.4f})')
    plt.xlabel('MAE')
    plt.ylabel('Frequency')
    
    # PSNR distribution
    plt.subplot(2, 4, 5)
    plt.hist(psnr_scores, bins=20, alpha=0.75, color='brown')
    plt.axvline(psnr_mean, color='red', linestyle='dashed', linewidth=2)
    plt.title(f'PSNR (Mean: {psnr_mean:.4f} dB)')
    plt.xlabel('PSNR (dB)')
    plt.ylabel('Frequency')
    
    # SSIM distribution
    plt.subplot(2, 4, 6)
    plt.hist(ssim_scores, bins=20, alpha=0.75, color='cyan')
    plt.axvline(ssim_mean, color='red', linestyle='dashed', linewidth=2)
    plt.title(f'SSIM (Mean: {ssim_mean:.4f})')
    plt.xlabel('SSIM')
    plt.ylabel('Frequency')
    
    # Sensitivity and Specificity as bar chart
    plt.subplot(2, 4, 7)
    metrics = ['Sensitivity', 'Specificity', 'Precision', 'Accuracy']
    values = [sensitivity, specificity, precision, accuracy]
    colors = ['blue', 'green', 'purple', 'orange']
    plt.bar(metrics, values, color=colors)
    plt.ylim(0, 1.0)
    plt.title('Binary Classification Metrics')
    plt.ylabel('Score')
    
    # Add confusion matrix as heatmap
    plt.subplot(2, 4, 8)
    cm_display = np.array([[tn, fp], [fn, tp]])
    plt.imshow(cm_display, cmap='Blues')
    plt.colorbar()
    plt.xticks([0, 1], ['Negative', 'Positive'])
    plt.yticks([0, 1], ['Negative', 'Positive'])
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix')
    
    # Add text annotations to confusion matrix
    plt.text(0, 0, f'TN: {tn}', ha='center', va='center')
    plt.text(1, 0, f'FP: {fp}', ha='center', va='center')
    plt.text(0, 1, f'FN: {fn}', ha='center', va='center')
    plt.text(1, 1, f'TP: {tp}', ha='center', va='center')
    
    plt.tight_layout()
    plt.savefig('metric_distributions.png')
    plt.show()

# Call the visualization function
print("Visualizing metric distributions...")
visualize_metric_distributions()

# Save final results to a text file
with open('retina_segmentation_results.txt', 'w') as f:
    f.write(f"Test Loss: {evaluation[0]:.4f}\n")
    f.write(f"Test Accuracy: {evaluation[1]:.4f}\n")
    f.write(f"Test Dice Coefficient: {evaluation[2]:.4f}\n")
    f.write(f"Test IoU Score: {evaluation[3]:.4f}\n")
    f.write(f"Test MSE: {evaluation[4]:.4f}\n")
    f.write(f"Test MAE: {evaluation[5]:.4f}\n")
    f.write(f"Test PSNR: {evaluation[6]:.4f}\n")
    f.write(f"Test SSIM: {evaluation[7]:.4f}\n\n")
    
    f.write(f"Mean Dice Score: {dice_mean:.4f}\n")
    f.write(f"Mean IoU Score: {iou_mean:.4f}\n")
    f.write(f"Mean MSE: {mse_mean:.4f}\n")
    f.write(f"Mean MAE: {mae_mean:.4f}\n")
    f.write(f"Mean PSNR: {psnr_mean:.4f} dB\n")
    f.write(f"Mean SSIM: {ssim_mean:.4f}\n\n")
    
    f.write("Confusion Matrix:\n")
    f.write(f"{cm}\n\n")
    f.write(f"Accuracy: {accuracy:.4f}\n")
    f.write(f"Sensitivity (Recall): {sensitivity:.4f}\n")
    f.write(f"Specificity: {specificity:.4f}\n")
    f.write(f"Precision: {precision:.4f}\n")

print("Analysis complete. Results saved to 'retina_segmentation_results.txt'")

2025-05-12 03:08:09.629702: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747019289.816672      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747019289.888256      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Loading dataset...
Original dataset: 800 images, 800 masks
Augmenting data...
Augmented dataset: 4000 images, 4000 masks
Train: 2800, Validation: 800, Test: 400


I0000 00:00:1747019459.339156      31 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 15513 MB memory:  -> device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0


Training FR-UNet model...
Epoch 1/100


I0000 00:00:1747019523.756427      88 service.cc:148] XLA service 0x7a7d10028130 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1747019523.757159      88 service.cc:156]   StreamExecutor device (0): Tesla P100-PCIE-16GB, Compute Capability 6.0
W0000 00:00:1747019526.708190      88 assert_op.cc:38] Ignoring Assert operator PSNR/Assert/Assert
W0000 00:00:1747019526.709089      88 assert_op.cc:38] Ignoring Assert operator PSNR/Assert_1/Assert
W0000 00:00:1747019526.712137      88 assert_op.cc:38] Ignoring Assert operator SSIM/Assert/Assert
W0000 00:00:1747019526.712940      88 assert_op.cc:38] Ignoring Assert operator SSIM/Assert_1/Assert
W0000 00:00:1747019526.713865      88 assert_op.cc:38] Ignoring Assert operator SSIM/Assert_2/Assert
W0000 00:00:1747019526.714464      88 assert_op.cc:38] Ignoring Assert operator SSIM/Assert_3/Assert
I0000 00:00:1747019528.538683      88 cuda_dnn.cc:529] Loaded cuDNN version 90300
I0000 00:00:1747019

[1m175/175[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 908ms/step - accuracy: 0.7816 - dice_coefficient: 0.1895 - iou_score: 0.1059 - loss: 1.3274 - mae: 0.3822 - mse: 0.1681 - psnr: 8.0263 - ssim: 0.0313

W0000 00:00:1747019757.517159      88 assert_op.cc:38] Ignoring Assert operator PSNR/Assert/Assert
W0000 00:00:1747019757.518068      88 assert_op.cc:38] Ignoring Assert operator PSNR/Assert_1/Assert
W0000 00:00:1747019757.520835      88 assert_op.cc:38] Ignoring Assert operator SSIM/Assert/Assert
W0000 00:00:1747019757.521494      88 assert_op.cc:38] Ignoring Assert operator SSIM/Assert_1/Assert
W0000 00:00:1747019757.522242      88 assert_op.cc:38] Ignoring Assert operator SSIM/Assert_2/Assert
W0000 00:00:1747019757.522790      88 assert_op.cc:38] Ignoring Assert operator SSIM/Assert_3/Assert



Epoch 1: val_dice_coefficient improved from -inf to 0.13123, saving model to best_fr_unet_retina_vessel_segmentation_model.keras
[1m175/175[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m306s[0m 1s/step - accuracy: 0.7821 - dice_coefficient: 0.1899 - iou_score: 0.1061 - loss: 1.3264 - mae: 0.3818 - mse: 0.1679 - psnr: 8.0350 - ssim: 0.0315 - val_accuracy: 0.8217 - val_dice_coefficient: 0.1312 - val_iou_score: 0.0702 - val_loss: 1.4315 - val_mae: 0.4195 - val_mse: 0.1883 - val_psnr: 7.2867 - val_ssim: 0.0054 - learning_rate: 1.0000e-04
Epoch 2/100
[1m175/175[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 907ms/step - accuracy: 0.9398 - dice_coefficient: 0.3679 - iou_score: 0.2258 - loss: 0.8893 - mae: 0.1970 - mse: 0.0616 - psnr: 12.1854 - ssim: 0.1332
Epoch 2: val_dice_coefficient improved from 0.13123 to 0.22754, saving model to best_fr_unet_retina_vessel_segmentation_model.keras
[1m175/175[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m174s[0m 994ms/step - accuracy: 0.93

In [None]:
def visualize_random_predictions(X_test, y_test, y_pred, num_samples=5):
    indices = random.sample(range(len(X_test)), num_samples)
    
    plt.figure(figsize=(15, 4*num_samples))
    
    for i, idx in enumerate(indices):
        # Original image
        plt.subplot(num_samples, 3, i*3 + 1)
        plt.imshow(X_test[idx].squeeze(), cmap='gray')
        plt.title(f"Original Image {idx}")
        plt.axis('off')
        
        # Ground truth mask in grayscale
        plt.subplot(num_samples, 3, i*3 + 2)
        plt.imshow(y_test[idx].squeeze(), cmap='gray')  # Change 'viridis' to 'gray'
        plt.title(f"Ground Truth Mask {idx}")
        plt.axis('off')
        
        # Predicted mask in grayscale
        plt.subplot(num_samples, 3, i*3 + 3)
        plt.imshow(y_pred[idx].squeeze(), cmap='gray')  # Change 'viridis' to 'gray'
        dice_val = dice_coefficient(y_test[idx], y_pred_bin[idx]).numpy()
        psnr_val = psnr(y_test[idx], y_pred[idx]).numpy()
        ssim_val = ssim(y_test[idx], y_pred[idx]).numpy()
        plt.title(f"Predicted Mask {idx}\nDice: {dice_val:.4f}\nPSNR: {psnr_val:.2f}, SSIM: {ssim_val:.4f}")
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('random_predictions.png')
    plt.show()

# Visualize predictions on random test samples
print("Visualizing random predictions...")
visualize_random_predictions(X_test, y_test, y_pred)


In [None]:
def visualize_random_predictions(X_test, y_test, y_pred, num_samples=5):
    indices = random.sample(range(len(X_test)), num_samples)
    
    plt.figure(figsize=(15, 4*num_samples))
    
    for i, idx in enumerate(indices):
        # Original image
        plt.subplot(num_samples, 3, i*3 + 1)
        plt.imshow(X_test[idx].squeeze(), cmap='gray')
        plt.title(f"Original Image {idx}")
        plt.axis('off')
        
        # Ground truth mask
        plt.subplot(num_samples, 3, i*3 + 2)
        plt.imshow(y_test[idx].squeeze(), cmap='viridis')
        plt.title(f"Ground Truth Mask {idx}")
        plt.axis('off')
        
        # Predicted mask in grayscale
        plt.subplot(num_samples, 3, i*3 + 3)
        plt.imshow(y_pred[idx].squeeze(), cmap='gray')  # Change 'viridis' to 'gray'
        dice_val = dice_coefficient(y_test[idx], y_pred_bin[idx]).numpy()
        psnr_val = psnr(y_test[idx], y_pred[idx]).numpy()
        ssim_val = ssim(y_test[idx], y_pred[idx]).numpy()
        plt.title(f"Predicted Mask {idx}\nDice: {dice_val:.4f}\nPSNR: {psnr_val:.2f}, SSIM: {ssim_val:.4f}")
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('random_predictions.png')
    plt.show()

# Visualize predictions on random test samples
print("Visualizing random predictions...")
visualize_random_predictions(X_test, y_test, y_pred)


In [None]:
import tensorflow as tf
import numpy as np

# Example function to calculate BCE, Dice, and DiceBCE metrics
def bce_loss(y_true, y_pred):
    return tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_true, y_pred))

import tensorflow as tf

def dice_coefficient(y_true, y_pred, smooth=1e-6):
    # Cast y_true and y_pred to float32 for safe arithmetic operations
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    intersection = tf.reduce_sum(y_true * y_pred)
    union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred)
    
    return (2.0 * intersection + smooth) / (union + smooth)


def dice_bce_loss(y_true, y_pred, smooth=1e-6):
    bce = bce_loss(y_true, y_pred)
    dice = dice_coefficient(y_true, y_pred, smooth)
    return bce + (1 - dice)

# Function to calculate clDice (class-wise Dice, here as an example)
def classwise_dice(y_true, y_pred, num_classes=2, smooth=1e-6):
    dice_scores = []
    for c in range(num_classes):
        y_true_class = tf.cast(tf.equal(y_true, c), tf.float32)
        y_pred_class = tf.cast(tf.equal(y_pred, c), tf.float32)
        dice = dice_coefficient(y_true_class, y_pred_class, smooth)
        dice_scores.append(dice)
    return np.mean(dice_scores)

# Evaluate metrics after training
def evaluate_metrics(model, X_test, y_test):
    y_pred = model.predict(X_test)
    y_pred_bin = (y_pred > 0.5).astype(np.float32)  # For binary masks

    # Calculate metrics
    bce = bce_loss(y_test, y_pred).numpy()
    dice = dice_coefficient(y_test, y_pred_bin).numpy()
    dice_bce = dice_bce_loss(y_test, y_pred_bin).numpy()
    cl_dice = classwise_dice(y_test, y_pred_bin)

    # Print the results on the screen
    print(f"\nMetrics after evaluation:")
    print(f"BCE: {bce:.4f}")
    print(f"Dice: {dice:.4f}")
    print(f"DiceBCE: {dice_bce:.4f}")
    print(f"clDice: {cl_dice:.4f}")

# Example of evaluating the model after training
print("Evaluating model...\n")
evaluate_metrics(model, X_test, y_test)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import tensorflow as tf

# Define Dice coefficient
def dice_coefficient(y_true, y_pred_bin):
    y_true_f = tf.cast(tf.reshape(y_true, [-1]), tf.float32)
    y_pred_f = tf.cast(tf.reshape(y_pred_bin, [-1]), tf.float32)
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + 1e-7) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + 1e-7)

# Function to visualize first N predictions
def visualize_first_predictions(X_test, y_test, y_pred, y_pred_bin, num_samples=5):
    plt.figure(figsize=(15, 4 * num_samples))

    for i in range(num_samples):
        # Original image
        plt.subplot(num_samples, 3, i * 3 + 1)
        plt.imshow(X_test[i].squeeze(), cmap='gray')
        plt.title(f"Original Image {i}")
        plt.axis('off')

        # Ground truth mask
        plt.subplot(num_samples, 3, i * 3 + 2)
        plt.imshow(y_test[i].squeeze(), cmap='gray')
        plt.title(f"Ground Truth Mask {i}")
        plt.axis('off')

        # Predicted mask
        plt.subplot(num_samples, 3, i * 3 + 3)
        plt.imshow(y_pred[i].squeeze(), cmap='gray')
        dice_val = dice_coefficient(y_test[i], y_pred_bin[i]).numpy()
        psnr_val = psnr(y_test[i].squeeze(), y_pred[i].squeeze())
        ssim_val = ssim(y_test[i].squeeze(), y_pred[i].squeeze(), data_range=1.0, win_size=3)
        plt.title(f"Predicted Mask {i}\nDice: {dice_val:.4f}\nPSNR: {psnr_val:.2f}, SSIM: {ssim_val:.4f}")
        plt.axis('off')

    plt.tight_layout()
    plt.savefig('first_5_predictions.png')
    plt.show()

# === Run prediction and visualize ===

print("Generating predictions on test set...")
y_pred = model.predict(X_test)
y_pred_bin = (y_pred > 0.5).astype(np.uint8)

print("Visualizing first five predictions...")
visualize_first_predictions(X_test, y_test, y_pred, y_pred_bin)
