# User Guide: Compressed Sensing MRI Reconstruction with L1-Wavelet Regularization

This tutorial demonstrates Compressed Sensing (CS) for MRI reconstruction using `reconlib`. We will specifically focus on L1 regularization in the Wavelet domain to reconstruct an image from undersampled k-space data.

**Key Concepts:**
- **Undersampling:** Acquiring fewer k-space samples than required by the Nyquist criterion, enabling faster scans.
- **Sparsity:** Many images, or their representations in a specific domain (like wavelets), have few non-zero coefficients. CS leverages this.
- **L1 Regularization:** Promotes sparsity by penalizing the L1 norm of the coefficients in the sparsifying domain.
- **Iterative Reconstruction:** Solvers like FISTA are used to solve the resulting optimization problem.

**Notebook Outline:**
1. Setup and Imports.
2. Define simulation parameters (image size, undersampling, noise, wavelet, CS strength).
3. Generate phantom and fully sampled k-space trajectory.
4. Simulate full k-space data, then retrospectively undersample it and add noise.
5. Set up NUFFT operator for reconstruction from undersampled data.
6. Reconstruct using Zero-Filled Adjoint (for baseline comparison).
7. Reconstruct using Conjugate Gradient (for non-CS iterative comparison).
8. Perform Compressed Sensing reconstruction using FISTA and L1-Wavelet regularization.
9. Visualize and compare all results.

## 1. Setup and Imports

In [None]:
%matplotlib inline
import torch
import numpy as np
import matplotlib.pyplot as plt

# ReconLib imports
try:
    from reconlib.nufft import NUFFT2D
    from reconlib.solvers import fista_reconstruction, conjugate_gradient_reconstruction
    from reconlib.wavelets_scratch import WaveletRegularizationTerm 
    # For data simulation
    from iternufft import generate_phantom_2d, generate_radial_trajectory_2d
    print("Successfully imported reconlib components.")
except ImportError as e:
    print(f"Error importing modules: {e}")
    print("Please ensure reconlib is installed and iternufft.py is in the Python path.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Simulation Parameters

In [None]:
IMAGE_SIZE = 128
FULL_N_SPOKES = 128  # Number of spokes for a fully sampled trajectory
N_SAMPLES_PER_SPOKE = int(IMAGE_SIZE * 1.5) # Samples along each spoke

# Undersampling parameters
UNDERSAMPLING_FACTOR = 4 # e.g., 4x acceleration
N_SPOKES_UNDERSAMPLED = FULL_N_SPOKES // UNDERSAMPLING_FACTOR

NOISE_STD_PERCENT = 0.02 # Noise level (percentage of max clean k-space signal)

# Wavelet and CS parameters
WAVELET_NAME = 'db4' # Daubechies 4
WAVELET_LEVELS = 3
LAMBDA_CS = 0.002 # Regularization strength for L1-wavelet CS

print(f"Image Size: {IMAGE_SIZE}x{IMAGE_SIZE}")
print(f"Fully Sampled Spokes: {FULL_N_SPOKES}")
print(f"Undersampled Spokes: {N_SPOKES_UNDERSAMPLED} ({UNDERSAMPLING_FACTOR}x acceleration)")
print(f"Noise Level: {NOISE_STD_PERCENT*100}%" )
print(f"Wavelet: {WAVELET_NAME}, Levels: {WAVELET_LEVELS}")
print(f"CS Lambda: {LAMBDA_CS}")

## 3. Generate Phantom and Full K-Space Trajectory

In [None]:
phantom_img = generate_phantom_2d(size=IMAGE_SIZE, device=device)
phantom_complex = phantom_img.to(torch.complex64)

k_trajectory_full = generate_radial_trajectory_2d(
    num_spokes=FULL_N_SPOKES,
    samples_per_spoke=N_SAMPLES_PER_SPOKE,
    device=device
)

fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].imshow(phantom_complex.abs().cpu().numpy(), cmap='gray')
axes[0].set_title(f"Original Phantom ({IMAGE_SIZE}x{IMAGE_SIZE})")
axes[0].axis('off')
axes[1].scatter(k_trajectory_full[:, 0].cpu().numpy(), k_trajectory_full[:, 1].cpu().numpy(), s=0.1)
axes[1].set_title(f"Fully Sampled Trajectory ({FULL_N_SPOKES} spokes)")
axes[1].set_xlabel("kx"); axes[1].set_ylabel("ky")
axes[1].axis('square');
plt.tight_layout(); plt.show()

## 4. Simulate Full K-space Data, Undersample, and Add Noise

