# Phase 2: Vessel Segmentation and XY Registration

## Objective
Use anatomical vessel structures to align OCT volumes in the XY plane.

## Approach
1. Load MIP en-face projections from Phase 1
2. Segment vessels using pre-trained models
3. Extract vessel centerlines and bifurcations
4. Register volumes based on vessel patterns

## Pre-trained Models Available
- **OCTA-autosegmentation** (GAN-based, OCTA-500 dataset)
- **BOE2020-OCTA-vessel-segmentation** (Deep iterative, pretrained weights)
- **OCT2Former** (Transformer-based, state-of-art on OCTA-500)
- **SGL-Retinal-Vessel-Segmentation** (SOTA on DRIVE dataset)

---

## Setup and Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import sys
from scipy import ndimage
from skimage import filters, morphology, measure
from skimage.filters import frangi
from skimage.morphology import skeletonize
import cv2

# Configure matplotlib
plt.rcParams['figure.figsize'] = (15, 10)
plt.rcParams['image.cmap'] = 'gray'

print("âœ“ Imports successful")

## Load MIP En-Face Projections

Load the MIP projections we saved in Phase 1.

In [None]:
# Load saved MIP projection from Phase 1
data_dir = Path('../notebooks/data')

enface_mip = np.load(data_dir / 'enface_mip_volume0.npy')

print(f"Loaded MIP en-face projection")
print(f"Shape: {enface_mip.shape}")
print(f"Value range: [{enface_mip.min():.1f}, {enface_mip.max():.1f}]")

# Visualize
plt.figure(figsize=(12, 10))
plt.imshow(enface_mip.T, aspect='auto', cmap='gray')
plt.title('MIP En-Face Projection (from Phase 1)', fontsize=14)
plt.xlabel('X (lateral)')
plt.ylabel('Z (B-scan number)')
plt.colorbar(label='Intensity')
plt.show()

---
## Method 1: Classical Frangi Filter (No Pre-trained Model)

**Fast baseline**: Use Frangi vesselness filter to enhance vessels without deep learning.

**Pros**: No model download, fast, deterministic  
**Cons**: Less accurate than deep learning

In [None]:
def enhance_vessels_frangi(enface, sigmas=range(1, 6)):
    """
    Enhance vessels using Frangi filter (classical method).
    
    Args:
        enface: 2D en-face image
        sigmas: Range of vessel widths to detect
    
    Returns:
        vessel_enhanced: Vessel-enhanced image
    """
    # Normalize to 0-1
    img_norm = (enface - enface.min()) / (enface.max() - enface.min())
    
    # Apply Frangi filter
    vessel_enhanced = frangi(img_norm, sigmas=sigmas, black_ridges=False)
    
    return vessel_enhanced


# Test Frangi filter
vessels_frangi = enhance_vessels_frangi(enface_mip)

fig, axes = plt.subplots(1, 2, figsize=(18, 8))

axes[0].imshow(enface_mip.T, aspect='auto', cmap='gray')
axes[0].set_title('Original MIP', fontsize=14)
axes[0].set_xlabel('X')
axes[0].set_ylabel('Z')

axes[1].imshow(vessels_frangi.T, aspect='auto', cmap='hot')
axes[1].set_title('Frangi Vessel Enhancement', fontsize=14)
axes[1].set_xlabel('X')
axes[1].set_ylabel('Z')

plt.tight_layout()
plt.show()

print(f"âœ“ Frangi filter applied")
print(f"Vessel response range: [{vessels_frangi.min():.3f}, {vessels_frangi.max():.3f}]")

---
## Method 2: Pre-trained Deep Learning Model

**Options**:
1. **BOE2020-OCTA-vessel-segmentation** - Easy to use, pretrained weights available
2. **OCTA-autosegmentation** - GAN-based, Docker container
3. **OCT2Former** - Transformer, state-of-art

Let's start with **BOE2020** as it has simple pretrained model loading.

In [None]:
# Install required packages
# !pip install torch torchvision -q

print("ðŸ“‹ TO USE PRE-TRAINED MODEL:")
print("1. Clone repository: git clone https://github.com/RViMLab/BOE2020-OCTA-vessel-segmentation.git")
print("2. Follow example_use_pretrained.py")
print("3. Load pretrained weights and run inference")
print("\nFor now, we'll continue with Frangi filter as baseline.")

---
## Extract Vessel Binary Mask

Threshold vessel-enhanced image to get binary vessel map.

