# 🔬 Retinal Blood Vessel Segmentation
## FIVES Dataset · U-Net · PyTorch · MPS Optimized

**Fixes applied:** `tqdm.auto`, `num_workers=0`, MPS device, faster training config, deprecated API updates.


## Cell 1 — Device Check

In [None]:
import platform, sys, torch

print(f"✅ Running on: {platform.processor()}")
print(f"Python: {sys.version}")

if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
    print(f"🚀 CUDA GPU: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
    print("🚀 Apple Silicon MPS enabled — GPU acceleration active!")
else:
    DEVICE = torch.device('cpu')
    print("⚠️  CPU only")

print(f"Device: {DEVICE}")


## Cell 2 — Download FIVES Dataset

In [None]:
import os, zipfile, urllib.request
from pathlib import Path

DATA_DIR = Path('./archive')
DATA_DIR.mkdir(exist_ok=True)

FIVES_URL = 'https://figshare.com/ndownloader/articles/19688169/versions/1'
ZIP_PATH  = './archive.zip'

if not (DATA_DIR / 'train').exists():
    print('⬇️  Downloading FIVES dataset (~1.2 GB)...')
    def progress_hook(count, block_size, total_size):
        pct = min(count * block_size * 100 / total_size, 100)
        if count % 500 == 0:
            print(f'   {pct:.1f}%', end='\r')
    urllib.request.urlretrieve(FIVES_URL, ZIP_PATH, progress_hook)
    print('\n📂 Extracting...')
    with zipfile.ZipFile(ZIP_PATH, 'r') as z:
        z.extractall('./')
    print('✅ FIVES dataset ready!')
else:
    print('✅ FIVES already downloaded — skipping.')

TRAIN_IMG_DIR  = DATA_DIR / 'train' / 'Original'
TRAIN_MASK_DIR = DATA_DIR / 'train' / 'Ground truth'
TEST_IMG_DIR   = DATA_DIR / 'test'  / 'Original'
TEST_MASK_DIR  = DATA_DIR / 'test'  / 'Ground truth'

for d in [TRAIN_IMG_DIR, TRAIN_MASK_DIR, TEST_IMG_DIR, TEST_MASK_DIR]:
    imgs = list(d.glob('*.png')) + list(d.glob('*.jpg'))
    print(f'  {d}  →  {len(imgs)} files')


## Cell 3 — EDA: CSV + Visualization

In [None]:
import csv, random
import matplotlib.pyplot as plt
import cv2, numpy as np
from collections import Counter
from pathlib import Path
import os

os.makedirs('./outputs', exist_ok=True)

CATEGORIES = ['Normal', 'AMD', 'DR', 'Glaucoma']
cat_map = {'N': 'Normal', 'A': 'AMD', 'D': 'DR', 'G': 'Glaucoma'}

rows = []
for img_path in sorted(TRAIN_IMG_DIR.glob('*.png')) + sorted(TEST_IMG_DIR.glob('*.png')):
    fname = img_path.name
    # FIX: scan all chars for category letter (filenames may start with a digit)
    prefix = next((c for c in fname.upper() if c in cat_map), None)
    disease = cat_map.get(prefix, 'Unknown')
    split = 'train' if 'train' in str(img_path) else 'test'
    rows.append({'image': fname, 'mask': fname, 'disease': disease,
                 'quality': 'Good', 'split': split})

with open('./outputs/fives_metadata.csv', 'w', newline='') as f:
    writer = csv.DictWriter(f, fieldnames=['image','mask','disease','quality','split'])
    writer.writeheader()
    writer.writerows(rows)
print(f"✅ CSV saved: {len(rows)} rows → ./outputs/fives_metadata.csv")

counts = Counter(r['disease'] for r in rows)
print("\n📊 Dataset counts per category:")
for cat, n in counts.items():
    print(f"   {cat}: {n} images")

# Show 6 random image–mask pairs
img_paths = sorted(TRAIN_IMG_DIR.glob('*.png'))
chosen = random.sample(img_paths, 6)
fig, axes = plt.subplots(2, 6, figsize=(18, 6))
fig.suptitle('FIVES — 6 Random Images + Ground Truth Masks', fontsize=14, fontweight='bold')
for i, ip in enumerate(chosen):
    mp = TRAIN_MASK_DIR / ip.name
    img  = cv2.cvtColor(cv2.imread(str(ip)), cv2.COLOR_BGR2RGB)
    mask = cv2.imread(str(mp), cv2.IMREAD_GRAYSCALE)
    img  = cv2.resize(img,  (256, 256))
    mask = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_NEAREST)
    axes[0,i].imshow(img);  axes[0,i].set_title(f'Image {i+1}', fontsize=9); axes[0,i].axis('off')
    axes[1,i].imshow(mask, cmap='gray'); axes[1,i].set_title(f'Mask {i+1}', fontsize=9); axes[1,i].axis('off')
