# X-ray Phase-Contrast Imaging (XPCI) Reconstruction Example
This notebook demonstrates basic XPCI phase image reconstruction from simulated differential phase contrast data. It uses a simplified gradient-based forward model and Total Variation (TV) regularized reconstruction.

## 1. Imports and Setup

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Ensure reconlib is in the Python path
if 'reconlib' not in os.getcwd():
    sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname('__file__'), '../../..')))
else:
    sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname('__file__'), '..')))

from reconlib.modalities.xray_phase_contrast.operators import XRayPhaseContrastOperator
from reconlib.modalities.xray_phase_contrast.reconstructors import tv_reconstruction_xrpc
# tv_reconstruction_xrpc uses UltrasoundTVCustomRegularizer by default

%matplotlib inline

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Define XPCI Parameters

In [None]:
# Image Parameters (2D Phase Image)
H, W = 128, 128  # Image height (H) and width (W)
image_shape_xrpc = (H, W)

# XPCI Specific Parameters
lambda_xray_m = 1e-10    # X-ray wavelength (e.g., 0.1 nm for ~12.4 keV)
k_wave_number = 2 * np.pi / lambda_xray_m
pixel_size_m = 1e-5      # Pixel size (e.g., 10 microns)
print(f"X-ray k-wave number: {k_wave_number:.2e}")

# Reconstruction Parameters
lambda_tv_xrpc = 0.001   # TV regularization strength (tune this carefully)
xrpc_pg_iterations = 50  # Iterations for Proximal Gradient
xrpc_pg_step_size = 0.01 # Step size for Proximal Gradient
xrpc_tv_prox_iters = 5  # Inner iterations for the custom TV prox
xrpc_tv_prox_step = 0.01# Inner step size for custom TV prox

## 3. Create XPCI Phase Phantom

In [None]:
def generate_xrpc_phantom(shape, device='cpu'):
    # shape: (H, W)
    phantom = torch.zeros(shape, dtype=torch.float32, device=device)
    h, w = shape
    
    # Simulate a phase object (e.g., a sphere or cylinder projection)
    center_y, center_x = h // 2, w // 2
    radius1 = min(h, w) // 4
    Y, X = torch.meshgrid(torch.arange(h, device=device), 
                          torch.arange(w, device=device), indexing='ij')
    mask1 = (X - center_x)**2 + (Y - center_y)**2 < radius1**2
    phantom[mask1] = 0.5 # Phase shift value
    
    # Add another smaller region with different phase shift
    center_y2, center_x2 = h // 4, w // 1.5
    radius2 = min(h, w) // 8
    mask2 = (X - center_x2)**2 + (Y - center_y2)**2 < radius2**2
    phantom[mask2] = -0.3 # Negative phase shift
    
    # Add a gradient/ramp for more interesting differential data
    ramp = torch.linspace(0, 0.2, w, device=device).unsqueeze(0).repeat(h,1)
    phantom += ramp
    return phantom

xrpc_phantom_phase = generate_xrpc_phantom(image_shape_xrpc, device=device)

plt.figure(figsize=(6,5))
plt.imshow(xrpc_phantom_phase.cpu().numpy(), cmap='viridis')
plt.title('Original XPCI Phase Phantom')
plt.xlabel('X-pixels'); plt.ylabel('Y-pixels')
plt.colorbar(label='Phase Shift (radians)')
plt.show()

## 4. Initialize X-ray Phase-Contrast Operator

In [None]:
xrpc_operator_inst = XRayPhaseContrastOperator(
    image_shape=image_shape_xrpc,
    k_wave_number=k_wave_number,
    pixel_size_m=pixel_size_m,
    device=device
)
print("XRayPhaseContrastOperator initialized.")

## 5. Simulate Differential Phase-Contrast Data

In [None]:
print("Simulating XPCI differential phase data...")
diff_phase_data_clean = xrpc_operator_inst.op(xrpc_phantom_phase)
print(f"Simulated clean differential phase data shape: {diff_phase_data_clean.shape}")