In [None]:
# Setup NUFFT for simulation (using full trajectory)
nufft_params_sim = {
    'image_shape': (IMAGE_SIZE, IMAGE_SIZE), 'k_trajectory': k_trajectory_full,
    'oversamp_factor': (2.0, 2.0), 'kb_J': (4, 4),
    'kb_alpha': (2.34 * 4, 2.34 * 4), 'Ld': (1024, 1024),
    'kb_m': (0.0, 0.0), 'device': device
}
nufft_op_sim = NUFFT2D(**nufft_params_sim)

print("Simulating fully sampled k-space data...")
k_space_data_full_clean = nufft_op_sim.forward(phantom_complex)

# Undersampling: Randomly select spokes
num_total_points_full = k_trajectory_full.shape[0]
points_per_spoke = N_SAMPLES_PER_SPOKE

spoke_indices = torch.arange(FULL_N_SPOKES)
selected_spoke_indices = spoke_indices[torch.randperm(FULL_N_SPOKES)[:N_SPOKES_UNDERSAMPLED]]

undersampling_mask = torch.zeros(num_total_points_full, dtype=torch.bool, device=device)
for i in selected_spoke_indices:
    undersampling_mask[i*points_per_spoke : (i+1)*points_per_spoke] = True

k_trajectory_undersampled = k_trajectory_full[undersampling_mask, :]
k_space_data_undersampled_clean = k_space_data_full_clean[undersampling_mask]
print(f"Undersampled trajectory shape: {k_trajectory_undersampled.shape}")
print(f"Undersampled clean k-space shape: {k_space_data_undersampled_clean.shape}")

# Add noise to the undersampled k-space data
noise_std_val = NOISE_STD_PERCENT * torch.max(torch.abs(k_space_data_undersampled_clean))
complex_noise = torch.complex(
    torch.randn_like(k_space_data_undersampled_clean.real) * noise_std_val,
    torch.randn_like(k_space_data_undersampled_clean.imag) * noise_std_val
).to(device)
k_space_data_undersampled_noisy = k_space_data_undersampled_clean + complex_noise

fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].scatter(k_trajectory_undersampled[:, 0].cpu().numpy(), k_trajectory_undersampled[:, 1].cpu().numpy(), s=0.1, color='blue')
axes[0].set_title(f"Undersampled Trajectory ({N_SPOKES_UNDERSAMPLED} spokes)")
axes[0].set_xlabel("kx"); axes[0].set_ylabel("ky")
axes[0].axis('square');

im = axes[1].scatter(k_trajectory_undersampled[:,0].cpu(), k_trajectory_undersampled[:,1].cpu(), 
                c=torch.log(torch.abs(k_space_data_undersampled_noisy.cpu()) + 1e-9), 
                s=1, cmap='viridis')
axes[1].set_title("Undersampled Noisy K-Space")
axes[1].set_xlabel("kx"); axes[1].set_ylabel("ky")
axes[1].axis('square');
fig.colorbar(im, ax=axes[1], label="log|k-space data|")
plt.tight_layout(); plt.show()

## 5. NUFFT Operator for Reconstruction

In [None]:
# --- 5. NUFFT Operator for Reconstruction ---
# This NUFFT operator will be configured with the UNDERSAMPLED k-space trajectory.
# It will be used by all subsequent reconstruction methods.

print("\n--- Setting up NUFFT Operator for Reconstruction (Undersampled Data) ---")
try:
    # Ensure k_trajectory_undersampled and other params are available
    if 'k_trajectory_undersampled' not in globals() or \
       'IMAGE_SIZE' not in globals() or \
       'device' not in globals() or \
       'NUFFT2D' not in globals():
        raise NameError("Required variables for NUFFT setup are not defined. Run previous cells.")

    nufft_recon_kwargs_cs = {
        'oversamp_factor': (2.0, 2.0), 
        'kb_J': (4, 4),      
        'kb_alpha': (2.34 * 4, 2.34 * 4), 
        'Ld': (1024, 1024), 
        'kb_m': (0.0, 0.0)
        # 'density_comp_weights': None is implied. For CS, the regularization handles artifacts.
        # If we wanted to use some basic DCF, it could be passed here.
    }
    
    # This nufft_op_recon will be used for all reconstruction methods below
    nufft_op_recon = NUFFT2D(
        image_shape=(IMAGE_SIZE, IMAGE_SIZE),
        k_trajectory=k_trajectory_undersampled, # CRITICAL: Use undersampled trajectory
        device=device,
        **nufft_recon_kwargs_cs 
    )
    print(f"NUFFT2D operator for reconstruction created successfully.")
    print(f"Configured with undersampled trajectory shape: {k_trajectory_undersampled.shape}")
    if hasattr(nufft_op_recon, 'density_comp_weights') and nufft_op_recon.density_comp_weights is not None:
         print(f"NUFFT op has density_comp_weights of shape: {nufft_op_recon.density_comp_weights.shape}")
    else:
        print("NUFFT op configured without explicit external density_comp_weights (will use internal default if any).")

