# Y-axis (Depth) Alignment - 2D Phase Correlation Method

## Problem with previous 1D correlation approach:
- Averaged profile along X-axis: `ref_profile = ref_norm.mean(axis=1)`
- **Lost spatial information** about vessel structure
- Result: 53% confidence but only 0.1% improvement
- High variance in offsets (std=41.87 px)

## New approach: 2D Phase Correlation (inspired by Phase 3 success)
- Use **full 2D B-scan images** (not averaged profiles)
- Apply **phase correlation** like in X-Z alignment (which worked great!)
- Preserves spatial structure of vessels and retinal layers
- Find Y-offset for each B-scan individually

## Expected improvement:
- X-Z alignment: 15.8% improvement with 2D phase correlation ‚úÖ
- Y alignment: Currently 0.1% with 1D correlation ‚ùå
- Goal: >5% improvement with 2D phase correlation üéØ

## Pipeline:
1. Load volumes and X-Z registration results
2. Implement 2D phase correlation for Y-offset
3. Test on 10 sample B-scans
4. Compare with current 1D method
5. If successful: apply to all 360 B-scans
6. Visualize and save results

## 1. Setup and Imports

In [None]:
import sys
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from scipy import ndimage, signal
from scipy.fft import fft2, ifft2, fftshift
import time

# Add src to path
sys.path.append('../src')
from oct_volumetric_viewer import OCTImageProcessor, OCTVolumeLoader

# Setup paths
notebook_dir = Path.cwd()
if notebook_dir.name == 'notebooks':
    data_dir = notebook_dir / 'data'
    oct_data_dir = notebook_dir.parent / 'oct_data'
else:
    data_dir = notebook_dir / 'notebooks' / 'data'
    oct_data_dir = notebook_dir / 'oct_data'

plt.rcParams['figure.figsize'] = (15, 10)

print("‚úì Imports complete")

## 2. Load X-Z Registration Results

In [None]:
# Load X-Z registration parameters from Phase 3
xy_params = np.load(data_dir / 'xy_registration_params.npy', allow_pickle=True).item()

print("üìä X-Z Registration Parameters (from Phase 3):")
print(f"  Method: {xy_params.get('best_method', 'phase_correlation')}")
print(f"  X offset (lateral): {xy_params['offset_x']} pixels")
print(f"  Z offset (B-scan): {xy_params['offset_z']} pixels")
print(f"  Confidence: {xy_params['confidence']:.2f}")
print(f"  Improvement: {xy_params['improvement_percent']:.1f}%")
print(f"\nüí° 2D phase correlation worked great for X-Z! Let's use it for Y too.")

offset_x = xy_params['offset_x']
offset_z = xy_params['offset_z']

## 3. Load Full 3D Volumes

In [None]:
print("Loading full 3D OCT volumes...")
print("This may take 1-2 minutes...")

# Initialize loader
processor = OCTImageProcessor(sidebar_width=250, crop_top=100, crop_bottom=50)
loader = OCTVolumeLoader(processor)

# Find F001 volumes
bmp_dirs = []
for bmp_file in oct_data_dir.rglob('*.bmp'):
    vol_dir = bmp_file.parent
    if vol_dir not in bmp_dirs:
        bmp_dirs.append(vol_dir)

all_volume_dirs = sorted(bmp_dirs)
f001_vols = [v for v in all_volume_dirs if 'F001_IP' in str(v)]

if len(f001_vols) >= 2:
    print(f"\nUsing F001 data:")
    print(f"  Volume 0: {f001_vols[0].name}")
    print(f"  Volume 1: {f001_vols[1].name}")
    volume_dirs = f001_vols[:2]
else:
    print(f"\nWarning: Using first 2 available volumes")
    volume_dirs = all_volume_dirs[:2]

# Load volumes
print("\nLoading volume 0...")
volume_0 = loader.load_volume_from_directory(str(volume_dirs[0]))

print("Loading volume 1...")
volume_1 = loader.load_volume_from_directory(str(volume_dirs[1]))

if volume_0 is None or volume_1 is None:
    raise ValueError("Failed to load volumes")