# Add Gaussian noise
signal_mean_abs_xrpc = torch.mean(torch.abs(diff_phase_data_clean))
noise_std_xrpc = 0.05 * signal_mean_abs_xrpc # 5% noise relative to mean signal magnitude
noise_xrpc = noise_std_xrpc * torch.randn_like(diff_phase_data_clean)
diff_phase_data_noisy = diff_phase_data_clean + noise_xrpc
print(f"Added Gaussian noise. Noise STD: {noise_std_xrpc.item()}")

plt.figure(figsize=(10, 5))
plt.subplot(1,2,1)
plt.imshow(diff_phase_data_clean.cpu().numpy(), cmap='coolwarm', aspect='auto')
plt.title('Clean Differential Phase Data')
plt.xlabel('X-pixels'); plt.ylabel('Y-pixels')
plt.colorbar()
plt.subplot(1,2,2)
plt.imshow(diff_phase_data_noisy.cpu().numpy(), cmap='coolwarm', aspect='auto')
plt.title('Noisy Differential Phase Data')
plt.xlabel('X-pixels'); plt.ylabel('Y-pixels')
plt.colorbar()
plt.tight_layout()
plt.show()

## 6. Phase Image Reconstruction

### 6.1 Adjoint Reconstruction (Approximate Phase Integration)

In [None]:
print("Performing Adjoint (approx. integration) reconstruction...")
xrpc_adjoint_recon_phase = xrpc_operator_inst.op_adj(diff_phase_data_noisy)
print(f"Adjoint reconstructed phase image shape: {xrpc_adjoint_recon_phase.shape}")

plt.figure(figsize=(6,5))
plt.imshow(xrpc_adjoint_recon_phase.cpu().numpy(), cmap='viridis')
plt.title('Adjoint XPCI Reconstruction (Approx. Phase)')
plt.xlabel('X-pixels'); plt.ylabel('Y-pixels')
plt.colorbar(label='Phase Shift (radians)')
plt.show()

### 6.2 TV Regularized Phase Reconstruction

In [None]:
print(f"Performing TV Regularized XPCI Phase Reconstruction (lambda_TV={lambda_tv_xrpc})...")

xrpc_tv_recon_phase = tv_reconstruction_xrpc(
    y_differential_phase_data=diff_phase_data_noisy,
    xrpc_operator=xrpc_operator_inst,
    lambda_tv=lambda_tv_xrpc,
    iterations=xrpc_pg_iterations,
    step_size=xrpc_pg_step_size,
    tv_prox_iterations=xrpc_tv_prox_iters,
    tv_prox_step_size=xrpc_tv_prox_step,
    verbose=True
)
print(f"TV Reconstructed XPCI Phase Image shape: {xrpc_tv_recon_phase.shape}")

plt.figure(figsize=(6,5))
plt.imshow(xrpc_tv_recon_phase.cpu().numpy(), cmap='viridis')
plt.title(f'TV Regularized XPCI Recon (lambda={lambda_tv_xrpc}, {xrpc_pg_iterations} iters)')
plt.xlabel('X-pixels'); plt.ylabel('Y-pixels')
plt.colorbar(label='Phase Shift (radians)')
plt.show()

## 7. Comparison of Results

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
fig.suptitle('XPCI Phase Reconstruction Comparison', fontsize=16)

im0 = axes[0].imshow(xrpc_phantom_phase.cpu().numpy(), cmap='viridis')
axes[0].set_title('Original Phase Phantom')
axes[0].set_xlabel('X-pixels'); axes[0].set_ylabel('Y-pixels')
fig.colorbar(im0, ax=axes[0], shrink=0.8, label='Phase (rad)')

im1 = axes[1].imshow(xrpc_adjoint_recon_phase.cpu().numpy(), cmap='viridis')
axes[1].set_title('Adjoint Recon (Approx. Phase)')
axes[1].set_xlabel('X-pixels'); axes[1].set_ylabel('Y-pixels')
fig.colorbar(im1, ax=axes[1], shrink=0.8, label='Phase (rad)')

im2 = axes[2].imshow(xrpc_tv_recon_phase.cpu().numpy(), cmap='viridis')
axes[2].set_title(f'TV Regularized Recon (lambda={lambda_tv_xrpc})')
axes[2].set_xlabel('X-pixels'); axes[2].set_ylabel('Y-pixels')
fig.colorbar(im2, ax=axes[2], shrink=0.8, label='Phase (rad)')

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()