axes[0,0].set_ylabel('Fundus Image', fontsize=10)
axes[1,0].set_ylabel('Vessel Mask',  fontsize=10)
plt.tight_layout()
plt.savefig('./outputs/task1_image_mask_pairs.png', dpi=100, bbox_inches='tight')
plt.show()
print("✅ Saved image–mask pairs.")

# CLAHE comparison
sample_img = cv2.cvtColor(cv2.imread(str(img_paths[0])), cv2.COLOR_BGR2RGB)
sample_img = cv2.resize(sample_img, (512, 512))
lab = cv2.cvtColor(sample_img, cv2.COLOR_RGB2LAB)
clahe_op = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
lab[:,:,0] = clahe_op.apply(lab[:,:,0])
enhanced = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
fig, (a1, a2) = plt.subplots(1, 2, figsize=(10, 4))
a1.imshow(sample_img); a1.set_title('Original'); a1.axis('off')
a2.imshow(enhanced);   a2.set_title('After CLAHE'); a2.axis('off')
plt.tight_layout()
plt.savefig('./outputs/task2_clahe_comparison.png', dpi=100, bbox_inches='tight')
plt.show()
print("✅ CLAHE comparison saved.")


## Cell 4 — Configuration

In [None]:
CFG = {
    # Data — 256px is ~4x faster than 512px with minimal accuracy loss
    'image_size'      : 256,
    'val_split'       : 0.1,
    'num_workers'     : 0,          # FIX: 0 workers — macOS multiprocessing crash fix

    # Model
    'architecture'    : 'unet',
    'encoder'         : 'resnet34',
    'encoder_weights' : 'imagenet',

    # Loss
    'loss_type'       : 'combined',
    'dice_w'          : 0.5,
    'tversky_w'       : 0.3,
    'focal_w'         : 0.2,
    'tversky_alpha'   : 0.3,
    'tversky_beta'    : 0.7,
    'focal_gamma'     : 2.0,

    # Training
    'epochs'          : 50,
    'batch_size'      : 16,
    'lr'              : 2e-4,       # slightly higher LR for faster convergence
    'weight_decay'    : 1e-4,
    'early_stopping'  : 15,
    'mixed_precision' : False,      # MPS doesn't support AMP — disabled
    'grad_clip'       : 1.0,
    'threshold'       : 0.5,

    # Paths
    'save_dir'        : './outputs',
}

import os
os.makedirs(CFG['save_dir'], exist_ok=True)
print('✅ Config ready.')
print(f"   Arch: {CFG['architecture']} | Encoder: {CFG['encoder']}")
print(f"   Loss: {CFG['loss_type']} | Epochs: {CFG['epochs']} | Batch: {CFG['batch_size']}")
print(f"   Image size: {CFG['image_size']}×{CFG['image_size']} | Device: {DEVICE}")
print(f"   num_workers: {CFG['num_workers']} (0 = safe for macOS)")


## Cell 5 — Imports

In [None]:
import os, glob, random, json, csv, time
import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from pathlib import Path
from tqdm.auto import tqdm            # FIX: tqdm.auto instead of tqdm.notebook

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.amp import GradScaler, autocast  # FIX: updated import (non-deprecated)

import albumentations as A
from albumentations.pytorch import ToTensorV2

import segmentation_models_pytorch as smp
from sklearn.metrics import roc_auc_score, roc_curve, matthews_corrcoef

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

print(f'✅ Imports done. Device: {DEVICE}')


## Cell 6 — Dataset Class

