# Lab 5 – Medical image segmentation with U-Net

## Tasks
- Split dataset 70/30 (train/val)
- Train U-Net
- Evaluate (Pixel Accuracy, IoU, Dice)
- Implement U-Net v2
- Compare baseline vs v2

In [1]:
# Imports and environment checks
import os
import random
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

print('TensorFlow version:', tf.__version__)
print('GPUs:', tf.config.list_physical_devices('GPU'))

# Reproducibility
random.seed(42)
np.random.seed(42)
tf.random.set_seed(42)

ModuleNotFoundError: No module named 'numpy'

# Dataset scan + pairing

In [None]:
# Paths and configuration
from pathlib import Path

# Repository-relative roots (no absolute paths)
REPO_ROOT = Path('.') .resolve()
DATA_ROOT = REPO_ROOT / 'data' / 'brain_tumor'
RUNS_DIR = REPO_ROOT / 'runs'
MODELS_DIR = REPO_ROOT / 'models'

# Split files
SPLIT_TRAIN = RUNS_DIR / 'split_train.txt'
SPLIT_TEST  = RUNS_DIR / 'split_test.txt'

# Training config
IMG_SIZE = (256, 256)
BATCH_SIZE = 8
EPOCHS = 30
LR = 1e-4

# Ensure directories exist
RUNS_DIR.mkdir(parents=True, exist_ok=True)
MODELS_DIR.mkdir(parents=True, exist_ok=True)

print('REPO_ROOT:', REPO_ROOT)
print('DATA_ROOT exists:', DATA_ROOT.exists())
print('RUNS_DIR:', RUNS_DIR)
print('MODELS_DIR:', MODELS_DIR)

In [None]:
# Scan dataset and match image↔mask pairs
from pathlib import Path
from collections import defaultdict

IMG_EXTS = {'.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff'}

def is_img_ext(p: Path) -> bool:
    return p.suffix.lower() in IMG_EXTS

def normalize_stem(stem: str) -> str:
    s = stem.lower()
    # strip common mask suffixes
    for suf in ['_mask', '-mask', 'mask', '_seg', '-seg']:
        if s.endswith(suf):
            s = s[: -len(suf)]
            break
    # strip common mask prefix
    if s.startswith('mask_'):
        s = s[len('mask_'):]
    return s

def scan_files(root: Path):
    images = []
    masks = []
    for p in root.rglob('*'):
        if not p.is_file():
            continue
        ext = p.suffix.lower()
        if ext in IMG_EXTS:
            # Heuristic: consider anything under 'images' as image, 'masks' as mask, else infer by name
            name = p.stem.lower()
            if 'mask' in name or 'seg' in name or 'masks' in p.parts:
                masks.append(p)
            else:
                images.append(p)
        # some datasets store masks as PNG while images JPEG; above handles both
    return images, masks

def rel_to_repo(p: Path) -> Path:
    # convert to path relative to repo root
    try:
        return p.resolve().relative_to(REPO_ROOT.resolve())
    except Exception:
        # fallback: make a relative path using Path.relative_to if possible
        return Path(*p.parts[-len(p.parts):])

def match_pairs(images, masks):
    by_stem = defaultdict(list)
    for m in masks:
        by_stem[normalize_stem(m.stem)].append(m)

    pairs = []
    for img in images:
        key = normalize_stem(img.stem)
        candidates = by_stem.get(key, [])
        if not candidates:
            continue
        # prefer same-directory matches
        same_dir = [m for m in candidates if m.parent == img.parent]
        m = same_dir[0] if same_dir else candidates[0]
        pairs.append((rel_to_repo(img).as_posix(), rel_to_repo(m).as_posix()))
    return pairs

# Execute scan and matching
imgs, msks = scan_files(DATA_ROOT)
pairs = match_pairs(imgs, msks)
print(f'Found images: {len(imgs)} | masks: {len(msks)} | pairs: {len(pairs)}')

if len(pairs) == 0:
    img_stems = [p.stem for p in imgs[:10]]
    msk_stems = [p.stem for p in msks[:10]]
    print('Sample image stems (up to 10):', img_stems)
    print('Sample mask stems (up to 10):', msk_stems)
    raise RuntimeError('No image-mask pairs matched. Please verify naming conventions and folder structure.')

