# Optical Coherence Tomography (OCT) Reconstruction Example
This notebook demonstrates basic OCT image reconstruction, including data simulation using a Fourier-based forward model and Total Variation (TV) regularized reconstruction.

## 1. Imports and Setup

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Ensure reconlib is in the Python path
if 'reconlib' not in os.getcwd():
    # Assuming notebook is in reconlib/modalities/oct/
    sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname('__file__'), '../../..')))
else:
    sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname('__file__'), '..')))

from reconlib.modalities.oct.operators import OCTForwardOperator
from reconlib.modalities.oct.reconstructors import tv_reconstruction_oct
# Note: tv_reconstruction_oct uses UltrasoundTVCustomRegularizer internally

%matplotlib inline

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Define OCT Parameters

In [None]:
# Image Parameters (B-scan like: A-scans vs Depth)
num_ascan_lines = 128  # Number of lateral A-scans in the B-scan
depth_pixels = 256    # Number of pixels along depth for each A-scan
image_shape_oct = (num_ascan_lines, depth_pixels)

# Physical Parameters for OCT
lambda_w_m = 850e-9      # Center wavelength (e.g., 850 nm)
z_max_m = 0.002          # Maximum imaging depth (e.g., 2 mm)
n_refractive_index = 1.35 # Refractive index of tissue (approximate)

# Reconstruction Parameters
lambda_tv_oct = 0.01      # TV regularization strength
oct_pg_iterations = 30    # Iterations for Proximal Gradient
oct_pg_step_size = 0.05   # Step size for Proximal Gradient
oct_tv_prox_iters = 5    # Inner iterations for the custom TV prox
oct_tv_prox_step = 0.01 # Inner step size for custom TV prox

## 3. Create OCT Phantom (Reflectivity Profile)

