# Phase 3: XY Registration Using Phase Correlation

This notebook performs XY-plane registration of two OCT volumes using **phase correlation** instead of bifurcation matching.

**‚ö†Ô∏è NOTE: Phase 2 (vessel segmentation) can be SKIPPED!**  
This notebook uses MIP directly from Phase 1. Vessel segmentation is only needed for vessel analysis, not for registration.

## Why Phase Correlation?
- **No feature detection needed** - uses entire vessel pattern
- **Fast** - single correlation operation
- **Robust** - works even with noise
- **Accurate** - sub-pixel precision

## Workflow:
1. Load MIP en-face images from **Phase 1** (not Phase 2!)
2. Apply phase correlation to find XY translation
3. Validate alignment quality
4. Visualize before/after registration
5. Save registration parameters

## Setup and Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from scipy import signal, ndimage

# Setup paths
data_dir = Path('data')

print("‚úì Imports complete")

## Step 1: Load MIP Data from Both Volumes

In [None]:
# Load MIP en-face projections
mip_v0 = np.load(data_dir / 'enface_mip_volume0.npy')
mip_v1 = np.load(data_dir / 'enface_mip_volume1.npy')

print(f"Volume 0 MIP shape: {mip_v0.shape}")
print(f"Volume 1 MIP shape: {mip_v1.shape}")

# Visualize initial state
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

axes[0].imshow(mip_v0.T, cmap='gray', origin='lower')
axes[0].set_title('Volume 0: MIP En-face')
axes[0].set_xlabel('X (width)')
axes[0].set_ylabel('Z (depth/B-scan)')

axes[1].imshow(mip_v1.T, cmap='gray', origin='lower')
axes[1].set_title('Volume 1: MIP En-face')
axes[1].set_xlabel('X (width)')
axes[1].set_ylabel('Z (depth/B-scan)')

# Overlay before registration
axes[2].imshow(mip_v0.T, cmap='Reds', alpha=0.5, origin='lower')
axes[2].imshow(mip_v1.T, cmap='Greens', alpha=0.5, origin='lower')
axes[2].set_title('Overlay Before Registration\n(Red=Vol0, Green=Vol1)')
axes[2].set_xlabel('X (width)')
axes[2].set_ylabel('Z (depth/B-scan)')

plt.tight_layout()
plt.show()

## Step 2: Phase Correlation Registration

Phase correlation finds the translation offset by:
1. Normalizing both images
2. Computing 2D correlation
3. Finding the peak ‚Üí this is the offset!

In [None]:
def register_mip_phase_correlation(mip1, mip2):
    """
    Register two MIP en-face images using phase correlation.
    
    Args:
        mip1: Reference MIP from Volume 0
        mip2: MIP to align from Volume 1
    
    Returns:
        (dx, dz): Translation offset (lateral X, B-scan Z)
        confidence: Match quality score
    """
    # Normalize images (remove mean and scale by std)
    mip1_norm = (mip1 - mip1.mean()) / (mip1.std() + 1e-8)
    mip2_norm = (mip2 - mip2.mean()) / (mip2.std() + 1e-8)
    
    # Compute 2D correlation
    # This creates a "heat map" showing how well images match at each shift
    correlation = signal.correlate2d(mip1_norm, mip2_norm, mode='same')
    
    # Find peak (strongest match position)
    peak_x, peak_z = np.unravel_index(np.argmax(correlation), correlation.shape)
    center_x, center_z = np.array(correlation.shape) // 2
    
    # Calculate offset from center
    offset_x = peak_x - center_x  # Lateral shift
    offset_z = peak_z - center_z  # B-scan shift
    
    # Confidence = peak strength relative to noise
    confidence = correlation.max() / (correlation.std() + 1e-8)
    
    return (offset_x, offset_z), confidence, correlation

# Perform registration
print("Performing phase correlation registration...")
(offset_x, offset_z), confidence, correlation_map = register_mip_phase_correlation(mip_v0, mip_v1)