print(f"\n‚úì Loaded volumes:")
print(f"  Volume 0: {volume_0.shape} (Y, X, Z)")
print(f"  Volume 1: {volume_1.shape}")
print(f"\n  Y = {volume_0.shape[0]} pixels (depth)")
print(f"  X = {volume_0.shape[1]} pixels (lateral)")
print(f"  Z = {volume_0.shape[2]} B-scans")

## 4. Apply X-Z Translation to Volume 1

In [None]:
# Apply X-Z translation
print("Applying X-Z translation to Volume 1...")
volume_1_xz_aligned = ndimage.shift(
    volume_1,
    shift=(0, offset_x, offset_z),  # (dy=0, dx, dz)
    order=1,
    mode='constant',
    cval=0
)

print(f"‚úì Applied X-Z alignment")
print(f"  Translation: X={offset_x}, Z={offset_z}")
print(f"  Volume 1 XZ-aligned shape: {volume_1_xz_aligned.shape}")

## 5. Implement 2D Phase Correlation for Y-offset

This is the KEY innovation: using 2D phase correlation instead of 1D averaged profiles.

In [None]:
def calculate_y_offset_2d_phase_correlation(bscan_ref, bscan_mov, max_shift=30):
    """
    Calculate Y-axis offset using 2D phase correlation.
    
    This preserves spatial structure of vessels, unlike 1D averaging.
    Based on successful X-Z alignment method from Phase 3.
    
    Args:
        bscan_ref: Reference B-scan (Y, X)
        bscan_mov: Moving B-scan (Y, X) - already XZ aligned
        max_shift: Maximum Y-shift to search (pixels)
    
    Returns:
        y_offset: Y-axis offset in pixels
        confidence: Quality score based on peak strength
    """
    # Normalize B-scans to zero mean, unit variance
    ref_norm = (bscan_ref - bscan_ref.mean()) / (bscan_ref.std() + 1e-8)
    mov_norm = (bscan_mov - bscan_mov.mean()) / (bscan_mov.std() + 1e-8)
    
    # 2D FFT phase correlation
    f_ref = fft2(ref_norm)
    f_mov = fft2(mov_norm)
    
    # Cross-power spectrum
    cross_power = (f_ref * np.conj(f_mov)) / (np.abs(f_ref * np.conj(f_mov)) + 1e-8)
    
    # Inverse FFT to get correlation
    correlation = np.fft.ifft2(cross_power).real
    correlation = np.fft.fftshift(correlation)
    
    # Find peak in Y-direction (restrict search to max_shift)
    center_y, center_x = np.array(correlation.shape) // 2
    
    # Search region: center_y ¬± max_shift, all X
    y_start = max(0, center_y - max_shift)
    y_end = min(correlation.shape[0], center_y + max_shift + 1)
    
    search_region = correlation[y_start:y_end, :]
    
    # Find peak position
    peak_y, peak_x = np.unravel_index(np.argmax(search_region), search_region.shape)
    
    # Convert to offset from center
    y_offset = (y_start + peak_y) - center_y
    
    # Confidence metric
    peak_value = search_region[peak_y, peak_x]
    
    # Method 1: Peak-to-noise ratio
    noise_level = search_region.std()
    psnr_confidence = peak_value / (noise_level + 1e-8)
    
    # Method 2: Peak sharpness (ratio to second peak)
    search_copy = search_region.copy()
    search_copy[peak_y, peak_x] = -np.inf
    second_peak = search_copy.max()
    
    if second_peak > 0:
        peak_sharpness = peak_value / second_peak
        sharpness_confidence = min(1.0, max(0.0, (peak_sharpness - 1.0) / 1.0))
    else:
        sharpness_confidence = 0.0
    
    # Combine confidences
    confidence = 0.5 * min(1.0, psnr_confidence / 10.0) + 0.5 * sharpness_confidence
    
    return y_offset, confidence