except NameError as e:
    print(f"NameError: {e}. Could not set up NUFFT operator.")
    nufft_op_recon = None # Placeholder
except Exception as e:
    print(f"An error occurred setting up the NUFFT operator: {e}")
    nufft_op_recon = None # Placeholder

## 6. Comparison: Zero-Filled Adjoint Reconstruction

In [None]:
# --- 6. Comparison: Zero-Filled Adjoint Reconstruction ---
# This is the simplest possible reconstruction from the undersampled k-space data.
# It involves applying the adjoint NUFFT operation directly.
# The nufft_op_recon (configured with undersampled trajectory and default DCF) is used.

print("\n--- Performing Zero-Filled Adjoint Reconstruction ---")
try:
    if 'nufft_op_recon' not in globals() or \
       'k_space_data_undersampled_noisy' not in globals() or \
       'IMAGE_SIZE' not in globals():
        raise NameError("Required variables (nufft_op_recon, k_space_data_undersampled_noisy, IMAGE_SIZE) not found.")

    if nufft_op_recon is None:
        raise ValueError("nufft_op_recon was not properly initialized in a previous cell.")

    recon_adj = nufft_op_recon.adjoint(k_space_data_undersampled_noisy)
    
    print(f"Zero-filled adjoint reconstruction complete. Shape: {recon_adj.shape}")

    plt.figure(figsize=(5, 5))
    plt.imshow(torch.abs(recon_adj.cpu()).numpy(), cmap='gray')
    plt.title("Recon: Zero-Filled Adjoint (Undersampled)")
    plt.axis('off')
    plt.show()

except NameError as e:
    print(f"NameError: {e}. Cannot perform adjoint reconstruction.")
    recon_adj = torch.zeros((IMAGE_SIZE, IMAGE_SIZE), dtype=torch.complex64, device=device) # Placeholder
except Exception as e:
    print(f"An error occurred during adjoint reconstruction: {e}")
    recon_adj = torch.zeros((IMAGE_SIZE, IMAGE_SIZE), dtype=torch.complex64, device=device) # Placeholder

## 7. Comparison: Conjugate Gradient Reconstruction

In [None]:
# --- 7. Comparison: Conjugate Gradient Reconstruction ---
# This reconstruction uses the Conjugate Gradient algorithm on the undersampled
# k-space data, without any explicit sparsity regularization.
# It solves the normal equations A^H A x = A^H y, where A is the NUFFT operator
# (nufft_op_recon) configured with the undersampled trajectory and default/minimal DCF.

print("\n--- Performing Conjugate Gradient Reconstruction (Undersampled Data) ---")

# CG parameters
CG_ITERS_COMPARISON = 20 # Number of CG iterations
CG_TOL_COMPARISON = 1e-6   # Tolerance for CG convergence

try:
    # Ensure necessary variables/classes are defined
    if 'k_space_data_undersampled_noisy' not in globals() or \
       'IMAGE_SIZE' not in globals() or \
       'k_trajectory_undersampled' not in globals() or \
       'device' not in globals() or \
       'nufft_recon_kwargs_cs' not in globals() or \
       'NUFFT2D' not in globals() or \
       'conjugate_gradient_reconstruction' not in globals(): # Defined in cell 3 (imports)
        raise NameError("One or more required variables/functions for CG recon are not defined. Run previous cells.")

    print(f"Starting CG reconstruction (iters={CG_ITERS_COMPARISON}, tol={CG_TOL_COMPARISON:.1e})...")
    
    # nufft_op_recon (from cell 11) is already configured with the undersampled trajectory.
    # The nufft_recon_kwargs_cs (from cell 11) will be used by the solver.
    recon_cg_undersampled = conjugate_gradient_reconstruction(
        kspace_data=k_space_data_undersampled_noisy,
        sampling_points=k_trajectory_undersampled,
        image_shape=(IMAGE_SIZE, IMAGE_SIZE),
        nufft_operator_class=NUFFT2D,
        nufft_kwargs=nufft_recon_kwargs_cs, # Using the same NUFFT params as CS recon for consistency
        max_iters=CG_ITERS_COMPARISON,
        tol=CG_TOL_COMPARISON
        # use_voronoi=False, voronoi_weights=None by default
    )
    
    print("CG reconstruction (undersampled data) complete.")
    print(f"CG reconstructed image shape: {recon_cg_undersampled.shape}")

    plt.figure(figsize=(5, 5))
    plt.imshow(torch.abs(recon_cg_undersampled.cpu()).numpy(), cmap='gray')
    plt.title(f"Recon: CG (Undersampled, {CG_ITERS_COMPARISON} iters)")
    plt.axis('off')
    plt.show()