In [None]:
def extract_vessel_mask(vessel_enhanced, threshold_percentile=75):
    """
    Create binary vessel mask from vessel-enhanced image.
    
    Args:
        vessel_enhanced: Vessel-enhanced image (Frangi output)
        threshold_percentile: Percentile for thresholding
    
    Returns:
        vessel_mask: Binary mask (True = vessel)
    """
    threshold = np.percentile(vessel_enhanced, threshold_percentile)
    vessel_mask = vessel_enhanced > threshold
    
    # Clean up small noise
    vessel_mask = morphology.remove_small_objects(vessel_mask, min_size=20)
    
    return vessel_mask


# Create vessel mask
vessel_mask = extract_vessel_mask(vessels_frangi, threshold_percentile=80)

fig, axes = plt.subplots(1, 3, figsize=(20, 6))

axes[0].imshow(enface_mip.T, aspect='auto', cmap='gray')
axes[0].set_title('Original MIP', fontsize=14)

axes[1].imshow(vessels_frangi.T, aspect='auto', cmap='hot')
axes[1].set_title('Frangi Enhancement', fontsize=14)

axes[2].imshow(vessel_mask.T, aspect='auto', cmap='gray')
axes[2].set_title('Binary Vessel Mask', fontsize=14)

plt.tight_layout()
plt.show()

vessel_density = vessel_mask.sum() / vessel_mask.size * 100
print(f"âœ“ Vessel mask created")
print(f"Vessel density: {vessel_density:.2f}%")

---
## Extract Vessel Centerlines and Bifurcations

Skeletonize vessels and find bifurcation points for registration landmarks.

In [None]:
def extract_vessel_skeleton(vessel_mask):
    """
    Extract vessel centerlines using skeletonization.
    
    Args:
        vessel_mask: Binary vessel mask
    
    Returns:
        skeleton: Binary skeleton image
    """
    skeleton = skeletonize(vessel_mask)
    return skeleton


def find_bifurcations(skeleton):
    """
    Find vessel bifurcation points (junctions).
    
    A bifurcation is where 3+ skeleton pixels meet.
    
    Args:
        skeleton: Binary skeleton image
    
    Returns:
        bifurcations: List of (x, z) bifurcation coordinates
    """
    # Count neighbors for each skeleton pixel
    kernel = np.ones((3, 3), dtype=np.uint8)
    kernel[1, 1] = 0  # Don't count center pixel
    
    neighbor_count = ndimage.convolve(skeleton.astype(np.uint8), kernel, mode='constant')
    
    # Bifurcations have 3+ neighbors
    bifurcation_mask = (neighbor_count >= 3) & skeleton
    
    # Get coordinates
    bifurcations = np.argwhere(bifurcation_mask)
    
    return bifurcations


# Extract skeleton and bifurcations
skeleton = extract_vessel_skeleton(vessel_mask)
bifurcations = find_bifurcations(skeleton)

fig, axes = plt.subplots(1, 3, figsize=(20, 6))

axes[0].imshow(vessel_mask.T, aspect='auto', cmap='gray')
axes[0].set_title('Vessel Mask', fontsize=14)

axes[1].imshow(skeleton.T, aspect='auto', cmap='gray')
axes[1].set_title('Vessel Skeleton (Centerlines)', fontsize=14)

# Overlay bifurcations
axes[2].imshow(enface_mip.T, aspect='auto', cmap='gray', alpha=0.7)
axes[2].imshow(skeleton.T, aspect='auto', cmap='hot', alpha=0.5)
if len(bifurcations) > 0:
    axes[2].scatter(bifurcations[:, 0], bifurcations[:, 1], c='cyan', s=50, marker='o', 
                    edgecolors='yellow', linewidths=2, label=f'{len(bifurcations)} bifurcations')
axes[2].set_title('Vessels + Bifurcation Points', fontsize=14)
axes[2].legend()

plt.tight_layout()
plt.show()

print(f"âœ“ Extracted {len(bifurcations)} bifurcation points")
print(f"Skeleton density: {skeleton.sum() / skeleton.size * 100:.2f}%")

---
## Save Results for Registration

Save vessel features for the next step: multi-volume registration.

In [None]:
# Save vessel analysis results
output_dir = Path('../notebooks/data')

np.save(output_dir / 'vessels_frangi_volume0.npy', vessels_frangi)
np.save(output_dir / 'vessel_mask_volume0.npy', vessel_mask)
np.save(output_dir / 'vessel_skeleton_volume0.npy', skeleton)
np.save(output_dir / 'bifurcations_volume0.npy', bifurcations)

print(f"âœ“ Saved vessel features to {output_dir}")
print(f"\nðŸ“‹ Next Steps:")
print(f"  1. Load second volume and extract vessel features")
print(f"  2. Match bifurcation patterns between volumes")
print(f"  3. Calculate XY offset for alignment")
print(f"  4. Apply registration and verify")
print(f"\nâœ“ Phase 2 vessel segmentation complete!")