# Eye Vessel Segmentation - U-Net Model Training

This notebook contains the training pipeline for the U-Net model used in the hackathon solution.

## Overview
- **Task**: Semantic segmentation of blood vessels in slit-lamp eye images
- **Model**: U-Net architecture optimized for medical image segmentation
- **Data**: Slit-lamp images with GeoJSON annotations
- **Metric**: F1 Score optimization

In [None]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.model_selection import train_test_split
import glob
from tqdm import tqdm

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

## Data Loading and Preprocessing

In [None]:
# Configuration
IMG_SIZE = 512
BATCH_SIZE = 16
EPOCHS = 50
LEARNING_RATE = 1e-4

# Paths
TRAIN_DATA_PATH = '../dataset/train_dataset_mc'
MODEL_SAVE_PATH = '../backend/models/unet_eye_segmentation.keras'

In [None]:
def load_geojson_mask(geojson_path, image_shape):
    """
    Convert GeoJSON annotations to binary mask.
    """
    with open(geojson_path, 'r') as f:
        geojson_data = json.load(f)
    
    mask = np.zeros(image_shape[:2], dtype=np.uint8)
    
    for feature in geojson_data.get('features', []):
        geometry = feature.get('geometry', {})
        if geometry.get('type') == 'Polygon':
            coordinates = geometry.get('coordinates', [])
            for polygon in coordinates:
                pts = np.array(polygon, dtype=np.int32)
                cv2.fillPoly(mask, [pts], 255)
        elif geometry.get('type') == 'MultiPolygon':
            coordinates = geometry.get('coordinates', [])
            for multi_polygon in coordinates:
                for polygon in multi_polygon:
                    pts = np.array(polygon, dtype=np.int32)
                    cv2.fillPoly(mask, [pts], 255)
    
    return mask

def load_dataset(data_path):
    """
    Load images and corresponding masks from the dataset.
    """
    images = []
    masks = []
    
    # Get all PNG files
    image_files = glob.glob(os.path.join(data_path, '*.png'))
    
    for image_path in tqdm(image_files, desc="Loading dataset"):
        # Load image
        image = cv2.imread(image_path)
        if image is None:
            continue
        
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Get corresponding GeoJSON file
        base_name = os.path.splitext(os.path.basename(image_path))[0]
        geojson_path = os.path.join(data_path, f"{base_name}.geojson")
        
        if not os.path.exists(geojson_path):
            continue
        
        # Load mask from GeoJSON
        mask = load_geojson_mask(geojson_path, image.shape)
        
        # Resize image and mask
        image_resized = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
        mask_resized = cv2.resize(mask, (IMG_SIZE, IMG_SIZE))
        
        # Normalize
        image_normalized = image_resized.astype(np.float32) / 255.0
        mask_normalized = (mask_resized > 127).astype(np.float32)
        
        images.append(image_normalized)
        masks.append(mask_normalized)
    
    return np.array(images), np.array(masks)

# Load the dataset
print("Loading dataset...")
X, y = load_dataset(TRAIN_DATA_PATH)
print(f"Loaded {len(X)} images with shape {X.shape[1:]}")
print(f"Masks shape: {y.shape}")

## Data Visualization

In [None]:
# Visualize some examples
fig, axes = plt.subplots(2, 6, figsize=(18, 6))

for i in range(3):
    # Original image
    axes[0, i*2].imshow(X[i])
    axes[0, i*2].set_title(f'Image {i+1}')
    axes[0, i*2].axis('off')
    
    # Mask
    axes[0, i*2+1].imshow(y[i], cmap='gray')
    axes[0, i*2+1].set_title(f'Mask {i+1}')
    axes[0, i*2+1].axis('off')
    
    # Overlay
    overlay = X[i].copy()
    overlay[:,:,0] = np.where(y[i] > 0.5, 1.0, overlay[:,:,0])
    axes[1, i*2].imshow(overlay)
    axes[1, i*2].set_title(f'Overlay {i+1}')
    axes[1, i*2].axis('off')
    
    # Vessel statistics
    vessel_ratio = np.mean(y[i])
    axes[1, i*2+1].bar(['Vessel', 'Background'], [vessel_ratio, 1-vessel_ratio])
    axes[1, i*2+1].set_title(f'Vessel Ratio: {vessel_ratio:.3f}')