print(f"\n‚úì Registration complete!")
print(f"\nüìç Translation offset:")
print(f"   X (lateral): {offset_x} pixels")
print(f"   Z (B-scan):  {offset_z} pixels")
print(f"\n‚ú® Confidence score: {confidence:.2f}")
print(f"   (Higher is better, typically > 5.0 is good)")

## Step 3: Visualize Correlation Map

In [None]:
# Visualize correlation map
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Full correlation map
im1 = axes[0].imshow(correlation_map, cmap='hot', aspect='auto')
peak_x_idx, peak_z_idx = np.unravel_index(np.argmax(correlation_map), correlation_map.shape)
axes[0].plot(peak_z_idx, peak_x_idx, 'g*', markersize=20, label=f'Peak: offset=({offset_x}, {offset_z})')
axes[0].set_title('Correlation Map\n(Brighter = better match)')
axes[0].set_xlabel('Z shift (B-scans)')
axes[0].set_ylabel('X shift (lateral)')
axes[0].legend()
plt.colorbar(im1, ax=axes[0], label='Correlation strength')

# Zoomed view around peak
zoom_size = 50
x_min = max(0, peak_x_idx - zoom_size)
x_max = min(correlation_map.shape[0], peak_x_idx + zoom_size)
z_min = max(0, peak_z_idx - zoom_size)
z_max = min(correlation_map.shape[1], peak_z_idx + zoom_size)

zoomed = correlation_map[x_min:x_max, z_min:z_max]
im2 = axes[1].imshow(zoomed, cmap='hot', aspect='auto')
axes[1].plot(peak_z_idx - z_min, peak_x_idx - x_min, 'g*', markersize=20)
axes[1].set_title(f'Zoomed Around Peak\nConfidence: {confidence:.2f}')
axes[1].set_xlabel('Z shift')
axes[1].set_ylabel('X shift')
plt.colorbar(im2, ax=axes[1], label='Correlation')

plt.tight_layout()
plt.show()

## Step 4: Apply Registration and Visualize Results

In [None]:
# Apply translation to volume 1 MIP
# Note: shift order is (x, z) matching our MIP array shape
mip_v1_aligned = ndimage.shift(mip_v1, shift=(offset_x, offset_z), order=1, mode='constant', cval=0)

# Visualize before/after
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Before: Overlay unaligned
axes[0, 0].imshow(mip_v0.T, cmap='Reds', alpha=0.5, origin='lower')
axes[0, 0].imshow(mip_v1.T, cmap='Greens', alpha=0.5, origin='lower')
axes[0, 0].set_title('Before Registration\n(Red=Vol0, Green=Vol1)', fontsize=14, fontweight='bold')
axes[0, 0].set_xlabel('X (width)')
axes[0, 0].set_ylabel('Z (B-scan)')

# After: Overlay aligned
axes[0, 1].imshow(mip_v0.T, cmap='Reds', alpha=0.5, origin='lower')
axes[0, 1].imshow(mip_v1_aligned.T, cmap='Greens', alpha=0.5, origin='lower')
axes[0, 1].set_title(f'After Registration\nOffset: ({offset_x}, {offset_z}) px', fontsize=14, fontweight='bold')
axes[0, 1].set_xlabel('X (width)')
axes[0, 1].set_ylabel('Z (B-scan)')

# Difference before
diff_before = np.abs(mip_v0.astype(float) - mip_v1.astype(float))
im1 = axes[1, 0].imshow(diff_before.T, cmap='hot', origin='lower')
axes[1, 0].set_title(f'Difference Before\nMean: {diff_before.mean():.2f}', fontsize=14, fontweight='bold')
axes[1, 0].set_xlabel('X (width)')
axes[1, 0].set_ylabel('Z (B-scan)')
plt.colorbar(im1, ax=axes[1, 0])