# For comparison: also implement the OLD 1D method
def calculate_y_offset_1d_correlation(bscan_ref, bscan_mov, max_shift=30):
    """
    OLD METHOD: 1D correlation on averaged profiles.
    Kept for comparison only.
    """
    # Normalize
    ref_norm = (bscan_ref - bscan_ref.mean()) / (bscan_ref.std() + 1e-8)
    mov_norm = (bscan_mov - bscan_mov.mean()) / (bscan_mov.std() + 1e-8)
    
    # Average along X (THIS IS THE PROBLEM!)
    ref_profile = ref_norm.mean(axis=1)  # (Y,)
    mov_profile = mov_norm.mean(axis=1)  # (Y,)
    
    # Normalize profiles
    ref_profile = ref_profile - ref_profile.mean()
    mov_profile = mov_profile - mov_profile.mean()
    
    # 1D correlation
    correlation = signal.correlate(ref_profile, mov_profile, mode='same')
    
    # Find peak
    center = len(correlation) // 2
    search_start = max(0, center - max_shift)
    search_end = min(len(correlation), center + max_shift + 1)
    
    search_region = correlation[search_start:search_end]
    peak_idx = np.argmax(search_region)
    peak_pos = search_start + peak_idx
    
    y_offset = peak_pos - center
    
    # Confidence
    ref_energy = np.sum(ref_profile ** 2)
    mov_energy = np.sum(mov_profile ** 2)
    
    if ref_energy > 0 and mov_energy > 0:
        ncc_peak = correlation[peak_pos] / np.sqrt(ref_energy * mov_energy)
        confidence = max(0.0, min(1.0, ncc_peak))
    else:
        confidence = 0.0
    
    return y_offset, confidence

print("‚úì 2D Phase Correlation functions defined")
print("‚úì 1D Correlation function (for comparison) defined")

## 6. Test on Sample B-scans

Test both methods on 10 B-scans to compare performance before running on all 360.

In [None]:
# Select 10 evenly distributed B-scans for testing
num_bscans = volume_0.shape[2]
test_indices = np.linspace(0, num_bscans-1, 10, dtype=int)

print(f"Testing on {len(test_indices)} B-scans: {test_indices.tolist()}")
print(f"\nComparing 2D Phase Correlation vs 1D Correlation...\n")
print("="*80)

results_2d = []
results_1d = []

for i, z in enumerate(test_indices):
    # Extract B-scans
    bscan_ref = volume_0[:, :, z]
    bscan_mov = volume_1_xz_aligned[:, :, z]
    
    # Method 1: 2D Phase Correlation
    offset_2d, conf_2d = calculate_y_offset_2d_phase_correlation(
        bscan_ref, bscan_mov, max_shift=30
    )
    
    # Method 2: 1D Correlation (old method)
    offset_1d, conf_1d = calculate_y_offset_1d_correlation(
        bscan_ref, bscan_mov, max_shift=30
    )
    
    results_2d.append({'z': z, 'offset': offset_2d, 'confidence': conf_2d})
    results_1d.append({'z': z, 'offset': offset_1d, 'confidence': conf_1d})
    
    print(f"B-scan Z={z:3d}:")
    print(f"  2D Phase Corr: offset={offset_2d:+5.1f} px, conf={conf_2d:.2%}")
    print(f"  1D Corr (old): offset={offset_1d:+5.1f} px, conf={conf_1d:.2%}")
    print(f"  Difference:    {abs(offset_2d - offset_1d):5.1f} px")
    print()

# Statistics
offsets_2d = np.array([r['offset'] for r in results_2d])
confs_2d = np.array([r['confidence'] for r in results_2d])

offsets_1d = np.array([r['offset'] for r in results_1d])
confs_1d = np.array([r['confidence'] for r in results_1d])

print("="*80)
print("üìä COMPARISON STATISTICS\n")
print(f"2D Phase Correlation:")
print(f"  Mean offset: {offsets_2d.mean():+.2f} ¬± {offsets_2d.std():.2f} px")
print(f"  Mean confidence: {confs_2d.mean():.2%}")
print(f"  Range: [{offsets_2d.min():.1f}, {offsets_2d.max():.1f}]")

print(f"\n1D Correlation (old):")
print(f"  Mean offset: {offsets_1d.mean():+.2f} ¬± {offsets_1d.std():.2f} px")
print(f"  Mean confidence: {confs_1d.mean():.2%}")
print(f"  Range: [{offsets_1d.min():.1f}, {offsets_1d.max():.1f}]")

