# ReconLib Self-Contained NUFFT Pipeline Demo

This notebook demonstrates the use of the self-contained NUFFT implementation within ReconLib for both 2D and 3D MRI reconstruction. It showcases the `NUFFTOperator` using the internal Python-based NUFFT (table-based method).

## 1. Imports

In [None]:
import torch
import numpy as np
import math
import matplotlib.pyplot as plt

# Ensure plots are displayed inline
%matplotlib inline 

# ReconLib imports
from reconlib.operators import NUFFTOperator

# Helper functions from local scripts (assuming they are in python path or same directory)
# If ReconLib is installed, these might be part of the library, adjust path if necessary
import sys
import os
sys.path.append(os.path.abspath(os.path.join('..'))) # Add parent directory to path to find iternufft and l1l2recon

from iternufft import iterative_recon, generate_phantom_2d, generate_radial_trajectory_2d, generate_phantom_3d, generate_radial_trajectory_3d
from l1l2recon import L1Reconstruction, L2Reconstruction

## 2. Setup

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## 3. 2D Reconstruction Demo

### a. Phantom and Trajectory (2D)

In [None]:
Nx_2d, Ny_2d = 128, 128
image_shape_2d = (Nx_2d, Ny_2d)

phantom_2d = generate_phantom_2d(Nx_2d, device=device).to(torch.complex64)
k_traj_2d = generate_radial_trajectory_2d(num_spokes=128, samples_per_spoke=256, device=device)

fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(phantom_2d.abs().cpu().numpy(), cmap='gray')
axs[0].set_title('Original 2D Phantom')
axs[0].axis('off')

axs[1].scatter(k_traj_2d[:, 0].cpu().numpy(), k_traj_2d[:, 1].cpu().numpy(), s=0.5)
axs[1].set_title('2D Radial K-space Trajectory')
axs[1].set_xlabel('kx')
axs[1].set_ylabel('ky')
axs[1].set_aspect('equal')
plt.tight_layout()
plt.show()

### b. NUFFTOperator Setup (2D)

In [None]:
oversamp_factor_2d = (2.0, 2.0)
kb_J_2d = (4, 4)
kb_alpha_2d = tuple(2.34 * J_d for J_d in kb_J_2d) # Common heuristic for alpha
Ld_2d = (2**10, 2**10)
Kd_2d = tuple(int(N * os) for N, os in zip(image_shape_2d, oversamp_factor_2d))
# kb_m_2d and n_shift_2d will use defaults in NUFFTOperator if not passed (or pass explicitly)

nufft_op_2d = NUFFTOperator(k_trajectory=k_traj_2d, 
                            image_shape=image_shape_2d, 
                            oversamp_factor=oversamp_factor_2d, 
                            kb_J=kb_J_2d, 
                            kb_alpha=kb_alpha_2d, 
                            Ld=Ld_2d,
                            Kd=Kd_2d, # Optional, NUFFTOperator can compute this
                            kb_m=(0.0, 0.0), # Explicitly setting MIRT default for m
                            n_shift=(0.0, 0.0), # Explicitly setting MIRT default for n_shift
                            device=device)
print("NUFFTOperator for 2D created.")

### c. Simulate K-Space Data (2D)

In [None]:
print("Simulating 2D k-space data...")
kspace_data_2d = nufft_op_2d.op(phantom_2d)

# Add a small amount of Gaussian noise
noise_level_2d = 0.01 * torch.mean(torch.abs(kspace_data_2d)) 
kspace_data_2d_noisy = kspace_data_2d + noise_level_2d * (torch.randn_like(kspace_data_2d.real) + 1j * torch.randn_like(kspace_data_2d.real))
print("Noisy 2D k-space data generated.")

### d. Iterative Reconstruction (CG - 2D)

In [None]:
print("Running 2D iterative reconstruction (CG)...")
recon_img_2d_cg = iterative_recon(kspace_data=kspace_data_2d_noisy, 
                                  nufft_op=nufft_op_2d, 
                                  num_iters=10)

fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(phantom_2d.abs().cpu().numpy(), cmap='gray')
axs[0].set_title('Original 2D Phantom')
axs[0].axis('off')
axs[1].imshow(recon_img_2d_cg.abs().cpu().numpy(), cmap='gray')
axs[1].set_title('Reconstructed 2D Image (CG)')
axs[1].axis('off')
plt.tight_layout()
plt.show()

### e. L2 Reconstruction (2D)