In [None]:
class FIVESDataset(Dataset):
    def __init__(self, images_dir, masks_dir, transform=None,
                 image_size=256, preprocess_clahe=True):
        self.images_dir = images_dir
        self.masks_dir  = masks_dir
        self.transform  = transform
        self.image_size = image_size
        self.preprocess_clahe = preprocess_clahe

        self.image_paths = sorted(
            glob.glob(str(images_dir / '*.png')) +
            glob.glob(str(images_dir / '*.jpg'))
        )
        if not self.image_paths:
            raise FileNotFoundError(f'No images found in {images_dir}')

        self.mask_paths = []
        for img_path in self.image_paths:
            fname     = os.path.basename(img_path)
            mask_path = masks_dir / fname
            if not mask_path.exists():
                candidates = list(masks_dir.glob(f'{Path(fname).stem}.*'))
                if candidates:
                    mask_path = candidates[0]
                else:
                    raise FileNotFoundError(f'Mask not found for {img_path}')
            self.mask_paths.append(str(mask_path))

        print(f'  FIVESDataset | {len(self.image_paths)} samples from {images_dir.name}/')

    @staticmethod
    def apply_clahe(image_rgb):
        lab = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2LAB)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        lab[:, :, 0] = clahe.apply(lab[:, :, 0])
        return cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)

    def _load_image(self, path):
        img = cv2.imread(path, cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (self.image_size, self.image_size),
                         interpolation=cv2.INTER_LINEAR)
        return img

    def _load_mask(self, path):
        mask = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        mask = cv2.resize(mask, (self.image_size, self.image_size),
                          interpolation=cv2.INTER_NEAREST)
        return (mask > 127).astype(np.uint8)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = self._load_image(self.image_paths[idx])
        mask  = self._load_mask(self.mask_paths[idx])
        if self.preprocess_clahe:
            image = self.apply_clahe(image)
        if self.transform is not None:
            aug   = self.transform(image=image, mask=mask)
            image = aug['image']
            mask  = aug['mask']
        else:
            image = torch.from_numpy(image).permute(2,0,1).float() / 255.0
            mask  = torch.from_numpy(mask).unsqueeze(0).float()
        if mask.dim() == 2:
            mask = mask.unsqueeze(0)
        return image, mask.float()


def get_train_transforms(size):
    return A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.Rotate(limit=30, border_mode=4, p=0.5),     # reduced rotation range for speed
        A.ElasticTransform(alpha=60, sigma=6, p=0.2),  # lighter augmentation
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
        A.RandomGamma(gamma_limit=(80, 120), p=0.3),
        A.GaussNoise(p=0.2),                           # FIX: removed deprecated var_limit
        A.Resize(size, size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

def get_val_transforms(size):
    return A.Compose([
        A.Resize(size, size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])


def build_dataloaders(cfg):
    size       = cfg['image_size']
    val_split  = cfg['val_split']
    batch_size = cfg['batch_size']
    nw         = cfg['num_workers']   # FIX: 0 for macOS stability

    full_train = FIVESDataset(
        TRAIN_IMG_DIR, TRAIN_MASK_DIR,
        transform=get_train_transforms(size),
        image_size=size,
    )
    n_val   = max(1, int(len(full_train) * val_split))
    n_train = len(full_train) - n_val
    train_ds, val_ds = random_split(
        full_train, [n_train, n_val],
        generator=torch.Generator().manual_seed(SEED)
    )
    val_ds.dataset.transform = get_val_transforms(size)

    test_ds = FIVESDataset(
        TEST_IMG_DIR, TEST_MASK_DIR,
        transform=get_val_transforms(size),
        image_size=size,
    )

    # FIX: pin_memory=False — not supported on MPS
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=nw, drop_last=True, pin_memory=False)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False,
                              num_workers=nw, pin_memory=False)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False,
                              num_workers=nw, pin_memory=False)

    print(f'\n✅ DataLoaders ready')
    print(f'   Train: {len(train_ds)} | Val: {len(val_ds)} | Test: {len(test_ds)}')
    return train_loader, val_loader, test_loader, test_ds


print('✅ Dataset classes defined.')


## Cell 7 — Loss Functions

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth
    def forward(self, logits, targets):
        probs   = torch.sigmoid(logits).view(-1)
        targets = targets.view(-1)
        inter   = (probs * targets).sum()
        return 1 - (2*inter + self.smooth) / (probs.sum() + targets.sum() + self.smooth)