plt.tight_layout()
plt.show()

## U-Net Model Architecture

In [None]:
def conv_block(inputs, num_filters):
    """Convolutional block with batch normalization and dropout."""
    x = layers.Conv2D(num_filters, 3, padding="same")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    
    x = layers.Conv2D(num_filters, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    
    return x

def encoder_block(inputs, num_filters):
    """Encoder block with convolution and max pooling."""
    x = conv_block(inputs, num_filters)
    p = layers.MaxPool2D((2, 2))(x)
    return x, p

def decoder_block(inputs, skip_features, num_filters):
    """Decoder block with upsampling and skip connections."""
    x = layers.Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(inputs)
    x = layers.Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

def build_unet(input_shape):
    """Build U-Net model."""
    inputs = layers.Input(input_shape)
    
    # Encoder
    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)
    
    # Bridge
    b1 = conv_block(p4, 1024)
    
    # Decoder
    d1 = decoder_block(b1, s4, 512)
    d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)
    
    # Output
    outputs = layers.Conv2D(1, 1, padding="same", activation="sigmoid")(d4)
    
    model = keras.Model(inputs, outputs, name="U-Net")
    return model

# Build the model
model = build_unet((IMG_SIZE, IMG_SIZE, 3))
model.summary()

## Loss Functions and Metrics

In [None]:
def dice_coefficient(y_true, y_pred, smooth=1e-6):
    """Dice coefficient for segmentation."""
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    """Dice loss function."""
    return 1 - dice_coefficient(y_true, y_pred)

def combined_loss(y_true, y_pred):
    """Combined binary crossentropy and dice loss."""
    bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    dice = dice_loss(y_true, y_pred)
    return bce + dice

def iou_metric(y_true, y_pred, threshold=0.5):
    """Intersection over Union metric."""
    y_pred = tf.cast(y_pred > threshold, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred)
    union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection
    return intersection / (union + 1e-6)

## Model Training

In [None]:
# Split the data
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.2, random_state=42
)

print(f"Training set: {X_train.shape[0]} samples")
print(f"Validation set: {X_val.shape[0]} samples")

# Compile the model
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss=combined_loss,
    metrics=[dice_coefficient, iou_metric, 'binary_accuracy']
)

# Callbacks
callbacks = [
    keras.callbacks.ModelCheckpoint(
        MODEL_SAVE_PATH,
        monitor='val_dice_coefficient',
        mode='max',
        save_best_only=True,
        verbose=1
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-7,
        verbose=1
    ),
    keras.callbacks.EarlyStopping(
        monitor='val_dice_coefficient',
        mode='max',
        patience=10,
        restore_best_weights=True,
        verbose=1
    )
]

# Train the model
history = model.fit(
    X_train, y_train,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(X_val, y_val),
    callbacks=callbacks,
    verbose=1
)

## Training Results Visualization