# Difference after
diff_after = np.abs(mip_v0.astype(float) - mip_v1_aligned.astype(float))
im2 = axes[1, 1].imshow(diff_after.T, cmap='hot', origin='lower')
axes[1, 1].set_title(f'Difference After\nMean: {diff_after.mean():.2f}', fontsize=14, fontweight='bold')
axes[1, 1].set_xlabel('X (width)')
axes[1, 1].set_ylabel('Z (B-scan)')
plt.colorbar(im2, ax=axes[1, 1])

plt.tight_layout()
plt.show()

# Calculate improvement
improvement = 100 * (1 - diff_after.mean() / diff_before.mean())

print("\nüìä Registration Quality Metrics:")
print(f"  Translation: dx={offset_x}, dz={offset_z} pixels")
print(f"  Confidence: {confidence:.2f}")
print(f"  Mean difference before: {diff_before.mean():.2f}")
print(f"  Mean difference after: {diff_after.mean():.2f}")
print(f"  Improvement: {improvement:.1f}%")

if improvement > 0:
    print(f"\n‚úÖ Registration successful! Vessel alignment improved by {improvement:.1f}%")
else:
    print(f"\n‚ö†Ô∏è  Warning: Registration may not have improved alignment")

## Step 5: Checkerboard Visualization

In [None]:
# Create checkerboard pattern for better visualization
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

checker_size = 100
h, w = mip_v0.shape
checkerboard = np.zeros((h, w), dtype=bool)
for i in range(0, h, checker_size):
    for j in range(0, w, checker_size):
        if ((i // checker_size) + (j // checker_size)) % 2 == 0:
            checkerboard[i:i+checker_size, j:j+checker_size] = True

# Before
checker_before = mip_v0.copy()
checker_before[~checkerboard] = mip_v1[~checkerboard]
axes[0].imshow(checker_before.T, cmap='gray', origin='lower')
axes[0].set_title('Checkerboard - Before Registration', fontsize=14, fontweight='bold')
axes[0].set_xlabel('X (width)')
axes[0].set_ylabel('Z (B-scan)')

# After
checker_after = mip_v0.copy()
checker_after[~checkerboard] = mip_v1_aligned[~checkerboard]
axes[1].imshow(checker_after.T, cmap='gray', origin='lower')
axes[1].set_title('Checkerboard - After Registration', fontsize=14, fontweight='bold')
axes[1].set_xlabel('X (width)')
axes[1].set_ylabel('Z (B-scan)')

plt.tight_layout()
plt.show()

print("\nüí° Tip: In the checkerboard view, vessels should appear continuous")
print("   across tile boundaries when properly aligned!")

## Step 6: Save Registration Results

In [None]:
# Save registration parameters
registration_params = {
    'method': 'phase_correlation',
    'offset_x': int(offset_x),
    'offset_z': int(offset_z),
    'confidence': float(confidence),
    'mean_diff_before': float(diff_before.mean()),
    'mean_diff_after': float(diff_after.mean()),
    'improvement_percent': float(improvement)
}

# Save results
np.save(data_dir / 'xy_registration_params.npy', registration_params)
np.save(data_dir / 'mip_v1_xy_aligned.npy', mip_v1_aligned)

print("‚úì Saved registration results:")
print(f"  - xy_registration_params.npy")
print(f"  - mip_v1_xy_aligned.npy")

print("\n‚úì Phase 3 XY registration complete!")
print("\nNext steps:")
print("  1. Phase 4: Z-axis alignment (if needed)")
print("  2. Phase 5: Y-axis (depth) alignment using retinal surfaces")
print("  3. Phase 6: Apply full 3D transformation and visualize merged volumes")

print(f"\n\n{'='*70}")
print("SUMMARY")
print(f"{'='*70}")
print(f"Method: Phase Correlation (no feature detection!)")
print(f"XY Translation: ({offset_x}, {offset_z}) pixels")
print(f"Confidence: {confidence:.2f}")
print(f"Alignment Improvement: {improvement:.1f}%")
print(f"{'='*70}")