class TverskyLoss(nn.Module):
    def __init__(self, alpha=0.3, beta=0.7, smooth=1.0):
        super().__init__()
        self.alpha, self.beta, self.smooth = alpha, beta, smooth
    def forward(self, logits, targets):
        probs   = torch.sigmoid(logits).view(-1)
        targets = targets.view(-1)
        tp = (probs * targets).sum()
        fp = (probs * (1 - targets)).sum()
        fn = ((1 - probs) * targets).sum()
        return 1 - (tp + self.smooth) / (tp + self.alpha*fp + self.beta*fn + self.smooth)

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha, self.gamma = alpha, gamma
    def forward(self, logits, targets):
        bce   = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        probs = torch.sigmoid(logits)
        p_t   = targets * probs + (1 - targets) * (1 - probs)
        a_t   = targets * self.alpha + (1 - targets) * (1 - self.alpha)
        return (a_t * (1 - p_t)**self.gamma * bce).mean()

class CombinedLoss(nn.Module):
    def __init__(self, dice_w=0.5, tversky_w=0.3, focal_w=0.2,
                 tversky_alpha=0.3, tversky_beta=0.7, focal_gamma=2.0):
        super().__init__()
        self.dice_w, self.tversky_w, self.focal_w = dice_w, tversky_w, focal_w
        self.dice    = DiceLoss()
        self.tversky = TverskyLoss(alpha=tversky_alpha, beta=tversky_beta)
        self.focal   = FocalLoss(gamma=focal_gamma)
    def forward(self, logits, targets):
        return (self.dice_w    * self.dice(logits, targets) +
                self.tversky_w * self.tversky(logits, targets) +
                self.focal_w   * self.focal(logits, targets))

def build_loss(cfg):
    t = cfg['loss_type']
    if   t == 'dice'     : return DiceLoss()
    elif t == 'bce_dice' : return nn.BCEWithLogitsLoss()
    elif t == 'focal'    : return FocalLoss(gamma=cfg['focal_gamma'])
    elif t == 'tversky'  : return TverskyLoss(alpha=cfg['tversky_alpha'], beta=cfg['tversky_beta'])
    elif t == 'combined' :
        return CombinedLoss(
            dice_w=cfg['dice_w'], tversky_w=cfg['tversky_w'], focal_w=cfg['focal_w'],
            tversky_alpha=cfg['tversky_alpha'], tversky_beta=cfg['tversky_beta'],
            focal_gamma=cfg['focal_gamma'])
    raise ValueError(f"Unknown loss: {t}")

print('✅ Loss functions defined.')


## Cell 8 — Model (U-Net via SMP)

In [None]:
def build_model(cfg):
    arch   = cfg['architecture']
    kwargs = dict(
        encoder_name    = cfg['encoder'],
        encoder_weights = cfg['encoder_weights'],
        in_channels     = 3,
        classes         = 1,
        activation      = None,
    )
    if   arch == 'unet'          : model = smp.Unet(**kwargs)
    elif arch == 'unetplusplus'  : model = smp.UnetPlusPlus(**kwargs)
    elif arch == 'attention_unet': model = smp.Unet(**kwargs, decoder_attention_type='scse')
    elif arch == 'manet'         : model = smp.MAnet(**kwargs)
    else: raise ValueError(f'Unknown arch: {arch}')

    n = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'✅ Model: {arch} | Encoder: {cfg["encoder"]} | Params: {n:,}')
    return model

# Quick sanity check
model = build_model(CFG).to(DEVICE)
dummy = torch.randn(2, 3, CFG['image_size'], CFG['image_size']).to(DEVICE)
with torch.no_grad():
    out = model(dummy)
print(f'   Input : {dummy.shape}')
print(f'   Output: {out.shape}  ✓')
del dummy, out


## Cell 9 — Metrics

In [None]:
def compute_metrics(prob, target, threshold=0.5):
    eps  = 1e-8
    pred = (prob >= threshold).astype(bool)
    gt   = target.astype(bool)
    tp = (pred &  gt).sum()
    fp = (pred & ~gt).sum()
    fn = (~pred &  gt).sum()
    tn = (~pred & ~gt).sum()
    dice        = (2*tp + eps) / (2*tp + fp + fn + eps)
    sensitivity = (tp + eps) / (tp + fn + eps)
    specificity = (tn + eps) / (tn + fp + eps)
    accuracy    = (tp + tn) / (tp + tn + fp + fn + eps)
    try:    auc = float(roc_auc_score(target.astype(int), prob))
    except: auc = 0.0
    try:    mcc = float(matthews_corrcoef(target.astype(int), pred.astype(int)))
    except: mcc = 0.0
    return dict(dice=float(dice), auc_roc=float(auc),
                sensitivity=float(sensitivity), specificity=float(specificity),
                accuracy=float(accuracy), mcc=float(mcc))