In [None]:
# Plot training history
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss
axes[0, 0].plot(history.history['loss'], label='Training Loss')
axes[0, 0].plot(history.history['val_loss'], label='Validation Loss')
axes[0, 0].set_title('Model Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()

# Dice Coefficient
axes[0, 1].plot(history.history['dice_coefficient'], label='Training Dice')
axes[0, 1].plot(history.history['val_dice_coefficient'], label='Validation Dice')
axes[0, 1].set_title('Dice Coefficient')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Dice Coefficient')
axes[0, 1].legend()

# IoU
axes[1, 0].plot(history.history['iou_metric'], label='Training IoU')
axes[1, 0].plot(history.history['val_iou_metric'], label='Validation IoU')
axes[1, 0].set_title('IoU Metric')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('IoU')
axes[1, 0].legend()

# Accuracy
axes[1, 1].plot(history.history['binary_accuracy'], label='Training Accuracy')
axes[1, 1].plot(history.history['val_binary_accuracy'], label='Validation Accuracy')
axes[1, 1].set_title('Binary Accuracy')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Accuracy')
axes[1, 1].legend()

plt.tight_layout()
plt.show()

# Print best scores
best_val_dice = max(history.history['val_dice_coefficient'])
best_val_iou = max(history.history['val_iou_metric'])
print(f"\nBest Validation Dice Coefficient: {best_val_dice:.4f}")
print(f"Best Validation IoU: {best_val_iou:.4f}")

## Model Evaluation

In [None]:
# Load the best model
best_model = keras.models.load_model(
    MODEL_SAVE_PATH,
    custom_objects={
        'combined_loss': combined_loss,
        'dice_coefficient': dice_coefficient,
        'iou_metric': iou_metric
    }
)

# Evaluate on validation set
val_results = best_model.evaluate(X_val, y_val, verbose=0)
print("Validation Results:")
for name, value in zip(best_model.metrics_names, val_results):
    print(f"{name}: {value:.4f}")

## Prediction Visualization

In [None]:
# Make predictions on validation set
predictions = best_model.predict(X_val[:8])

# Visualize predictions
fig, axes = plt.subplots(4, 6, figsize=(18, 12))

for i in range(4):
    # Original image
    axes[i, 0].imshow(X_val[i])
    axes[i, 0].set_title('Original')
    axes[i, 0].axis('off')
    
    # Ground truth
    axes[i, 1].imshow(y_val[i], cmap='gray')
    axes[i, 1].set_title('Ground Truth')
    axes[i, 1].axis('off')
    
    # Prediction
    pred = predictions[i, :, :, 0]
    axes[i, 2].imshow(pred, cmap='gray')
    axes[i, 2].set_title('Prediction')
    axes[i, 2].axis('off')
    
    # Binary prediction
    binary_pred = (pred > 0.5).astype(float)
    axes[i, 3].imshow(binary_pred, cmap='gray')
    axes[i, 3].set_title('Binary Pred')
    axes[i, 3].axis('off')
    
    # Overlay
    overlay = X_val[i].copy()
    overlay[:,:,0] = np.where(binary_pred > 0.5, 1.0, overlay[:,:,0])
    axes[i, 4].imshow(overlay)
    axes[i, 4].set_title('Overlay')
    axes[i, 4].axis('off')
    
    # Difference
    diff = np.abs(y_val[i] - binary_pred)
    axes[i, 5].imshow(diff, cmap='Reds')
    axes[i, 5].set_title('Difference')
    axes[i, 5].axis('off')

plt.tight_layout()
plt.show()

## Model Export for Production

In [None]:
# Save model info
model_info = {
    'model_type': 'U-Net',
    'input_shape': [IMG_SIZE, IMG_SIZE, 3],
    'output_shape': [IMG_SIZE, IMG_SIZE, 1],
    'best_val_dice': float(best_val_dice),
    'best_val_iou': float(best_val_iou),
    'training_samples': len(X_train),
    'validation_samples': len(X_val),
    'epochs_trained': len(history.history['loss']),
    'hyperparameters': {
        'batch_size': BATCH_SIZE,
        'learning_rate': LEARNING_RATE,
        'image_size': IMG_SIZE
    }
}

# Save model info as JSON
with open('../backend/models/model_info.json', 'w') as f:
    json.dump(model_info, f, indent=2)

print("Model training completed and saved successfully!")
print(f"Model saved to: {MODEL_SAVE_PATH}")
print(f"Final validation Dice coefficient: {best_val_dice:.4f}")
print(f"Final validation IoU: {best_val_iou:.4f}")