except NameError as e:
    print(f"NameError: {e}. Could not perform CG reconstruction.")
    recon_cg_undersampled = torch.zeros((IMAGE_SIZE, IMAGE_SIZE), dtype=torch.complex64, device=device) # Placeholder
except Exception as e:
    print(f"An error occurred during CG reconstruction: {e}")
    recon_cg_undersampled = torch.zeros((IMAGE_SIZE, IMAGE_SIZE), dtype=torch.complex64, device=device) # Placeholder

## 8. Compressed Sensing Reconstruction (L1-Wavelet)

In [None]:
# --- 8. Compressed Sensing Reconstruction (L1-Wavelet) ---
# Here, we use the FISTA algorithm with the WaveletRegularizationTerm.
# This regularizer applies an L1 penalty to the wavelet coefficients of the image,
# promoting sparsity in the wavelet domain.

print("\n--- Performing Compressed Sensing (L1-Wavelet) Reconstruction ---")

# FISTA parameters for CS
FISTA_ITERS_CS = 75  # May need more iterations for CS
FISTA_TOL_CS = 1e-5
# LAMBDA_CS, WAVELET_NAME, WAVELET_LEVELS are defined in cell with index 5 ("Simulation Parameters")

try:
    # Ensure necessary variables/classes are defined
    if 'k_space_data_undersampled_noisy' not in globals() or \
       'IMAGE_SIZE' not in globals() or \
       'k_trajectory_undersampled' not in globals() or \
       'device' not in globals() or \
       'nufft_op_recon' not in globals() or \
       'fista_reconstruction' not in globals() or \
       'WaveletRegularizationTerm' not in globals() or \
       'LAMBDA_CS' not in globals() or \
       'WAVELET_NAME' not in globals() or \
       'WAVELET_LEVELS' not in globals():
        raise NameError("One or more required variables/classes for CS recon are not defined. Run previous cells.")

    # 1. Instantiate WaveletRegularizationTerm
    cs_regularizer = WaveletRegularizationTerm(
        lambda_reg=LAMBDA_CS,       # This lambda is used by the prox op for thresholding
        wavelet_name=WAVELET_NAME,
        level=WAVELET_LEVELS,
        device=device
    )
    print(f"WaveletRegularizationTerm instantiated: wavelet='{WAVELET_NAME}', levels={WAVELET_LEVELS}, lambda_cs={LAMBDA_CS}")

    # 2. Perform FISTA reconstruction
    # The nufft_op_recon is already configured with the undersampled trajectory
    # and appropriate NUFFT parameters (but no explicit DCF, as CS handles it).
    print(f"Starting FISTA for CS (iters={FISTA_ITERS_CS}, tol={FISTA_TOL_CS:.1e})...")
    
    recon_cs = fista_reconstruction(
        kspace_data=k_space_data_undersampled_noisy,
        sampling_points=k_trajectory_undersampled, # Passed to NUFFT op
        image_shape=(IMAGE_SIZE, IMAGE_SIZE),      # Passed to NUFFT op
        nufft_operator_class=NUFFT2D,              # NUFFT op class
        nufft_kwargs=nufft_recon_kwargs_cs,        # NUFFT op parameters (from cell 11)
        regularizer=cs_regularizer,
        lambda_reg=1.0, # Set to 1.0, as WaveletRegularizationTerm's lambda_reg is the true strength. 
        max_iters=FISTA_ITERS_CS,
        tol=FISTA_TOL_CS,
        verbose=True # Enable verbose output for FISTA
    )
    
    print("Compressed Sensing reconstruction complete.")
    print(f"CS reconstructed image shape: {recon_cs.shape}")

    # 3. Display the CS reconstructed image
    plt.figure(figsize=(5, 5))
    plt.imshow(torch.abs(recon_cs.cpu()).numpy(), cmap='gray')
    plt.title(f"CS Recon: FISTA + L1-Wavelet (lambda={LAMBDA_CS}, {FISTA_ITERS_CS} iters)")
    plt.axis('off')
    plt.show()

except NameError as e:
    print(f"NameError: {e}. Could not perform CS reconstruction.")
    recon_cs = torch.zeros((IMAGE_SIZE, IMAGE_SIZE), dtype=torch.complex64, device=device) # Placeholder
except Exception as e:
    print(f"An error occurred during CS reconstruction: {e}")
    import traceback
    traceback.print_exc()
    recon_cs = torch.zeros((IMAGE_SIZE, IMAGE_SIZE), dtype=torch.complex64, device=device) # Placeholder

## 9. Results Visualization and Comparison

## 10. Conclusion and Discussion