def find_optimal_threshold(all_probs, all_targets):
    fpr, tpr, thresholds = roc_curve(all_targets.astype(int), all_probs)
    j = tpr - fpr
    return float(thresholds[np.argmax(j)])

print('✅ Metrics defined.')


## Cell 10 — Training Loop

In [None]:
def train_epoch(model, loader, optimizer, criterion, device, cfg):
    model.train()
    total_loss = 0.0
    all_probs, all_targets = [], []

    for images, masks in tqdm(loader, desc='  Train', leave=False):
        images = images.to(device)
        masks  = masks.to(device)
        optimizer.zero_grad(set_to_none=True)

        # FIX: no autocast on MPS (not supported); plain forward pass
        logits = model(images)
        loss   = criterion(logits, masks)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), cfg['grad_clip'])
        optimizer.step()

        total_loss += loss.item()
        with torch.no_grad():
            all_probs.append(torch.sigmoid(logits).cpu().numpy().flatten())
            all_targets.append(masks.cpu().numpy().flatten())

    probs   = np.concatenate(all_probs)
    targets = np.concatenate(all_targets)
    m = compute_metrics(probs, targets)
    m['loss'] = total_loss / len(loader)
    return m


@torch.no_grad()
def val_epoch(model, loader, criterion, device, cfg):
    model.eval()
    total_loss = 0.0
    all_probs, all_targets = [], []

    for images, masks in tqdm(loader, desc='  Val  ', leave=False):
        images = images.to(device)
        masks  = masks.to(device)
        logits = model(images)
        loss   = criterion(logits, masks)
        total_loss += loss.item()
        all_probs.append(torch.sigmoid(logits).cpu().numpy().flatten())
        all_targets.append(masks.cpu().numpy().flatten())

    probs   = np.concatenate(all_probs)
    targets = np.concatenate(all_targets)
    thresh  = find_optimal_threshold(probs, targets)
    m = compute_metrics(probs, targets, threshold=thresh)
    m['loss']      = total_loss / len(loader)
    m['threshold'] = thresh
    return m


def run_training(cfg):
    train_loader, val_loader, test_loader, test_ds = build_dataloaders(cfg)

    model     = build_model(cfg).to(DEVICE)
    criterion = build_loss(cfg)
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg['lr'],
                                  weight_decay=cfg['weight_decay'])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=cfg['epochs'], eta_min=1e-7)

    # FIX: GradScaler updated API; disabled on MPS (not supported)
    # scaler left out entirely since mixed_precision=False on MPS

    save_dir  = Path(cfg['save_dir'])
    best_ckpt = save_dir / 'best_model.pth'
    history   = {'train_loss':[], 'val_loss':[], 'train_dice':[], 'val_dice':[]}
    best_dice = 0.0
    best_thresh = 0.5
    no_improve  = 0

    print(f"\n{'='*60}")
    print(f"  Training: {cfg['architecture']} + {cfg['encoder']}")
    print(f"  Epochs: {cfg['epochs']} | Batch: {cfg['batch_size']} | LR: {cfg['lr']}")
    print(f"  Loss: {cfg['loss_type']} | Device: {DEVICE}")
    print(f"  Image size: {cfg['image_size']}×{cfg['image_size']}")
    print(f"{'='*60}\n")

    for epoch in range(1, cfg['epochs'] + 1):
        t0 = time.time()
        train_m = train_epoch(model, train_loader, optimizer, criterion, DEVICE, cfg)
        val_m   = val_epoch(model, val_loader, criterion, DEVICE, cfg)
        scheduler.step()

        elapsed = time.time() - t0
        lr_now  = optimizer.param_groups[0]['lr']

        print(
            f"Ep [{epoch:03d}/{cfg['epochs']}] "
            f"Train Loss={train_m['loss']:.4f} Dice={train_m['dice']:.4f} | "
            f"Val Loss={val_m['loss']:.4f} Dice={val_m['dice']:.4f} "
            f"AUC={val_m['auc_roc']:.4f} Sens={val_m['sensitivity']:.4f} "
            f"Spec={val_m['specificity']:.4f} Thr={val_m['threshold']:.3f} "
            f"LR={lr_now:.2e} [{elapsed:.0f}s]"
        )

        history['train_loss'].append(train_m['loss'])
        history['val_loss'].append(val_m['loss'])
        history['train_dice'].append(train_m['dice'])
        history['val_dice'].append(val_m['dice'])

        if val_m['dice'] > best_dice:
            best_dice   = val_m['dice']
            best_thresh = val_m['threshold']
            no_improve  = 0
            torch.save({
                'epoch'           : epoch,
                'model_state_dict': model.state_dict(),
                'best_dice'       : best_dice,
                'threshold'       : best_thresh,
                'val_metrics'     : val_m,
                'cfg'             : cfg,
            }, best_ckpt)
            print(f'  ✅ New best! Dice={best_dice:.4f}  Threshold={best_thresh:.3f}')
        else:
            no_improve += 1
            if no_improve >= cfg['early_stopping']:
                print(f'\n⛔ Early stopping at epoch {epoch}')
                break

    print(f"\n{'='*60}")
    print(f'  Training complete! Best Val Dice: {best_dice:.4f}')
    print(f"{'='*60}\n")
    return model, history, test_loader, test_ds, best_thresh