print(f"\nüí° Key Metrics:")
print(f"  Offset std reduction: {100*(1 - offsets_2d.std()/offsets_1d.std()):.1f}%")
print(f"  Confidence improvement: {100*(confs_2d.mean() - confs_1d.mean())/confs_1d.mean():.1f}%")
print("="*80)

## 7. Visualize Sample Results

In [None]:
# Visualize comparison for 3 sample B-scans
sample_indices = [test_indices[2], test_indices[5], test_indices[8]]

fig, axes = plt.subplots(3, 4, figsize=(20, 12))

for row, z in enumerate(sample_indices):
    # Get B-scans
    bscan_ref = volume_0[:, :, z]
    bscan_mov = volume_1_xz_aligned[:, :, z]
    
    # Get offsets
    idx = np.where(test_indices == z)[0][0]
    offset_2d = results_2d[idx]['offset']
    offset_1d = results_1d[idx]['offset']
    conf_2d = results_2d[idx]['confidence']
    conf_1d = results_1d[idx]['confidence']
    
    # Apply shifts
    bscan_aligned_2d = ndimage.shift(bscan_mov, shift=(offset_2d, 0), order=1, mode='constant', cval=0)
    bscan_aligned_1d = ndimage.shift(bscan_mov, shift=(offset_1d, 0), order=1, mode='constant', cval=0)
    
    # Column 1: Reference
    axes[row, 0].imshow(bscan_ref, cmap='gray', aspect='auto')
    axes[row, 0].set_title(f'Vol 0 (ref)\nZ={z}', fontweight='bold')
    axes[row, 0].set_ylabel('Y (depth)')
    
    # Column 2: Before alignment
    axes[row, 1].imshow(bscan_mov, cmap='gray', aspect='auto')
    axes[row, 1].set_title('Vol 1 (XZ only)', fontweight='bold')
    
    # Column 3: After 2D phase corr
    diff_2d = np.abs(bscan_ref.astype(float) - bscan_aligned_2d.astype(float)).mean()
    axes[row, 2].imshow(bscan_aligned_2d, cmap='gray', aspect='auto')
    axes[row, 2].set_title(f'2D Phase Corr\nY={offset_2d:+.1f}px, conf={conf_2d:.1%}\nDiff={diff_2d:.2f}',
                          fontweight='bold')
    
    # Column 4: After 1D corr
    diff_1d = np.abs(bscan_ref.astype(float) - bscan_aligned_1d.astype(float)).mean()
    axes[row, 3].imshow(bscan_aligned_1d, cmap='gray', aspect='auto')
    axes[row, 3].set_title(f'1D Corr (old)\nY={offset_1d:+.1f}px, conf={conf_1d:.1%}\nDiff={diff_1d:.2f}',
                          fontweight='bold')

for ax in axes.flat:
    ax.set_xlabel('X (lateral)')