# 70/30 split

In [None]:
# Deterministic split and write files
import random

# Ensure pairs exist (from previous cell)
assert 'pairs' in globals(), 'pairs not defined. Run dataset pairing cell first.'

# Sort then shuffle with fixed seed
pairs_sorted = sorted(pairs)
random.Random(42).shuffle(pairs_sorted)

N = len(pairs_sorted)
split_idx = int(0.7 * N)
train_pairs = pairs_sorted[:split_idx]
test_pairs  = pairs_sorted[split_idx:]

# Write split files with POSIX paths: <img_path> <mask_path>
RUNS_DIR.mkdir(parents=True, exist_ok=True)
with SPLIT_TRAIN.open('w', encoding='utf-8') as f:
    for img, msk in train_pairs:
        f.write(f"{img} {msk}\n")
with SPLIT_TEST.open('w', encoding='utf-8') as f:
    for img, msk in test_pairs:
        f.write(f"{img} {msk}\n")

# Diagnostics
print(f'Total pairs: {N} | train: {len(train_pairs)} | test: {len(test_pairs)}')

# Show first 3 lines from each split file
def head_lines(path, n=3):
    lines = []
    try:
        with path.open('r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                if i >= n:
                    break
                lines.append(line.rstrip('\n'))
    except FileNotFoundError:
        lines = ['<missing>']
    return lines

print('split_train.txt (first 3):', head_lines(SPLIT_TRAIN, 3))
print('split_test.txt  (first 3):', head_lines(SPLIT_TEST, 3))

# tf.data pipeline

In [None]:
# Build tf.data pipeline for loading, resizing, normalizing
import tensorflow as tf
from pathlib import Path

assert SPLIT_TRAIN.exists() and SPLIT_TEST.exists(), 'Split files missing. Run the split cell first.'

# Read split files into lists of Paths
def read_split(path: Path):
    pairs = []
    with path.open('r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            img_str, mask_str = line.split(' ', 1)
            pairs.append((Path(img_str), Path(mask_str)))
    return pairs

train_list = read_split(SPLIT_TRAIN)
test_list  = read_split(SPLIT_TEST)
print(f'Read split lists: train={len(train_list)} test={len(test_list)}')

IMG_H, IMG_W = IMG_SIZE

# Loader using tf.io and tf.image
def load_one(img_path: tf.Tensor, mask_path: tf.Tensor):
    img_bytes = tf.io.read_file(img_path)
    msk_bytes = tf.io.read_file(mask_path)

    img = tf.image.decode_image(img_bytes, channels=3, expand_animations=False)
    msk = tf.image.decode_image(msk_bytes, channels=1, expand_animations=False)

    # Ensure shape is fully defined after decode_image
    img.set_shape([None, None, 3])
    msk.set_shape([None, None, 1])

    # Resize
    img = tf.image.resize(img, [IMG_H, IMG_W], method='bilinear')
    msk = tf.image.resize(msk, [IMG_H, IMG_W], method='nearest')

    # Normalize
    img = tf.cast(img, tf.float32) / 255.0
    msk = tf.cast(msk > 0, tf.float32)
    return img, msk

# Build datasets
def make_ds(pairs_list, batch=BATCH_SIZE, shuffle=False):
    img_paths = [str(REPO_ROOT / p[0]) for p in pairs_list]
    msk_paths = [str(REPO_ROOT / p[1]) for p in pairs_list]

    ds = tf.data.Dataset.from_tensor_slices((img_paths, msk_paths))
    if shuffle:
        ds = ds.shuffle(buffer_size=len(pairs_list), seed=42, reshuffle_each_iteration=True)
    ds = ds.map(lambda ip, mp: load_one(ip, mp), num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(batch)
    ds = ds.prefetch(tf.data.AUTOTUNE)
    return ds

train_ds = make_ds(train_list, batch=BATCH_SIZE, shuffle=True)
test_ds  = make_ds(test_list, batch=1, shuffle=False)

# Print one batch shapes
for imgs, msks in train_ds.take(1):
    print('Train batch shapes:', imgs.shape, msks.shape)
for imgs, msks in test_ds.take(1):
    print('Test batch shapes:', imgs.shape, msks.shape)

In [None]:
# Sanity visualization: 3 samples from train_ds
import matplotlib.pyplot as plt

samples = list(train_ds.unbatch().take(3))

fig, axes = plt.subplots(3, 3, figsize=(12, 12))
for i, (img, msk) in enumerate(samples):
    img_np = img.numpy()
    msk_np = msk.numpy().squeeze()
    
    # Image
    axes[i, 0].imshow(img_np)
    axes[i, 0].set_title(f'Sample {i+1}: Image')
    axes[i, 0].axis('off')
    
    # Mask
    axes[i, 1].imshow(msk_np, cmap='gray')
    axes[i, 1].set_title(f'Sample {i+1}: Mask')
    axes[i, 1].axis('off')
    
    # Overlay: image + mask contour or alpha-blend
    axes[i, 2].imshow(img_np)
    axes[i, 2].imshow(msk_np, cmap='Reds', alpha=0.4)
    axes[i, 2].set_title(f'Sample {i+1}: Overlay')
    axes[i, 2].axis('off')

plt.tight_layout()
out_path = RUNS_DIR / 'dataset_check.png'
plt.savefig(out_path, dpi=150, bbox_inches='tight')
print(f'Saved visualization to {out_path}')
plt.show()

# Baseline U-Net

In [None]:
# Baseline U-Net architecture using Keras
from tensorflow.keras import layers, models

def conv_block(x, filters):
    """Two consecutive Conv2D layers with ReLU activation."""
    x = layers.Conv2D(filters, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(filters, 3, padding='same', activation='relu')(x)
    return x

def encoder_block(x, filters):
    """Encoder block: conv_block + max pooling. Returns skip connection and pooled output."""
    skip = conv_block(x, filters)
    pooled = layers.MaxPooling2D(2)(skip)
    return skip, pooled

def decoder_block(x, skip, filters):
    """Decoder block: Conv2DTranspose (upsampling) + concatenate with skip + conv_block."""
    x = layers.Conv2DTranspose(filters, 2, strides=2, padding='same')(x)
    x = layers.Concatenate()([x, skip])
    x = conv_block(x, filters)
    return x

def build_unet(input_shape=(256, 256, 3), base_filters=32, depth=4):
    """Build a U-Net model with specified input shape, base filters, and depth."""
    inputs = layers.Input(shape=input_shape)
    
    # Encoder path
    skips = []
    x = inputs
    for i in range(depth):
        skip, x = encoder_block(x, base_filters * (2 ** i))
        skips.append(skip)
    
    # Bottleneck
    x = conv_block(x, base_filters * (2 ** depth))
    
    # Decoder path
    for i in reversed(range(depth)):
        x = decoder_block(x, skips[i], base_filters * (2 ** i))
    
    # Output layer: single-channel binary mask
    outputs = layers.Conv2D(1, 1, activation='sigmoid', padding='same')(x)
    
    model = models.Model(inputs, outputs, name='UNet_baseline')
    return model

# Build and summarize model
unet = build_unet(input_shape=(*IMG_SIZE, 3), base_filters=32, depth=4)
unet.summary()

In [None]:
# Custom metrics: IoU and Dice (TensorFlow-safe)
import tensorflow as tf

def iou_metric(y_true, y_pred, threshold=0.5, epsilon=1e-7):
    """Intersection over Union metric with thresholding."""
    y_pred_bin = tf.cast(y_pred > threshold, tf.float32)
    y_true_bin = tf.cast(y_true > threshold, tf.float32)
    
    intersection = tf.reduce_sum(y_true_bin * y_pred_bin)
    union = tf.reduce_sum(y_true_bin) + tf.reduce_sum(y_pred_bin) - intersection
    
    iou = (intersection + epsilon) / (union + epsilon)
    return iou

def dice_metric(y_true, y_pred, threshold=0.5, epsilon=1e-7):
    """Dice coefficient metric with thresholding."""
    y_pred_bin = tf.cast(y_pred > threshold, tf.float32)
    y_true_bin = tf.cast(y_true > threshold, tf.float32)
    
    intersection = tf.reduce_sum(y_true_bin * y_pred_bin)
    dice = (2.0 * intersection + epsilon) / (tf.reduce_sum(y_true_bin) + tf.reduce_sum(y_pred_bin) + epsilon)
    return dice

print('IoU and Dice metrics defined.')

# Training baseline

In [None]:
# Compile and train baseline U-Net
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
import matplotlib.pyplot as plt

# Compile model
unet.compile(
    optimizer=Adam(learning_rate=LR),
    loss='binary_crossentropy',
    metrics=[iou_metric, dice_metric]
)
print('Model compiled.')

# Callbacks
checkpoint_path = MODELS_DIR / 'unet_best.keras'
tb_log_dir = RUNS_DIR / 'tb' / 'baseline'

callbacks = [
    ModelCheckpoint(
        str(checkpoint_path),
        monitor='val_loss',
        save_best_only=True,
        verbose=1
    ),
    TensorBoard(
        log_dir=str(tb_log_dir),
        histogram_freq=0
    )
]

print(f'Checkpoint: {checkpoint_path}')
print(f'TensorBoard logs: {tb_log_dir}')

# Train
history = unet.fit(
    train_ds,
    validation_data=test_ds,
    epochs=EPOCHS,
    callbacks=callbacks,
    verbose=1
)

# Plot and save training history
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss
axes[0].plot(history.history['loss'], label='train_loss')
axes[0].plot(history.history['val_loss'], label='val_loss')
axes[0].set_title('Loss')
axes[0].set_xlabel('Epoch')
axes[0].legend()

# IoU
axes[1].plot(history.history['iou_metric'], label='train_iou')
axes[1].plot(history.history['val_iou_metric'], label='val_iou')
axes[1].set_title('IoU')
axes[1].set_xlabel('Epoch')
axes[1].legend()

# Dice
axes[2].plot(history.history['dice_metric'], label='train_dice')
axes[2].plot(history.history['val_dice_metric'], label='val_dice')
axes[2].set_title('Dice')
axes[2].set_xlabel('Epoch')
axes[2].legend()

plt.tight_layout()
hist_path = RUNS_DIR / 'baseline_history.png'
plt.savefig(hist_path, dpi=150, bbox_inches='tight')
print(f'Saved training history to {hist_path}')
plt.show()

# Evaluation baseline (lab formulas)

In [None]:
# Load best model and evaluate on test set using lab formulas
from tensorflow.keras.models import load_model
import numpy as np
import matplotlib.pyplot as plt

# Load best checkpoint
best_model_path = MODELS_DIR / 'unet_best.keras'
model_eval = load_model(str(best_model_path), custom_objects={'iou_metric': iou_metric, 'dice_metric': dice_metric})
print(f'Loaded best model from {best_model_path}')

# Collect metrics per image
pa_list = []
iou_list = []
dice_list = []

# Collect samples for qualitative visualization
qual_samples = []

for img_batch, gt_batch in test_ds:
    # Predict
    pred_batch = model_eval.predict(img_batch, verbose=0)
    
    # Process each image in batch (batch=1 for test_ds)
    for img, gt, pred in zip(img_batch.numpy(), gt_batch.numpy(), pred_batch):
        # Threshold prediction at 0.5
        pred_bin = (pred > 0.5).astype(np.float32)
        gt_bin = (gt > 0.5).astype(np.float32)
        
        # Flatten for metric computation
        gt_flat = gt_bin.flatten()
        pred_flat = pred_bin.flatten()
        
        # Compute TP, TN, FP, FN
        TP = np.sum((gt_flat == 1) & (pred_flat == 1))
        TN = np.sum((gt_flat == 0) & (pred_flat == 0))
        FP = np.sum((gt_flat == 0) & (pred_flat == 1))
        FN = np.sum((gt_flat == 1) & (pred_flat == 0))
        
        # Pixel Accuracy
        pa = (TP + TN) / (TP + TN + FP + FN + 1e-7)
        
        # IoU
        iou = TP / (TP + FP + FN + 1e-7)
        
        # Dice
        dice = (2 * TP) / (2 * TP + FP + FN + 1e-7)
        
        pa_list.append(pa)
        iou_list.append(iou)
        dice_list.append(dice)
        
        # Store first 10 samples for visualization
        if len(qual_samples) < 10:
            qual_samples.append((img, gt.squeeze(), pred_bin.squeeze()))

# Print mean metrics
print(f'\n=== Baseline Evaluation (test set, n={len(pa_list)}) ===')
print(f'Pixel Accuracy: {np.mean(pa_list):.4f} ± {np.std(pa_list):.4f}')
print(f'IoU:            {np.mean(iou_list):.4f} ± {np.std(iou_list):.4f}')
print(f'Dice:           {np.mean(dice_list):.4f} ± {np.std(dice_list):.4f}')

# Save qualitative results: grid of 10 samples (image / GT / pred / overlay)
eval_dir = RUNS_DIR / 'eval' / 'baseline'
eval_dir.mkdir(parents=True, exist_ok=True)

fig, axes = plt.subplots(10, 4, figsize=(16, 40))
for i, (img, gt, pred) in enumerate(qual_samples):
    # Image
    axes[i, 0].imshow(img)
    axes[i, 0].set_title(f'Sample {i+1}: Image')
    axes[i, 0].axis('off')
    
    # Ground truth
    axes[i, 1].imshow(gt, cmap='gray')
    axes[i, 1].set_title(f'Sample {i+1}: GT')
    axes[i, 1].axis('off')
    
    # Prediction
    axes[i, 2].imshow(pred, cmap='gray')
    axes[i, 2].set_title(f'Sample {i+1}: Pred')
    axes[i, 2].axis('off')
    
    # Overlay: image + prediction
    axes[i, 3].imshow(img)
    axes[i, 3].imshow(pred, cmap='Reds', alpha=0.4)
    axes[i, 3].set_title(f'Sample {i+1}: Overlay')
    axes[i, 3].axis('off')

plt.tight_layout()
qual_path = eval_dir / 'qualitative_results.png'
plt.savefig(qual_path, dpi=150, bbox_inches='tight')
print(f'Saved qualitative results to {qual_path}')
plt.show()

# U-Net v2

In [None]:
# U-Net v2 with BatchNorm, Dropout, and combined BCE+Dice loss
from tensorflow.keras import layers, models, backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
import matplotlib.pyplot as plt

def conv_block_v2(x, filters, dropout=False):
    """Conv block with BatchNorm after each Conv2D (before activation) and optional Dropout."""
    x = layers.Conv2D(filters, 3, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    
    x = layers.Conv2D(filters, 3, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    
    if dropout:
        x = layers.Dropout(0.3)(x)
    return x

def encoder_block_v2(x, filters, dropout=False):
    """Encoder block v2 with optional dropout."""
    skip = conv_block_v2(x, filters, dropout=dropout)
    pooled = layers.MaxPooling2D(2)(skip)
    return skip, pooled

def decoder_block_v2(x, skip, filters):
    """Decoder block v2."""
    x = layers.Conv2DTranspose(filters, 2, strides=2, padding='same')(x)
    x = layers.Concatenate()([x, skip])
    x = conv_block_v2(x, filters, dropout=False)
    return x

def build_unet_v2(input_shape=(256, 256, 3), base_filters=32, depth=4):
    """Build U-Net v2 with BatchNorm and Dropout in bottleneck and deepest encoder."""
    inputs = layers.Input(shape=input_shape)
    
    # Encoder path
    skips = []
    x = inputs
    for i in range(depth):
        # Apply dropout to deepest encoder (i == depth-1)
        dropout = (i == depth - 1)
        skip, x = encoder_block_v2(x, base_filters * (2 ** i), dropout=dropout)
        skips.append(skip)
    
    # Bottleneck with dropout
    x = conv_block_v2(x, base_filters * (2 ** depth), dropout=True)
    
    # Decoder path
    for i in reversed(range(depth)):
        x = decoder_block_v2(x, skips[i], base_filters * (2 ** i))
    
    # Output layer
    outputs = layers.Conv2D(1, 1, activation='sigmoid', padding='same')(x)
    
    model = models.Model(inputs, outputs, name='UNet_v2')
    return model

# Continuous Dice loss
def dice_loss(y_true, y_pred, epsilon=1e-7):
    """Continuous Dice loss (1 - Dice coefficient)."""
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    dice_coef = (2.0 * intersection + epsilon) / (K.sum(y_true_f) + K.sum(y_pred_f) + epsilon)
    return 1.0 - dice_coef

# Combined loss: BCE + Dice
def combined_loss(y_true, y_pred):
    """Binary crossentropy + Dice loss."""
    bce = K.binary_crossentropy(y_true, y_pred)
    bce = K.mean(bce)
    dice = dice_loss(y_true, y_pred)
    return bce + dice

# Build and compile U-Net v2
unet_v2 = build_unet_v2(input_shape=(*IMG_SIZE, 3), base_filters=32, depth=4)
unet_v2.summary()

unet_v2.compile(
    optimizer=Adam(learning_rate=LR),
    loss=combined_loss,
    metrics=[iou_metric, dice_metric]
)
print('U-Net v2 compiled with combined BCE+Dice loss.')

# Callbacks
checkpoint_path_v2 = MODELS_DIR / 'unet_v2_best.keras'
tb_log_dir_v2 = RUNS_DIR / 'tb' / 'v2'

callbacks_v2 = [
    ModelCheckpoint(
        str(checkpoint_path_v2),
        monitor='val_loss',
        save_best_only=True,
        verbose=1
    ),
    TensorBoard(
        log_dir=str(tb_log_dir_v2),
        histogram_freq=0
    )
]

print(f'Checkpoint v2: {checkpoint_path_v2}')
print(f'TensorBoard logs v2: {tb_log_dir_v2}')

# Train U-Net v2
history_v2 = unet_v2.fit(
    train_ds,
    validation_data=test_ds,
    epochs=EPOCHS,
    callbacks=callbacks_v2,
    verbose=1
)

# Plot and save training history
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss
axes[0].plot(history_v2.history['loss'], label='train_loss')
axes[0].plot(history_v2.history['val_loss'], label='val_loss')
axes[0].set_title('Loss (BCE+Dice)')
axes[0].set_xlabel('Epoch')
axes[0].legend()

# IoU
axes[1].plot(history_v2.history['iou_metric'], label='train_iou')
axes[1].plot(history_v2.history['val_iou_metric'], label='val_iou')
axes[1].set_title('IoU')
axes[1].set_xlabel('Epoch')
axes[1].legend()

# Dice
axes[2].plot(history_v2.history['dice_metric'], label='train_dice')
axes[2].plot(history_v2.history['val_dice_metric'], label='val_dice')
axes[2].set_title('Dice')
axes[2].set_xlabel('Epoch')
axes[2].legend()

plt.tight_layout()
hist_path_v2 = RUNS_DIR / 'v2_history.png'
plt.savefig(hist_path_v2, dpi=150, bbox_inches='tight')
print(f'Saved U-Net v2 training history to {hist_path_v2}')
plt.show()

In [None]:
# Evaluate U-Net v2 and compare with baseline
from tensorflow.keras.models import load_model
import numpy as np
import matplotlib.pyplot as plt

# Load best v2 checkpoint
best_model_path_v2 = MODELS_DIR / 'unet_v2_best.keras'
model_eval_v2 = load_model(
    str(best_model_path_v2), 
    custom_objects={
        'iou_metric': iou_metric, 
        'dice_metric': dice_metric,
        'combined_loss': combined_loss,
        'dice_loss': dice_loss
    }
)
print(f'Loaded U-Net v2 best model from {best_model_path_v2}')

# Collect metrics per image for v2
pa_list_v2 = []
iou_list_v2 = []
dice_list_v2 = []

# Collect samples for qualitative visualization
qual_samples_v2 = []

for img_batch, gt_batch in test_ds:
    # Predict
    pred_batch = model_eval_v2.predict(img_batch, verbose=0)
    
    # Process each image in batch (batch=1 for test_ds)
    for img, gt, pred in zip(img_batch.numpy(), gt_batch.numpy(), pred_batch):
        # Threshold prediction at 0.5
        pred_bin = (pred > 0.5).astype(np.float32)
        gt_bin = (gt > 0.5).astype(np.float32)
        
        # Flatten for metric computation
        gt_flat = gt_bin.flatten()
        pred_flat = pred_bin.flatten()
        
        # Compute TP, TN, FP, FN
        TP = np.sum((gt_flat == 1) & (pred_flat == 1))
        TN = np.sum((gt_flat == 0) & (pred_flat == 0))
        FP = np.sum((gt_flat == 0) & (pred_flat == 1))
        FN = np.sum((gt_flat == 1) & (pred_flat == 0))
        
        # Pixel Accuracy
        pa = (TP + TN) / (TP + TN + FP + FN + 1e-7)
        
        # IoU
        iou = TP / (TP + FP + FN + 1e-7)
        
        # Dice
        dice = (2 * TP) / (2 * TP + FP + FN + 1e-7)
        
        pa_list_v2.append(pa)
        iou_list_v2.append(iou)
        dice_list_v2.append(dice)
        
        # Store first 10 samples for visualization
        if len(qual_samples_v2) < 10:
            qual_samples_v2.append((img, gt.squeeze(), pred_bin.squeeze()))

# Print U-Net v2 metrics
print(f'\n=== U-Net v2 Evaluation (test set, n={len(pa_list_v2)}) ===')
print(f'Pixel Accuracy: {np.mean(pa_list_v2):.4f} ± {np.std(pa_list_v2):.4f}')
print(f'IoU:            {np.mean(iou_list_v2):.4f} ± {np.std(iou_list_v2):.4f}')
print(f'Dice:           {np.mean(dice_list_v2):.4f} ± {np.std(dice_list_v2):.4f}')

# Print comparison table
print('\n' + '='*50)
print('                  COMPARISON')
print('='*50)
print(f'{"Model":<15} {"PixelAcc":<12} {"IoU":<12} {"Dice":<12}')
print('-'*50)
print(f'{"Baseline":<15} {np.mean(pa_list):<12.4f} {np.mean(iou_list):<12.4f} {np.mean(dice_list):<12.4f}')
print(f'{"U-Net v2":<15} {np.mean(pa_list_v2):<12.4f} {np.mean(iou_list_v2):<12.4f} {np.mean(dice_list_v2):<12.4f}')
print('='*50)

# Save qualitative results for v2
eval_dir_v2 = RUNS_DIR / 'eval' / 'v2'
eval_dir_v2.mkdir(parents=True, exist_ok=True)

fig, axes = plt.subplots(10, 4, figsize=(16, 40))
for i, (img, gt, pred) in enumerate(qual_samples_v2):
    # Image
    axes[i, 0].imshow(img)
    axes[i, 0].set_title(f'Sample {i+1}: Image')
    axes[i, 0].axis('off')
    
    # Ground truth
    axes[i, 1].imshow(gt, cmap='gray')
    axes[i, 1].set_title(f'Sample {i+1}: GT')
    axes[i, 1].axis('off')
    
    # Prediction
    axes[i, 2].imshow(pred, cmap='gray')
    axes[i, 2].set_title(f'Sample {i+1}: Pred')
    axes[i, 2].axis('off')
    
    # Overlay: image + prediction
    axes[i, 3].imshow(img)
    axes[i, 3].imshow(pred, cmap='Reds', alpha=0.4)
    axes[i, 3].set_title(f'Sample {i+1}: Overlay')
    axes[i, 3].axis('off')

plt.tight_layout()
qual_path_v2 = eval_dir_v2 / 'qualitative_results.png'
plt.savefig(qual_path_v2, dpi=150, bbox_inches='tight')
print(f'\nSaved U-Net v2 qualitative results to {qual_path_v2}')
plt.show()

# Optional: Pratheepan

In [None]:
# Optional: Evaluate on Pratheepan dataset
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

# Check if Pratheepan dataset exists
PRATHEEPAN_ROOT = REPO_ROOT / 'data' / 'pratheepan'

if PRATHEEPAN_ROOT.exists():
    print(f'Pratheepan dataset found at {PRATHEEPAN_ROOT}')
    
    # Scan and match pairs
    prath_imgs, prath_msks = scan_files(PRATHEEPAN_ROOT)
    prath_pairs = match_pairs(prath_imgs, prath_msks)
    
    if len(prath_pairs) > 0:
        print(f'Found {len(prath_pairs)} Pratheepan pairs')
        
        # Build dataset (batch=1 for evaluation)
        prath_ds = make_ds([(Path(img), Path(msk)) for img, msk in prath_pairs], batch=1, shuffle=False)
        
        # Load best baseline model (or use v2 if preferred)
        from tensorflow.keras.models import load_model
        best_model_path = MODELS_DIR / 'unet_best.keras'
        model_prath = load_model(
            str(best_model_path), 
            custom_objects={'iou_metric': iou_metric, 'dice_metric': dice_metric}
        )
        print(f'Loaded model from {best_model_path}')
        
        # Evaluate
        pa_list_prath = []
        iou_list_prath = []
        dice_list_prath = []
        qual_samples_prath = []
        
        for img_batch, gt_batch in prath_ds:
            pred_batch = model_prath.predict(img_batch, verbose=0)
            
            for img, gt, pred in zip(img_batch.numpy(), gt_batch.numpy(), pred_batch):
                # Threshold
                pred_bin = (pred > 0.5).astype(np.float32)
                gt_bin = (gt > 0.5).astype(np.float32)
                
                # Flatten
                gt_flat = gt_bin.flatten()
                pred_flat = pred_bin.flatten()
                
                # Compute metrics
                TP = np.sum((gt_flat == 1) & (pred_flat == 1))
                TN = np.sum((gt_flat == 0) & (pred_flat == 0))
                FP = np.sum((gt_flat == 0) & (pred_flat == 1))
                FN = np.sum((gt_flat == 1) & (pred_flat == 0))
                
                pa = (TP + TN) / (TP + TN + FP + FN + 1e-7)
                iou = TP / (TP + FP + FN + 1e-7)
                dice = (2 * TP) / (2 * TP + FP + FN + 1e-7)
                
                pa_list_prath.append(pa)
                iou_list_prath.append(iou)
                dice_list_prath.append(dice)
                
                # Store first 10 samples
                if len(qual_samples_prath) < 10:
                    qual_samples_prath.append((img, gt.squeeze(), pred_bin.squeeze()))
        
        # Print metrics
        print(f'\n=== Pratheepan Evaluation (n={len(pa_list_prath)}) ===')
        print(f'Pixel Accuracy: {np.mean(pa_list_prath):.4f} ± {np.std(pa_list_prath):.4f}')
        print(f'IoU:            {np.mean(iou_list_prath):.4f} ± {np.std(iou_list_prath):.4f}')
        print(f'Dice:           {np.mean(dice_list_prath):.4f} ± {np.std(dice_list_prath):.4f}')
        
        # Save qualitative results
        eval_dir_prath = RUNS_DIR / 'eval' / 'pratheepan'
        eval_dir_prath.mkdir(parents=True, exist_ok=True)
        
        n_samples = min(10, len(qual_samples_prath))
        fig, axes = plt.subplots(n_samples, 4, figsize=(16, 4 * n_samples))
        
        # Handle single sample case (axes won't be 2D)
        if n_samples == 1:
            axes = axes.reshape(1, -1)
        
        for i, (img, gt, pred) in enumerate(qual_samples_prath[:n_samples]):
            # Image
            axes[i, 0].imshow(img)
            axes[i, 0].set_title(f'Sample {i+1}: Image')
            axes[i, 0].axis('off')
            
            # Ground truth
            axes[i, 1].imshow(gt, cmap='gray')
            axes[i, 1].set_title(f'Sample {i+1}: GT')
            axes[i, 1].axis('off')
            
            # Prediction
            axes[i, 2].imshow(pred, cmap='gray')
            axes[i, 2].set_title(f'Sample {i+1}: Pred')
            axes[i, 2].axis('off')
            
            # Overlay
            axes[i, 3].imshow(img)
            axes[i, 3].imshow(pred, cmap='Reds', alpha=0.4)
            axes[i, 3].set_title(f'Sample {i+1}: Overlay')
            axes[i, 3].axis('off')
        
        plt.tight_layout()
        qual_path_prath = eval_dir_prath / 'qualitative_results.png'
        plt.savefig(qual_path_prath, dpi=150, bbox_inches='tight')
        print(f'Saved Pratheepan qualitative results to {qual_path_prath}')
        plt.show()
    else:
        print('Skipping Pratheepan evaluation (no image-mask pairs found).')
else:
    print('Skipping Pratheepan evaluation (dataset not found).')