print('✅ Training loop defined.')


## Cell 11 — 🚀 Run Training

In [None]:
# ── Quick test toggle ────────────────────────────────────────────────────
QUICK_TEST = True   # ← Set False for full 50-epoch training

if QUICK_TEST:
    print('⚡ QUICK TEST MODE — 5 epochs only')
    CFG['epochs']         = 5
    CFG['early_stopping'] = 99

model, history, test_loader, test_ds, best_threshold = run_training(CFG)


## Cell 12 — Training Curves

In [None]:
epochs_ran = range(1, len(history['train_loss']) + 1)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4))

ax1.plot(epochs_ran, history['train_loss'], label='Train Loss', color='steelblue', lw=2)
ax1.plot(epochs_ran, history['val_loss'],   label='Val Loss',   color='tomato',    lw=2)
ax1.set_xlabel('Epoch'); ax1.set_ylabel('Loss')
ax1.set_title('Loss Curve'); ax1.legend(); ax1.grid(alpha=0.3)

ax2.plot(epochs_ran, history['train_dice'], label='Train Dice', color='steelblue', lw=2)
ax2.plot(epochs_ran, history['val_dice'],   label='Val Dice',   color='tomato',    lw=2)
ax2.axhline(0.82, linestyle='--', color='green', alpha=0.7, label='Target (0.82)')
ax2.set_xlabel('Epoch'); ax2.set_ylabel('Dice Score')
ax2.set_title('Dice Score Curve'); ax2.legend(); ax2.grid(alpha=0.3)

plt.suptitle('Training History — FIVES Retinal Vessel Segmentation', fontsize=13, y=1.02)
plt.tight_layout()
plt.savefig(f"{CFG['save_dir']}/training_curves.png", dpi=150, bbox_inches='tight')
plt.show()
print('✅ Training curves saved.')


## Cell 13 — Test Set Evaluation