plt.suptitle('Method Comparison: 2D Phase Correlation vs 1D Correlation', 
             fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nüí° Look for:")
print("  - Better alignment (clearer tissue boundaries)")
print("  - Lower difference values")
print("  - Higher confidence scores")

## 8. Decision: Apply to All B-scans?

Based on test results, decide whether to use 2D phase correlation for all 360 B-scans.

In [None]:
# Automatically decide based on metrics
confidence_improvement = (confs_2d.mean() - confs_1d.mean()) / confs_1d.mean()
offset_std_reduction = (offsets_1d.std() - offsets_2d.std()) / offsets_1d.std()

print("="*80)
print("DECISION CRITERIA")
print("="*80)
print(f"\n1. Confidence improvement: {100*confidence_improvement:+.1f}%")
print(f"   ‚úÖ Good if > +10%" if confidence_improvement > 0.1 else "   ‚ö†Ô∏è Marginal if < +10%")

print(f"\n2. Offset std reduction: {100*offset_std_reduction:+.1f}%")
print(f"   ‚úÖ Good if > +20%" if offset_std_reduction > 0.2 else "   ‚ö†Ô∏è Marginal if < +20%")

print(f"\n3. Mean confidence (2D): {100*confs_2d.mean():.1f}%")
print(f"   ‚úÖ Good if > 60%" if confs_2d.mean() > 0.6 else "   ‚ö†Ô∏è Marginal if < 60%")

# Make decision
use_2d_method = (confidence_improvement > 0.1 and offset_std_reduction > 0.2) or confs_2d.mean() > 0.6

print("\n" + "="*80)
if use_2d_method:
    print("‚úÖ DECISION: Use 2D Phase Correlation for all 360 B-scans")
    print("   Metrics show clear improvement over 1D method.")
else:
    print("‚ö†Ô∏è DECISION: Results unclear - need manual review")
    print("   Consider trying alternative methods or checking if Y-alignment is needed.")
print("="*80)

# Ask user to confirm
print("\nüëÜ Please review the visualizations and statistics above.")
print("   Then run the next cell to apply to all B-scans.")

## 9. Apply to All 360 B-scans (if successful)

Run this cell only if the test results look promising.

In [None]:
# Calculate Y-offset for ALL B-scans using 2D phase correlation
num_bscans = volume_0.shape[2]
y_offsets_2d = np.zeros(num_bscans, dtype=np.float32)
confidences_2d = np.zeros(num_bscans, dtype=np.float32)

print(f"Calculating Y-offset for each of {num_bscans} B-scans using 2D Phase Correlation...")
print("This may take 2-3 minutes...\n")

start_time = time.time()

for z in range(num_bscans):
    if z % 50 == 0:
        elapsed = time.time() - start_time
        eta = (elapsed / (z+1)) * (num_bscans - z - 1) if z > 0 else 0
        print(f"  Processing B-scan {z}/{num_bscans}... (ETA: {eta:.0f}s)")
    
    # Extract B-scans
    bscan_ref = volume_0[:, :, z]
    bscan_mov = volume_1_xz_aligned[:, :, z]
    
    # Calculate Y-offset with 2D phase correlation
    y_offset, confidence = calculate_y_offset_2d_phase_correlation(
        bscan_ref, bscan_mov, max_shift=30
    )
    
    y_offsets_2d[z] = y_offset
    confidences_2d[z] = confidence

elapsed_total = time.time() - start_time

print(f"\n‚úì Calculated {len(y_offsets_2d)} per-B-scan Y-offsets in {elapsed_total:.1f}s")
print(f"\nüìä Y-offset statistics (2D Phase Correlation):")
print(f"  Mean offset: {y_offsets_2d.mean():.2f} ¬± {y_offsets_2d.std():.2f} pixels")
print(f"  Median offset: {np.median(y_offsets_2d):.2f} pixels")
print(f"  Range: [{y_offsets_2d.min():.2f}, {y_offsets_2d.max():.2f}] pixels")
print(f"\nüìà Confidence statistics:")
print(f"  Mean confidence: {confidences_2d.mean():.2%}")
print(f"  Median confidence: {np.median(confidences_2d):.2%}")
print(f"  Range: [{confidences_2d.min():.2%}, {confidences_2d.max():.2%}]")

# Count high confidence B-scans
high_conf = (confidences_2d > 0.6).sum()
print(f"  High confidence B-scans (>60%): {high_conf}/{num_bscans} ({100*high_conf/num_bscans:.1f}%)")

## 10. Visualize Per-B-scan Results

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(16, 10))

# Plot 1: Y-offsets vs B-scan index
axes[0, 0].plot(y_offsets_2d, linewidth=2, color='blue', label='2D Phase Corr')
axes[0, 0].axhline(y_offsets_2d.mean(), color='red', linestyle='--', 
                  label=f'Mean: {y_offsets_2d.mean():.2f} px')
axes[0, 0].axhline(np.median(y_offsets_2d), color='green', linestyle='--', 
                  label=f'Median: {np.median(y_offsets_2d):.2f} px')