In [None]:
print("Running 2D L2 reconstruction...")
l2_recon_module_2d = L2Reconstruction(linear_operator=nufft_op_2d, num_iterations=15, learning_rate=0.1)
recon_img_2d_l2 = l2_recon_module_2d.forward(kspace_data_2d_noisy)

fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(phantom_2d.abs().cpu().numpy(), cmap='gray')
axs[0].set_title('Original 2D Phantom')
axs[0].axis('off')
axs[1].imshow(recon_img_2d_l2.abs().cpu().numpy(), cmap='gray')
axs[1].set_title('Reconstructed 2D Image (L2)')
axs[1].axis('off')
plt.tight_layout()
plt.show()

### f. L1 Reconstruction (2D)

In [None]:
print("Running 2D L1 reconstruction...")
l1_recon_module_2d = L1Reconstruction(linear_operator=nufft_op_2d, num_iterations=20, lambda_reg=0.001, learning_rate=0.1)
recon_img_2d_l1 = l1_recon_module_2d.forward(kspace_data_2d_noisy)

fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(phantom_2d.abs().cpu().numpy(), cmap='gray')
axs[0].set_title('Original 2D Phantom')
axs[0].axis('off')
axs[1].imshow(recon_img_2d_l1.abs().cpu().numpy(), cmap='gray')
axs[1].set_title('Reconstructed 2D Image (L1)')
axs[1].axis('off')
plt.tight_layout()
plt.show()

## 4. 3D Reconstruction Demo

### a. Phantom and Trajectory (3D)

In [None]:
Nz_3d, Ny_3d, Nx_3d = 32, 32, 32
image_shape_3d = (Nz_3d, Ny_3d, Nx_3d)

phantom_3d = generate_phantom_3d(shape=image_shape_3d, device=device).to(torch.complex64)
k_traj_3d = generate_radial_trajectory_3d(num_profiles_z=32, 
                                        num_spokes_per_profile=32, 
                                        samples_per_spoke=32, 
                                        shape=image_shape_3d, 
                                        device=device)