In [None]:
@torch.no_grad()
def evaluate_test_set(model, test_loader, threshold, device, cfg):
    model.eval()
    all_probs, all_targets = [], []

    for images, masks in tqdm(test_loader, desc='Evaluating test set'):
        images = images.to(device)
        logits = model(images)
        probs  = torch.sigmoid(logits).cpu().numpy().flatten()
        all_probs.append(probs)
        all_targets.append(masks.numpy().flatten())

    all_probs   = np.concatenate(all_probs)
    all_targets = np.concatenate(all_targets)
    metrics     = compute_metrics(all_probs, all_targets, threshold)

    print(f"\n{'='*55}")
    print(f"  FIVES Test Set — Final Results  (threshold={threshold:.3f})")
    print(f"{'='*55}")
    rows = [
        ('Dice Score',  metrics['dice'],        0.82, '≥ 0.82'),
        ('AUC-ROC',     metrics['auc_roc'],     0.98, '≥ 0.98'),
        ('Sensitivity', metrics['sensitivity'], 0.80, '≥ 0.80'),
        ('Specificity', metrics['specificity'], 0.97, '≥ 0.97'),
        ('MCC',         metrics['mcc'],         None, ''),
        ('Accuracy',    metrics['accuracy'],    None, ''),
    ]
    for name, val, target, label in rows:
        passed = '✅' if (target is None or val >= target) else '❌'
        bar    = f"(target {label})" if label else ''
        print(f"  {passed} {name:<14}: {val:.4f}  {bar}")

    # FIX: all() takes a single iterable, not multiple args
    all_pass = all([
        metrics['dice'] >= 0.82,
        metrics['auc_roc'] >= 0.98,
        metrics['sensitivity'] >= 0.80,
        metrics['specificity'] >= 0.97,
    ])
    print(f"{'='*55}")
    print(f"  PRD Criteria: {'✅ ALL PASSED' if all_pass else '❌ Train more epochs'}")
    print(f"{'='*55}\n")
    return metrics, all_probs, all_targets


# Load best checkpoint
best_ckpt_path = f"{CFG['save_dir']}/best_model.pth"
ckpt = torch.load(best_ckpt_path, map_location=DEVICE)
model.load_state_dict(ckpt['model_state_dict'])
best_threshold = ckpt.get('threshold', CFG['threshold'])
print(f'✅ Loaded best model (epoch {ckpt["epoch"]}, val Dice={ckpt["best_dice"]:.4f})')

test_metrics, test_probs, test_targets = evaluate_test_set(
    model, test_loader, best_threshold, DEVICE, CFG)


## Cell 14 — Qualitative Visualization

In [None]:
def denormalize(tensor):
    mean = np.array([0.485, 0.456, 0.406])
    std  = np.array([0.229, 0.224, 0.225])
    img  = tensor.permute(1,2,0).cpu().numpy()
    return ((img * std + mean).clip(0,1) * 255).astype(np.uint8)

@torch.no_grad()
def visualize_predictions(model, test_ds, threshold, device, cfg, n_samples=6):
    model.eval()
    indices = np.linspace(0, len(test_ds)-1, n_samples, dtype=int)
    fig, axes = plt.subplots(n_samples, 5, figsize=(20, 4*n_samples))
    for ax, t in zip(axes[0], ['Original','GT Overlay','Pred Overlay','Error Map','Prob Heatmap']):
        ax.set_title(t, fontsize=11, fontweight='bold')

    os.makedirs(f"{cfg['save_dir']}/overlays", exist_ok=True)

    for row, idx in enumerate(indices):
        img_tensor, mask_tensor = test_ds[idx]
        logit = model(img_tensor.unsqueeze(0).to(device))
        prob  = torch.sigmoid(logit).squeeze().cpu().numpy()
        pred  = (prob >= threshold).astype(np.uint8)
        gt    = mask_tensor.squeeze().numpy().astype(np.uint8)
        img   = denormalize(img_tensor)
        m     = compute_metrics(prob.flatten(), gt.flatten(), threshold)

        axes[row,0].imshow(img)
        axes[row,0].set_ylabel(f'img_{idx}\nDice={m["dice"]:.3f}', fontsize=8)
        axes[row,0].axis('off')

        gt_ov = img.copy()
        gt_ov[gt==1] = (gt_ov[gt==1]*0.5 + np.array([0,200,0])*0.5).astype(np.uint8)
        axes[row,1].imshow(gt_ov); axes[row,1].axis('off')

        pr_ov = img.copy()
        pr_ov[pred==1] = (pr_ov[pred==1]*0.5 + np.array([0,200,0])*0.5).astype(np.uint8)
        axes[row,2].imshow(pr_ov); axes[row,2].axis('off')

        err = img.copy().astype(np.float32)
        tp_ = (gt.astype(bool) & pred.astype(bool))
        fp_ = (~gt.astype(bool) & pred.astype(bool))
        fn_ = (gt.astype(bool) & ~pred.astype(bool))
        err[tp_]=[0,200,0]; err[fp_]=[0,0,255]; err[fn_]=[255,0,0]
        axes[row,3].imshow(err.astype(np.uint8))
        axes[row,3].legend(handles=[
            mpatches.Patch(color='lime',label='TP'),
            mpatches.Patch(color='blue',label='FP'),
            mpatches.Patch(color='red', label='FN')
        ], loc='lower right', fontsize=7)
        axes[row,3].axis('off')

        im = axes[row,4].imshow(prob, cmap='hot', vmin=0, vmax=1)
        plt.colorbar(im, ax=axes[row,4], fraction=0.046, pad=0.04)
        axes[row,4].axis('off')

        cv2.imwrite(f"{cfg['save_dir']}/overlays/img_{idx}_overlay.png",
                    cv2.cvtColor(pr_ov, cv2.COLOR_RGB2BGR))

    plt.suptitle(f'Qualitative Results — FIVES | Threshold={threshold:.3f}',
                 fontsize=13, y=1.01)
    plt.tight_layout()
    fig_path = f"{cfg['save_dir']}/qualitative_results.png"
    plt.savefig(fig_path, dpi=100, bbox_inches='tight')
    plt.show()
    print(f'✅ Saved → {fig_path}')