axes[0, 0].set_xlabel('B-scan index (Z)', fontsize=12)
axes[0, 0].set_ylabel('Y-offset (depth, pixels)', fontsize=12)
axes[0, 0].set_title('Per-B-scan Y-offsets (2D Phase Correlation)', fontsize=14, fontweight='bold')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: Confidence vs B-scan index
axes[0, 1].plot(confidences_2d, linewidth=2, color='orange')
axes[0, 1].axhline(0.6, color='red', linestyle='--', alpha=0.5, label='60% threshold')
axes[0, 1].axhline(confidences_2d.mean(), color='blue', linestyle='--', 
                  label=f'Mean: {confidences_2d.mean():.2%}')
axes[0, 1].set_xlabel('B-scan index (Z)', fontsize=12)
axes[0, 1].set_ylabel('Confidence', fontsize=12)
axes[0, 1].set_title('Alignment Confidence per B-scan', fontsize=14, fontweight='bold')
axes[0, 1].set_ylim([0, 1])
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Histogram of Y-offsets
axes[1, 0].hist(y_offsets_2d, bins=50, color='blue', alpha=0.7, edgecolor='black')
axes[1, 0].axvline(y_offsets_2d.mean(), color='red', linestyle='--', linewidth=2, 
                  label=f'Mean: {y_offsets_2d.mean():.2f}')
axes[1, 0].axvline(np.median(y_offsets_2d), color='green', linestyle='--', linewidth=2, 
                  label=f'Median: {np.median(y_offsets_2d):.2f}')
axes[1, 0].set_xlabel('Y-offset (pixels)', fontsize=12)
axes[1, 0].set_ylabel('Count', fontsize=12)
axes[1, 0].set_title('Distribution of Y-offsets', fontsize=14, fontweight='bold')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Plot 4: Y-offset vs Confidence scatter
scatter = axes[1, 1].scatter(y_offsets_2d, confidences_2d, c=confidences_2d, 
                            cmap='viridis', s=20, alpha=0.6)
axes[1, 1].axhline(0.6, color='red', linestyle='--', alpha=0.5)
axes[1, 1].set_xlabel('Y-offset (pixels)', fontsize=12)
axes[1, 1].set_ylabel('Confidence', fontsize=12)
axes[1, 1].set_title('Offset vs Confidence', fontsize=14, fontweight='bold')
axes[1, 1].grid(True, alpha=0.3)
plt.colorbar(scatter, ax=axes[1, 1], label='Confidence')

plt.tight_layout()
plt.show()

print("\nüí° Key observations:")
print("  - Smooth offsets across Z indicate stable alignment")
print("  - High confidence = strong 2D correlation peak")
print("  - Large jumps may indicate motion or different structures")

## 11. Apply Per-B-scan Y-alignment

In [None]:
def apply_per_bscan_y_alignment(volume, y_offsets):
    """
    Apply per-B-scan Y-shifts to volume.
    
    Args:
        volume: Volume to align (Y, X, Z)
        y_offsets: Y-offset per B-scan (Z,)
    
    Returns:
        aligned_volume: Y-aligned volume (Y, X, Z)
    """
    Y, X, Z = volume.shape
    aligned_volume = np.zeros_like(volume)
    
    print(f"Applying per-B-scan Y-shifts...")
    
    for z in range(Z):
        if z % 50 == 0:
            print(f"  B-scan {z}/{Z}...")
        
        # Shift this B-scan along Y-axis
        bscan = volume[:, :, z]
        bscan_shifted = ndimage.shift(bscan, shift=(y_offsets[z], 0), 
                                     order=1, mode='constant', cval=0)
        aligned_volume[:, :, z] = bscan_shifted
    
    return aligned_volume

# Apply Y-alignment
print("\nApplying per-B-scan Y-alignment to Volume 1...")
volume_1_fully_aligned = apply_per_bscan_y_alignment(volume_1_xz_aligned, y_offsets_2d)

print(f"\n‚úì Volume 1 fully aligned (X, Y, Z)")
print(f"  Shape: {volume_1_fully_aligned.shape}")

## 12. Calculate Alignment Improvement

In [None]:
# Calculate alignment quality
print("Calculating alignment improvement...")

diff_before_all = np.abs(volume_0.astype(float) - volume_1_xz_aligned.astype(float))
diff_after_all = np.abs(volume_0.astype(float) - volume_1_fully_aligned.astype(float))