# Plot center slice of 3D phantom
fig1, ax1 = plt.subplots(1, 1, figsize=(5, 5))
ax1.imshow(phantom_3d[Nz_3d // 2, :, :].abs().cpu().numpy(), cmap='gray')
ax1.set_title(f'Original 3D Phantom (Slice {Nz_3d // 2})')
ax1.axis('off')
plt.show()

# Plot 3D k-space trajectory (subset for clarity)
fig2 = plt.figure(figsize=(6,6))
ax2 = fig2.add_subplot(111, projection='3d')
num_points_to_plot = min(k_traj_3d.shape[0], 2000) # Plot up to 2000 points
ax2.scatter(k_traj_3d[:num_points_to_plot, 0].cpu().numpy(), 
            k_traj_3d[:num_points_to_plot, 1].cpu().numpy(), 
            k_traj_3d[:num_points_to_plot, 2].cpu().numpy(), s=0.5)
ax2.set_title('3D Stack-of-Stars K-space Trajectory (Subset)')
ax2.set_xlabel('kx')
ax2.set_ylabel('ky')
ax2.set_zlabel('kz')
plt.show()

### b. NUFFTOperator Setup (3D)

In [None]:
oversamp_factor_3d = (1.5, 1.5, 1.5) # Reduced for speed/memory
kb_J_3d = (4, 4, 4)
kb_alpha_3d = tuple(2.34 * J_d for J_d in kb_J_3d)
Ld_3d = (2**8, 2**8, 2**8) # Reduced for speed/memory
Kd_3d = tuple(int(N * os) for N, os in zip(image_shape_3d, oversamp_factor_3d))
n_shift_3d = (0.0, 0.0, 0.0)
kb_m_3d = (0.0,0.0,0.0)

nufft_op_3d = NUFFTOperator(k_trajectory=k_traj_3d, 
                            image_shape=image_shape_3d, 
                            oversamp_factor=oversamp_factor_3d, 
                            kb_J=kb_J_3d, 
                            kb_alpha=kb_alpha_3d, 
                            Ld=Ld_3d,
                            Kd=Kd_3d,
                            kb_m=kb_m_3d,
                            n_shift=n_shift_3d, 
                            device=device, 
                            nufft_type_3d='table') # Explicitly use table-based NUFFT
print("NUFFTOperator for 3D (table-based, linear interpolation) created: nufft_op_3d")

#### Demonstrating NUFFT with Nearest Neighbor Interpolation
We can also choose `interpolation_order=0` for nearest neighbor interpolation in the table lookup, which might be faster but less accurate than linear interpolation (`order=1`).

In [None]:
# Using the same parameters as nufft_op_3d, but with interpolation_order=0
nufft_op_3d_nn = NUFFTOperator(k_trajectory=k_traj_3d, 
                               image_shape=image_shape_3d, 
                               oversamp_factor=oversamp_factor_3d, 
                               kb_J=kb_J_3d, 
                               kb_alpha=kb_alpha_3d, 
                               Ld=Ld_3d,
                               Kd=Kd_3d,
                               kb_m=kb_m_3d,
                               n_shift=n_shift_3d, 
                               interpolation_order=0, # Specify Nearest Neighbor
                               device=device, 
                               nufft_type_3d='table')
print("NUFFTOperator for 3D (Nearest Neighbor Interpolation) created: nufft_op_3d_nn")

### c. Simulate K-Space Data (3D - Linear Interpolation)

In [None]:
print("Simulating 3D k-space data (using linear interpolation NUFFT)...")
kspace_data_3d = nufft_op_3d.op(phantom_3d)

# Add a small amount of Gaussian noise
noise_level_3d = 0.01 * torch.mean(torch.abs(kspace_data_3d))
kspace_data_3d_noisy = kspace_data_3d + noise_level_3d * (torch.randn_like(kspace_data_3d.real) + 1j * torch.randn_like(kspace_data_3d.real))
print("Noisy 3D k-space data (linear interp) generated.")

### c.2. Simulate K-Space Data (3D - Nearest Neighbor)

In [None]:
print("Simulating 3D k-space data (using nearest neighbor NUFFT)...")
kspace_data_3d_nn = nufft_op_3d_nn.op(phantom_3d)

# Add the same level of Gaussian noise for fair comparison
kspace_data_3d_nn_noisy = kspace_data_3d_nn + noise_level_3d * (torch.randn_like(kspace_data_3d_nn.real) + 1j * torch.randn_like(kspace_data_3d_nn.real))
print("Noisy 3D k-space data (nearest neighbor) generated.")

### d. Iterative Reconstruction (CG - 3D - Linear Interpolation)

In [None]:
print("Running 3D iterative reconstruction (CG - Linear Interpolation)...")
recon_img_3d_cg = iterative_recon(kspace_data=kspace_data_3d_noisy, 
                                  nufft_op=nufft_op_3d, 
                                  num_iters=5) # Reduced iterations for speed

fig, axs = plt.subplots(1, 2, figsize=(10, 5))
center_slice_idx = image_shape_3d[0] // 2
axs[0].imshow(phantom_3d[center_slice_idx, :, :].abs().cpu().numpy(), cmap='gray')
axs[0].set_title(f'Original 3D Phantom (Slice {center_slice_idx})')
axs[0].axis('off')
axs[1].imshow(recon_img_3d_cg[center_slice_idx, :, :].abs().cpu().numpy(), cmap='gray')
axs[1].set_title(f'Reconstructed 3D (CG - Linear) (Slice {center_slice_idx})')
axs[1].axis('off')
plt.tight_layout()
plt.show()

### d.2. Iterative Reconstruction (CG - 3D - Nearest Neighbor)

In [None]:
print("Running 3D iterative reconstruction (CG - Nearest Neighbor)...")
recon_img_3d_cg_nn = iterative_recon(kspace_data=kspace_data_3d_nn_noisy, 
                                     nufft_op=nufft_op_3d_nn, 
                                     num_iters=5) # Reduced iterations for speed

fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(phantom_3d[center_slice_idx, :, :].abs().cpu().numpy(), cmap='gray')
axs[0].set_title(f'Original 3D Phantom (Slice {center_slice_idx})')
axs[0].axis('off')
axs[1].imshow(recon_img_3d_cg_nn[center_slice_idx, :, :].abs().cpu().numpy(), cmap='gray')
axs[1].set_title(f'Reconstructed 3D (CG - NN) (Slice {center_slice_idx})')
axs[1].axis('off')
plt.tight_layout()
plt.show()

### e. L2 Reconstruction (3D)

In [None]:
print("Running 3D L2 reconstruction...")
l2_recon_module_3d = L2Reconstruction(linear_operator=nufft_op_3d, num_iterations=10, learning_rate=0.1)
recon_img_3d_l2 = l2_recon_module_3d.forward(kspace_data_3d_noisy)

fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(phantom_3d[center_slice_idx, :, :].abs().cpu().numpy(), cmap='gray')
axs[0].set_title(f'Original 3D Phantom (Slice {center_slice_idx})')
axs[0].axis('off')
axs[1].imshow(recon_img_3d_l2[center_slice_idx, :, :].abs().cpu().numpy(), cmap='gray')
axs[1].set_title(f'Reconstructed 3D (L2) (Slice {center_slice_idx})')
axs[1].axis('off')
plt.tight_layout()
plt.show()

### f. L1 Reconstruction (3D)

In [None]:
print("Running 3D L1 reconstruction...")
l1_recon_module_3d = L1Reconstruction(linear_operator=nufft_op_3d, num_iterations=15, lambda_reg=0.005, learning_rate=0.1)
recon_img_3d_l1 = l1_recon_module_3d.forward(kspace_data_3d_noisy)

fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(phantom_3d[center_slice_idx, :, :].abs().cpu().numpy(), cmap='gray')
axs[0].set_title(f'Original 3D Phantom (Slice {center_slice_idx})')
axs[0].axis('off')
axs[1].imshow(recon_img_3d_l1[center_slice_idx, :, :].abs().cpu().numpy(), cmap='gray')
axs[1].set_title(f'Reconstructed 3D (L1) (Slice {center_slice_idx})')
axs[1].axis('off')
plt.tight_layout()
plt.show()

## 5. Conclusion

This notebook demonstrated the use of `NUFFTOperator` with the self-contained Python-based (table method) NUFFT engine for both 2D and 3D MRI reconstruction scenarios. It showcased iterative reconstruction using Conjugate Gradient (from `iternufft.py`) and L2/L1 regularized reconstructions (from `l1l2recon.py`).

## 6. Adjointness Test

The adjoint property of an operator A is defined by the relation `<Ax, y> = <x, A*y>`, where `A*` is the adjoint of A, and `<u,v>` is the inner product (dot product) `sum(u * conj(v))`. This test is crucial for verifying the correctness of the forward and adjoint NUFFT implementations, especially for iterative reconstruction algorithms that rely on this property.

In [None]:
print("--- Adjointness Test for 2D NUFFTOperator ---")
# Ensure nufft_op_2d and image_shape_2d are available from previous cells

# Create random complex data for image domain x_2d and k-space domain y_2d
x_2d = torch.randn(image_shape_2d, dtype=torch.complex64, device=device)
y_2d_shape = (k_traj_2d.shape[0],) # k-space data is 1D vector of k-space points
y_2d = torch.randn(y_2d_shape, dtype=torch.complex64, device=device)

# Compute Ax_2d and Aty_2d
Ax_2d = nufft_op_2d.op(x_2d)
Aty_2d = nufft_op_2d.op_adj(y_2d)

# Calculate dot products
lhs_2d = torch.sum(Ax_2d * torch.conj(y_2d))
rhs_2d = torch.sum(x_2d * torch.conj(Aty_2d))

print(f"LHS (<Ax, y>): {lhs_2d.item()}")
print(f"RHS (<x, A*y>): {rhs_2d.item()}")

abs_diff_2d = torch.abs(lhs_2d - rhs_2d).item()
print(f"Absolute Difference: {abs_diff_2d}")

if torch.abs(lhs_2d) > 1e-9:
    rel_diff_2d = abs_diff_2d / torch.abs(lhs_2d).item()
    print(f"Relative Difference: {rel_diff_2d}")
else:
    print("LHS is near zero, relative difference is not meaningful.")

In [None]:
print("\n--- Adjointness Test for 3D NUFFTOperator ---")
# Ensure nufft_op_3d and image_shape_3d are available from previous cells

# Create random complex data for image domain x_3d and k-space domain y_3d
x_3d = torch.randn(image_shape_3d, dtype=torch.complex64, device=device)
y_3d_shape = (k_traj_3d.shape[0],) # k-space data is 1D vector
y_3d = torch.randn(y_3d_shape, dtype=torch.complex64, device=device)

# Compute Ax_3d and Aty_3d
Ax_3d = nufft_op_3d.op(x_3d)
Aty_3d = nufft_op_3d.op_adj(y_3d)

# Calculate dot products
lhs_3d = torch.sum(Ax_3d * torch.conj(y_3d))
rhs_3d = torch.sum(x_3d * torch.conj(Aty_3d))

print(f"LHS (<Ax, y>): {lhs_3d.item()}")
print(f"RHS (<x, A*y>): {rhs_3d.item()}")

abs_diff_3d = torch.abs(lhs_3d - rhs_3d).item()
print(f"Absolute Difference: {abs_diff_3d}")

if torch.abs(lhs_3d) > 1e-9:
    rel_diff_3d = abs_diff_3d / torch.abs(lhs_3d).item()
    print(f"Relative Difference: {rel_diff_3d}")
else:
    print("LHS is near zero, relative difference is not meaningful.")