visualize_predictions(model, test_ds, best_threshold, DEVICE, CFG, n_samples=6)


## Cell 15 — Thin Vessel Zoom

In [None]:
@torch.no_grad()
def thin_vessel_zoom(model, test_ds, threshold, device, cfg, sample_idx=0, n_patches=4):
    model.eval()
    img_tensor, mask_tensor = test_ds[sample_idx]
    logit = model(img_tensor.unsqueeze(0).to(device))
    prob  = torch.sigmoid(logit).squeeze().cpu().numpy()
    pred  = (prob >= threshold).astype(np.uint8)
    gt    = mask_tensor.squeeze().numpy().astype(np.uint8)
    img   = denormalize(img_tensor)

    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5,5))
    thin   = gt - cv2.erode(gt, kernel, iterations=1)
    ys, xs = np.where(thin > 0)
    if len(ys) == 0:
        print('⚠️  No thin vessels found. Try a different sample_idx.')
        return

    pad  = 48
    h, w = thin.shape
    idxs = random.sample(range(len(ys)), min(n_patches, len(ys)))
    fig, axes = plt.subplots(n_patches, 3, figsize=(9, 3*n_patches))
    if n_patches == 1: axes = axes[np.newaxis, :]

    for row, i in enumerate(idxs):
        cy, cx = int(ys[i]), int(xs[i])
        y0,y1 = max(0,cy-pad), min(h,cy+pad)
        x0,x1 = max(0,cx-pad), min(w,cx+pad)
        axes[row,0].imshow(img[y0:y1,x0:x1]); axes[row,0].set_title('Fundus',fontsize=9); axes[row,0].axis('off')
        axes[row,1].imshow(gt[y0:y1,x0:x1],cmap='gray'); axes[row,1].set_title('GT Mask',fontsize=9); axes[row,1].axis('off')
        axes[row,2].imshow(pred[y0:y1,x0:x1],cmap='gray'); axes[row,2].set_title('Predicted',fontsize=9); axes[row,2].axis('off')

    plt.suptitle(f'Thin Vessel Zoom — Test Image {sample_idx}', fontsize=11)
    plt.tight_layout()
    plt.savefig(f"{cfg['save_dir']}/thin_vessel_zoom.png", dpi=150, bbox_inches='tight')
    plt.show()
    print('✅ Thin vessel zoom saved.')


thin_vessel_zoom(model, test_ds, best_threshold, DEVICE, CFG, sample_idx=0, n_patches=4)


## Cell 16 — Save Outputs

In [None]:
import json, shutil

final_metrics = {
    **test_metrics,
    'threshold': best_threshold,
    'n_test'   : len(test_ds),
    'cfg'      : {k: str(v) for k, v in CFG.items()},
}
metrics_path = f"{CFG['save_dir']}/final_metrics.json"
with open(metrics_path, 'w') as f:
    json.dump(final_metrics, f, indent=2)
print(f'✅ Metrics saved → {metrics_path}')

zip_path = './retinal_segmentation_outputs'
shutil.make_archive(zip_path, 'zip', CFG['save_dir'])
print(f'✅ Zipped → {zip_path}.zip')

print('\n📁 Output files:')
for fp in sorted(Path(CFG['save_dir']).rglob('*')):
    if fp.is_file():
        print(f'   {fp.relative_to(CFG["save_dir"])}')