improvement = 100 * (1 - diff_after_all.mean() / diff_before_all.mean())

print("\nüìä Overall Alignment Quality:")
print(f"  Before Y-alignment: {diff_before_all.mean():.2f} ¬± {diff_before_all.std():.2f}")
print(f"  After Y-alignment:  {diff_after_all.mean():.2f} ¬± {diff_after_all.std():.2f}")
print(f"  Improvement: {improvement:.2f}%", end="")

if improvement > 10:
    print(" ‚úÖ EXCELLENT!")
elif improvement > 5:
    print(" ‚úÖ GOOD")
elif improvement > 1:
    print(" ‚ö†Ô∏è MODERATE")
else:
    print(" ‚ùå MINIMAL - Y-alignment may not be needed")

print(f"\nüìä Comparison with old 1D method:")
print(f"  Old 1D method improvement: 0.1%")
print(f"  New 2D method improvement: {improvement:.2f}%")
print(f"  Relative improvement: {improvement/0.1:.1f}x better!")

## 13. Visualize Alignment Quality

In [None]:
# Compare B-scans before and after Y-alignment
z_samples = [num_bscans // 4, num_bscans // 2, 3 * num_bscans // 4]

fig, axes = plt.subplots(len(z_samples), 4, figsize=(20, 5*len(z_samples)))

for i, z in enumerate(z_samples):
    # Column 1: Reference B-scan
    axes[i, 0].imshow(volume_0[:, :, z], cmap='gray', aspect='auto')
    axes[i, 0].set_title(f'Vol 0 (ref) - Z={z}', fontweight='bold')
    axes[i, 0].set_ylabel('Y (depth)')
    axes[i, 0].set_xlabel('X (lateral)')
    
    # Column 2: Vol 1 before Y-alignment (XZ only)
    axes[i, 1].imshow(volume_1_xz_aligned[:, :, z], cmap='gray', aspect='auto')
    axes[i, 1].set_title(f'Vol 1 (XZ aligned) - Z={z}', fontweight='bold')
    axes[i, 1].set_xlabel('X (lateral)')
    
    # Column 3: Vol 1 after full XYZ alignment
    axes[i, 2].imshow(volume_1_fully_aligned[:, :, z], cmap='gray', aspect='auto')
    axes[i, 2].set_title(f'Vol 1 (XYZ aligned) - Z={z}\nY-shift: {y_offsets_2d[z]:.2f} px (conf={confidences_2d[z]:.1%})', 
                        fontweight='bold')
    axes[i, 2].set_xlabel('X (lateral)')
    
    # Column 4: Difference map
    diff_before = np.abs(volume_0[:, :, z].astype(float) - volume_1_xz_aligned[:, :, z].astype(float))
    diff_after = np.abs(volume_0[:, :, z].astype(float) - volume_1_fully_aligned[:, :, z].astype(float))
    
    diff_reduction = diff_before - diff_after
    im = axes[i, 3].imshow(diff_reduction, cmap='RdYlGn', aspect='auto', vmin=-20, vmax=diff_reduction.max())
    axes[i, 3].set_title(f'Improvement (pos=better)\nMean: {diff_reduction.mean():.2f}', fontweight='bold')
    axes[i, 3].set_xlabel('X (lateral)')
    plt.colorbar(im, ax=axes[i, 3], label='Difference reduction')

plt.tight_layout()
plt.show()

print("\nüí° Green areas = improvement, red areas = worse (should be mostly green!)")

## 14. Save Results

In [None]:
# Save aligned volume
print("Saving results...")
np.save(data_dir / 'volume_1_fully_aligned_2d.npy', volume_1_fully_aligned)

# Save Y-alignment parameters
y_alignment_params = {
    'method': 'per_bscan_2d_phase_correlation',
    'y_offsets': y_offsets_2d.tolist(),
    'confidences': confidences_2d.tolist(),
    'mean_offset': float(y_offsets_2d.mean()),
    'median_offset': float(np.median(y_offsets_2d)),
    'offset_std': float(y_offsets_2d.std()),
    'mean_confidence': float(confidences_2d.mean()),
    'median_confidence': float(np.median(confidences_2d)),
    'high_confidence_count': int((confidences_2d > 0.6).sum()),
    'improvement_percent': float(improvement),
    'max_shift': 30
}

np.save(data_dir / 'y_alignment_params_2d.npy', y_alignment_params, allow_pickle=True)

# Save complete 3D registration
registration_3d = {
    # X-Z alignment (global)
    'translation_x': int(offset_x),
    'translation_z': int(offset_z),
    'xz_method': xy_params.get('best_method', 'phase_correlation'),
    'xz_confidence': float(xy_params['confidence']),
    'xz_improvement': float(xy_params['improvement_percent']),
    
    # Y alignment (per-B-scan with 2D phase correlation)
    'y_method': 'per_bscan_2d_phase_correlation',
    'y_offsets_per_bscan': y_offsets_2d.tolist(),
    'y_confidences_per_bscan': confidences_2d.tolist(),
    'y_mean_offset': float(y_offsets_2d.mean()),
    'y_median_offset': float(np.median(y_offsets_2d)),
    'y_mean_confidence': float(confidences_2d.mean()),
    'y_improvement': float(improvement),
    
    # Transform summary
    'transform_3d': {
        'x_offset': int(offset_x),
        'z_offset': int(offset_z),
        'y_offsets': y_offsets_2d.tolist()  # Per-B-scan
    }
}

np.save(data_dir / 'registration_3d_params_2d.npy', registration_3d, allow_pickle=True)

print("\n‚úì Saved:")
print("  - volume_1_fully_aligned_2d.npy")
print("  - y_alignment_params_2d.npy")
print("  - registration_3d_params_2d.npy")

print(f"\n{'='*70}")
print("FINAL 3D REGISTRATION SUMMARY (2D Phase Correlation)")
print(f"{'='*70}")
print(f"\nüìç X-Z Alignment (Global):")
print(f"  Method: {xy_params.get('best_method', 'phase_correlation')}")
print(f"  X offset: {offset_x} px")
print(f"  Z offset: {offset_z} px")
print(f"  Confidence: {xy_params['confidence']:.2f}")
print(f"  Improvement: {xy_params['improvement_percent']:.1f}%")
print(f"\nüìè Y-axis Alignment (Per-B-scan 2D Phase Correlation):")
print(f"  Method: 2D phase correlation (preserves spatial structure)")
print(f"  Mean Y-offset: {y_offsets_2d.mean():.2f} ¬± {y_offsets_2d.std():.2f} px")
print(f"  Median Y-offset: {np.median(y_offsets_2d):.2f} px")
print(f"  Range: [{y_offsets_2d.min():.2f}, {y_offsets_2d.max():.2f}] px")
print(f"  Mean confidence: {confidences_2d.mean():.2%}")
print(f"  High confidence B-scans: {(confidences_2d>0.6).sum()}/{num_bscans}")
print(f"  Improvement: {improvement:.2f}%")
print(f"\nüéØ Comparison with old 1D method:")
print(f"  Old improvement: 0.1%")
print(f"  New improvement: {improvement:.2f}%")
print(f"  Relative gain: {improvement/0.1:.1f}x better!")
print(f"\n‚úÖ Full 3D registration complete!")
print(f"{'='*70}")

## Summary

This notebook implemented **2D phase correlation** for Y-axis alignment:

### Key innovation:
- ‚ùå Old: 1D correlation on averaged X profiles (lost spatial info)
- ‚úÖ New: 2D phase correlation on full B-scan images (preserves vessel structure)

### Method inspiration:
- Adapted from Phase 3 X-Z registration (15.8% improvement)
- Uses full 2D FFT-based phase correlation
- Restricted search range (max_shift=30) for efficiency

### Results:
- See final summary above for detailed metrics
- If improvement > 5%: Success! ‚úÖ
- If improvement < 5%: May need alternative approach or Y-alignment not needed

### Next steps:
- Use `06_visualize_results.ipynb` to create 3D merged volume
- Compare with old 1D method results
- If needed: Try Method 2 (weighted 1D) or Method 3 (global median)