### 1. Stack Loading & Selection

In [None]:
import SimpleITK as sitk
import numpy as np
import cv2
import matplotlib.pyplot as plt
import time
import pandas as pd

# --- CONFIGURATION ---
INPUT_STACK_PATH = "path/to/rough_aligned_stack.tif"
FIXED_IDX = 11  # The "Anchor" slide
MOVING_IDX = 12 # The slide that needs residual correction

# --- LOAD VOLUME ---
# We assume the input is a 3D volume (Z, Y, X) where Z is the slice index
print(f"Loading stack: {INPUT_STACK_PATH}...")
vol_img = sitk.ReadImage(INPUT_STACK_PATH)
vol_arr = sitk.GetArrayFromImage(vol_img) # Shape: (Depth, Height, Width)

print(f"Volume Shape: {vol_arr.shape}")

# Extract Slices
# Trade-off: We cast to float32 for math, but will need uint8 for OpenCV
fixed_slice = vol_arr[FIXED_IDX, :, :]
moving_slice = vol_arr[MOVING_IDX, :, :]

# --- PREPROCESSING FOR OPENCV ---
def prepare_for_features(img_arr):
    """
    1. Normalizes to 0-255 (Dynamic Range Compression).
    2. Casts to Uint8 (Required for SIFT/ORB).
    3. Applies CLAHE (Contrast Limited Adaptive Histogram Equalization) to boost local texture.
    """
    # Normalize
    norm = cv2.normalize(img_arr, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
    
    # CLAHE: Enhances local contrast (edges/texture) without amplifying noise like global HistEq
    # GridSize (8,8) is standard for medical images to localize contrast enhancement
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    return clahe.apply(norm)

cv_fixed = prepare_for_features(fixed_slice)
cv_moving = prepare_for_features(moving_slice)

# Visualize Initial State
plt.figure(figsize=(10,10))
plt.imshow(cv_fixed, cmap='gray')
plt.title(f"Reference Slice (Idx {FIXED_IDX}) - Prepared for Feature Extraction")
plt.axis('off')
plt.show()

### 2. Feature-Based

In [None]:
def solve_residual_features(fixed_8bit, moving_8bit, method="ORB"):
    """
    Calculates the residual homography between two roughly aligned images.
    Returns the refined image and metrics.
    """
    t_start = time.time()
    
    # 1. Feature Detector Factory
    # Trade-off: 
    # ORB = Fast, Binary Descriptor (Hamming), Good for edges.
    # AKAZE = Non-linear scale space, better for biological tissue/texture.
    # SIFT = High precision, Float Descriptor (L2), Slower.
    if method == "ORB":
        det = cv2.ORB_create(nfeatures=5000)
        norm = cv2.NORM_HAMMING
    elif method == "AKAZE":
        det = cv2.AKAZE_create()
        norm = cv2.NORM_HAMMING
    elif method == "SIFT":
        det = cv2.SIFT_create()
        norm = cv2.NORM_L2
    
    # 2. Detect & Compute
    kp1, des1 = det.detectAndCompute(fixed_8bit, None)
    kp2, des2 = det.detectAndCompute(moving_8bit, None)
    
    # Fail-safe for feature starvation
    if des1 is None or des2 is None:
        return {"status": "Fail", "method": method, "mse": np.inf}

    # 3. Matching (KNN)
    matcher = cv2.BFMatcher(norm)
    matches = matcher.knnMatch(des1, des2, k=2)
    
    # 4. Outlier Rejection (Lowe's Ratio)
    # Stricter ratio (0.7) increases precision at the cost of match count
    good = [m for m, n in matches if m.distance < 0.7 * n.distance]
    
    if len(good) < 10:
        return {"status": "Fail (Low Matches)", "method": method, "mse": np.inf}

    # 5. Geometric Verification (RANSAC)
    src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
    dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)
    
    # RANSAC Threshold: 3.0 pixels (Strict). 
    # Since they are roughly aligned, matches > 3px away are likely noise.
    H, mask = cv2.findHomography(dst_pts, src_pts, cv2.RANSAC, 3.0)
    
    if H is None:
        return {"status": "Fail (No Converge)", "method": method, "mse": np.inf}
        
    # 6. Apply Residual Warp
    h, w = fixed_8bit.shape
    warped = cv2.warpPerspective(moving_8bit, H, (w, h))
    
    # 7. Metrics
    mse = np.mean((fixed_8bit - warped) ** 2)
    t_end = time.time()
    
    return {
        "status": "Success",
        "method": method,
        "matches": len(good),
        "inliers": np.sum(mask),
        "mse": mse,
        "time": t_end - t_start,
        "warped": warped,
        "H": H
    }

### 3.Execution & Lens Comparison

In [None]:
# --- RUN EXPERIMENT ---
methods = ["ORB", "AKAZE", "SIFT"] # Add SIFT if opencv-contrib is installed
results = []
images = {}

# 1. Calculate Baseline (Rough Alignment) Error
baseline_mse = np.mean((cv_fixed - cv_moving) ** 2)
print(f"BASELINE (Rough) MSE: {baseline_mse:.2f}\n" + "-"*40)

# 2. Test Methods
for m in methods:
    print(f"Testing {m}...", end=" ")
    res = solve_residual_features(cv_fixed, cv_moving, method=m)
    
    if res["status"] == "Success":
        imp = baseline_mse - res["mse"]
        print(f"Done. Matches: {res['matches']} | Inliers: {res['inliers']} | MSE: {res['mse']:.2f} (Imp: {imp:.2f})")
        results.append(res)
        images[m] = res["warped"]
    else:
        print(f"Failed. Reason: {res['status']}")

# --- VISUALIZATION ---
if results:
    fig, axes = plt.subplots(1, len(results) + 1, figsize=(20, 8))
    
    # Helper to create Green/Magenta overlay
    def make_overlay(img1, img2):
        # Img1 = Green, Img2 = Magenta
        # This color combo is colorblind-safe and high contrast
        g = img1.astype(float) / 255.0
        m = img2.astype(float) / 255.0
        return np.dstack((m, g, m)) # R, G, B

    # Plot Baseline
    axes[0].imshow(make_overlay(cv_fixed, cv_moving))
    axes[0].set_title(f"Baseline (Rough)\nMSE: {baseline_mse:.1f}")
    axes[0].axis('off')
    
    # Plot Results
    for i, res in enumerate(results):
        overlay = make_overlay(cv_fixed, res["warped"])
        axes[i+1].imshow(overlay)
        axes[i+1].set_title(f"{res['method']} (Refined)\nMSE: {res['mse']:.1f}\nTime: {res['time']:.3f}s")
        axes[i+1].axis('off')

    plt.tight_layout()
    plt.show()
    
    # Table Summary
    df = pd.DataFrame(results).drop(columns=["warped", "H"])
    print("\nSummary Table:")
    print(df.to_string(index=False))