In [None]:
def generate_oct_phantom(shape, device='cpu'):
    # shape: (num_ascan_lines, depth_pixels)
    phantom = torch.zeros(shape, dtype=torch.complex64, device=device)
    n_ascans, n_depth = shape
    
    # Simulate a few reflecting layers
    layer_depths_pixels = [n_depth // 4, n_depth // 2, n_depth // 4 * 3]
    layer_reflectivities = [1.0, 0.7, 0.9]
    layer_thickness_pixels = 3
    
    for i in range(n_ascans):
        # Vary layer depths slightly across A-scans for texture
        depth_offset = int(5 * np.sin(2 * np.pi * i / n_ascans * 2)) if i % 10 == 0 else 0
        for l_idx, depth_px in enumerate(layer_depths_pixels):
            start_px = max(0, depth_px + depth_offset - layer_thickness_pixels // 2)
            end_px = min(n_depth, depth_px + depth_offset + layer_thickness_pixels // 2 + 1)
            if i > n_ascans * 0.2 and i < n_ascans * 0.8: # Make layers discontinuous laterally
                 phantom[i, start_px:end_px] = layer_reflectivities[l_idx]
            if l_idx == 1 and i > n_ascans * 0.4 and i < n_ascans * 0.6:
                 phantom[i, start_px:end_px] = 0 # Create a small gap in 2nd layer

    # Add some sparse scatterers
    phantom[n_ascans // 3, n_depth // 3 + 10] = 1.2 + 0.3j
    phantom[n_ascans // 2, n_depth // 2 + 20] = 0.5 - 0.8j
    return phantom

oct_phantom = generate_oct_phantom(image_shape_oct, device=device)

plt.figure(figsize=(8, 5))
plt.imshow(torch.abs(oct_phantom).cpu().numpy(), cmap='gray', aspect='auto', extent=[0, image_shape_oct[1], image_shape_oct[0], 0])
plt.title('Original OCT Phantom (Reflectivity Magnitude)')
plt.xlabel('Depth Pixels')
plt.ylabel('A-scan Line')
plt.colorbar(label='Reflectivity')
plt.show()

## 4. Initialize OCT Forward Operator

In [None]:
oct_operator_inst = OCTForwardOperator(
    image_shape=image_shape_oct,
    lambda_w=lambda_w_m,
    z_max_m=z_max_m,
    n_refractive_index=n_refractive_index,
    device=device
)
print("OCTForwardOperator initialized.")

## 5. Simulate OCT Spectral Data (Forward Projection)

In [None]:
print("Simulating OCT spectral data (k-space)...")
oct_k_space_clean = oct_operator_inst.op(oct_phantom)
print(f"Simulated clean k-space data shape: {oct_k_space_clean.shape}")

# Add complex Gaussian noise
signal_power_oct = torch.mean(torch.abs(oct_k_space_clean)**2)
noise_power_ratio_oct = 0.1 # 10% noise relative to signal power
noise_std_oct = torch.sqrt(signal_power_oct * noise_power_ratio_oct / 2) # Factor of 2 for complex noise
noise_oct = noise_std_oct * (torch.randn_like(oct_k_space_clean.real) + 1j * torch.randn_like(oct_k_space_clean.imag))
oct_k_space_noisy = oct_k_space_clean + noise_oct
print(f"Added complex Gaussian noise. Noise STD: {noise_std_oct.item()}")

plt.figure(figsize=(10, 5))
plt.subplot(1,2,1)
plt.imshow(torch.abs(oct_k_space_clean).cpu().numpy(), aspect='auto', cmap='viridis')
plt.title('Clean OCT k-space Data (Magnitude)')
plt.xlabel('k-space Samples (Depth Freq.)')
plt.ylabel('A-scan Line')
plt.colorbar()
plt.subplot(1,2,2)
plt.imshow(torch.abs(oct_k_space_noisy).cpu().numpy(), aspect='auto', cmap='viridis')
plt.title('Noisy OCT k-space Data (Magnitude)')
plt.xlabel('k-space Samples (Depth Freq.)')
plt.ylabel('A-scan Line')
plt.colorbar()
plt.tight_layout()
plt.show()

## 6. Image Reconstruction

### 6.1 Adjoint Reconstruction (IFFT of k-space data)

In [None]:
print("Performing Adjoint (IFFT-based) reconstruction...")
oct_adjoint_recon = oct_operator_inst.op_adj(oct_k_space_noisy)
print(f"Adjoint reconstructed image shape: {oct_adjoint_recon.shape}")

plt.figure(figsize=(8, 5))
plt.imshow(torch.abs(oct_adjoint_recon).cpu().numpy(), cmap='gray', aspect='auto', extent=[0, image_shape_oct[1], image_shape_oct[0], 0])
plt.title('Adjoint (IFFT) OCT Reconstruction')
plt.xlabel('Depth Pixels')
plt.ylabel('A-scan Line')
plt.colorbar(label='Reflectivity')
plt.show()

### 6.2 TV Regularized Reconstruction

In [None]:
print(f"Performing TV Regularized OCT Reconstruction (lambda_TV={lambda_tv_oct})...This may take some time.")

oct_tv_recon_image = tv_reconstruction_oct(
    y_oct_data=oct_k_space_noisy,
    oct_operator=oct_operator_inst,
    lambda_tv=lambda_tv_oct,
    iterations=oct_pg_iterations,
    step_size=oct_pg_step_size,
    tv_prox_iterations=oct_tv_prox_iters,
    tv_prox_step_size=oct_tv_prox_step,
    verbose=True
)
print(f"TV Reconstructed OCT image shape: {oct_tv_recon_image.shape}")

plt.figure(figsize=(8, 5))
plt.imshow(torch.abs(oct_tv_recon_image).cpu().numpy(), cmap='gray', aspect='auto', extent=[0, image_shape_oct[1], image_shape_oct[0], 0])
plt.title(f'TV Regularized OCT Recon (lambda={lambda_tv_oct}, {oct_pg_iterations} iters)')
plt.xlabel('Depth Pixels')
plt.ylabel('A-scan Line')
plt.colorbar(label='Reflectivity')
plt.show()

## 7. Comparison of Results

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
fig.suptitle('OCT Reconstruction Comparison', fontsize=16)

im0 = axes[0].imshow(torch.abs(oct_phantom).cpu().numpy(), cmap='gray', aspect='auto', extent=[0, image_shape_oct[1], image_shape_oct[0], 0])
axes[0].set_title('Original Phantom')
axes[0].set_xlabel('Depth Pixels'); axes[0].set_ylabel('A-scan Line')
fig.colorbar(im0, ax=axes[0], shrink=0.8)

im1 = axes[1].imshow(torch.abs(oct_adjoint_recon).cpu().numpy(), cmap='gray', aspect='auto', extent=[0, image_shape_oct[1], image_shape_oct[0], 0])
axes[1].set_title('Adjoint (IFFT) Recon')
axes[1].set_xlabel('Depth Pixels'); axes[1].set_ylabel('A-scan Line')
fig.colorbar(im1, ax=axes[1], shrink=0.8)

im2 = axes[2].imshow(torch.abs(oct_tv_recon_image).cpu().numpy(), cmap='gray', aspect='auto', extent=[0, image_shape_oct[1], image_shape_oct[0], 0])
axes[2].set_title(f'TV Recon (lambda={lambda_tv_oct})')
axes[2].set_xlabel('Depth Pixels'); axes[2].set_ylabel('A-scan Line')
fig.colorbar(im2, ax=axes[2], shrink=0.8)

plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust layout to make space for suptitle
plt.show()