# User Guide: Reconstruction with Regularizers

This tutorial demonstrates how to use regularizers within `reconlib` to improve MRI reconstructions, particularly when dealing with noise or undersampled data. Regularization incorporates prior knowledge about the image to guide the reconstruction process.

We will focus on using the Fast Iterative Shrinkage-Thresholding Algorithm (FISTA), which is well-suited for problems involving a differentiable data fidelity term and a regularizer with a known proximal operator.

**Steps Covered:**
1. Setup and necessary imports.
2. Generation of a phantom and non-Cartesian k-space data (similar to the basic pipeline tutorial, perhaps with more noise).
3. Introduction to L1 and Total Variation (TV) regularizers.
4. Performing reconstruction using FISTA with L1 regularization.
5. Performing reconstruction using FISTA with TV regularization.
6. Comparing the results and discussing the impact of regularization.

## 1. Setup and Imports

In [None]:
%matplotlib inline
import torch
import numpy as np
import matplotlib.pyplot as plt

# ReconLib imports
try:
    from reconlib.nufft import NUFFT2D
    from reconlib.solvers import fista_reconstruction # Using FISTA for this tutorial
    from reconlib.regularizers.common import L1Regularizer, TVRegularizer
    # For data simulation - assuming functions from iternufft.py are suitable
    from iternufft import generate_phantom_2d, generate_radial_trajectory_2d
    print("Successfully imported reconlib components and simulation utilities.")
except ImportError as e:
    print(f"Error importing modules: {e}")
    print("Please ensure reconlib is installed and iternufft.py is in the Python path.")

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Data Generation

We'll use a similar data generation process as in the basic pipeline tutorial. You can experiment with `NOISE_STD_PERCENT` (e.g., increase it to 0.05 or 0.1) or reduce `N_SPOKES` to see a more pronounced effect of regularization.

In [None]:
# --- Configuration ---
IMAGE_SIZE = 128
N_SPOKES = 96  # Slightly fewer spokes to simulate moderate undersampling
N_SAMPLES_PER_SPOKE = int(IMAGE_SIZE * 1.5)
NOISE_STD_PERCENT = 0.05 # Increased noise level

# --- Generate Phantom ---
try:
    phantom_img = generate_phantom_2d(size=IMAGE_SIZE, device=device)
    phantom_complex = phantom_img.to(torch.complex64)
    print(f"Phantom generated: {phantom_img.shape}")
except NameError:
    print("generate_phantom_2d not available. Placeholder used.")
    phantom_complex = torch.zeros((IMAGE_SIZE, IMAGE_SIZE), dtype=torch.complex64, device=device)

plt.figure(figsize=(4, 4))
plt.imshow(phantom_complex.abs().cpu().numpy(), cmap='gray')
plt.title(f"Original Phantom ({IMAGE_SIZE}x{IMAGE_SIZE})")
plt.axis('off'); plt.show()

# --- Generate K-Space Trajectory ---
try:
    k_trajectory = generate_radial_trajectory_2d(num_spokes=N_SPOKES, samples_per_spoke=N_SAMPLES_PER_SPOKE, device=device)
    print(f"K-space trajectory generated: {k_trajectory.shape}")
except NameError:
    print("generate_radial_trajectory_2d not available. Placeholder used.")
    k_trajectory = torch.rand((N_SPOKES * N_SAMPLES_PER_SPOKE, 2), device=device) - 0.5

# --- Simulate K-Space Data (Forward NUFFT) and Add Noise ---
try:
    nufft_params_sim = {
        'image_shape': (IMAGE_SIZE, IMAGE_SIZE), 'k_trajectory': k_trajectory,
        'oversamp_factor': (2.0, 2.0), 'kb_J': (4, 4),
        'kb_alpha': (2.34 * 4, 2.34 * 4), 'Ld': (1024, 1024),
        'kb_m': (0.0, 0.0), 'device': device
    }
    nufft_op_sim = NUFFT2D(**nufft_params_sim)
    k_space_data_clean = nufft_op_sim.forward(phantom_complex)
    noise_std_val = NOISE_STD_PERCENT * torch.max(torch.abs(k_space_data_clean))
    complex_noise = torch.complex(torch.randn_like(k_space_data_clean.real) * noise_std_val, 
                                torch.randn_like(k_space_data_clean.imag) * noise_std_val).to(device)
    k_space_data_noisy = k_space_data_clean + complex_noise
    print(f"Noisy k-space data created: {k_space_data_noisy.shape}")
except Exception as e:
    print(f"Error in k-space simulation: {e}. Placeholders used.")
    k_space_data_noisy = torch.zeros((k_trajectory.shape[0] if 'k_trajectory' in globals() else 1), dtype=torch.complex64, device=device)

plt.figure(figsize=(4,4))
if k_space_data_noisy.numel() > 1:
    plt.scatter(k_trajectory[:,0].cpu(), k_trajectory[:,1].cpu(), 
                c=torch.log(torch.abs(k_space_data_noisy.cpu()) + 1e-9), 
                s=1, cmap='viridis')
    plt.colorbar(label="log|k-space data|")
plt.title("Simulated Noisy K-Space Data"); plt.xlabel("kx"); plt.ylabel("ky"); plt.axis('square'); plt.show()

## 3. NUFFT Operator Setup for Reconstruction

We define the NUFFT parameters that will be used by the FISTA solver. For regularized reconstruction, it's important that the NUFFT operator does *not* apply its own strong density compensation if the regularization is meant to handle aspects of this. We will use NUFFT with its default (or minimal) internal DCF.

In [None]:
try:
    nufft_recon_kwargs = {
        'oversamp_factor': (2.0, 2.0), 
        'kb_J': (4, 4),      
        'kb_alpha': (2.34 * 4, 2.34 * 4), 
        'Ld': (1024, 1024), 
        'kb_m': (0.0, 0.0)
        # 'density_comp_weights': None is implied if not specified,
        # relying on NUFFT2D's default internal handling.
    }
    # The solver will instantiate NUFFT2D with image_shape, k_trajectory, device, and these kwargs.
    print(f"NUFFT kwargs for reconstruction solvers: {nufft_recon_kwargs}")
except Exception as e:
    print(f"Error setting up NUFFT kwargs: {e}")
    nufft_recon_kwargs = {}

## 4. Reconstruction with L1 Regularization

In [None]:
# --- 4. Reconstruction with L1 Regularization ---
# L1 regularization promotes sparsity in the image domain.
# R(x) = lambda_l1 * ||x||_1

print("\n--- Performing L1 Regularized Reconstruction ---")

# L1 Regularization parameters
LAMBDA_L1 = 0.001  # Regularization strength for L1
FISTA_ITERS_L1 = 50 # Number of FISTA iterations
FISTA_TOL_L1 = 1e-5 # Tolerance for FISTA convergence

try:
    # Ensure necessary variables from previous cells are available
    if 'k_space_data_noisy' not in globals() or \
       'IMAGE_SIZE' not in globals() or \
       'k_trajectory' not in globals() or \
       'device' not in globals() or \
       'nufft_recon_kwargs' not in globals() or \
       'NUFFT2D' not in globals() or \
       'fista_reconstruction' not in globals() or \
       'L1Regularizer' not in globals():
        raise NameError("One or more required variables/functions are not defined. Please run previous cells.")

    # 1. Instantiate L1Regularizer
    l1_reg = L1Regularizer(lambda_reg=LAMBDA_L1)
    print(f"L1Regularizer instantiated with lambda_l1 = {LAMBDA_L1}")

    # 2. Perform FISTA reconstruction with L1 regularization
    print(f"Starting FISTA with L1 regularization (iters={FISTA_ITERS_L1}, tol={FISTA_TOL_L1:.1e})...")
    
    recon_l1 = fista_reconstruction(
        kspace_data=k_space_data_noisy,
        sampling_points=k_trajectory,
        image_shape=(IMAGE_SIZE, IMAGE_SIZE),
        nufft_operator_class=NUFFT2D,
        nufft_kwargs=nufft_recon_kwargs,
        regularizer=l1_reg,
        lambda_reg=LAMBDA_L1, # Note: FISTA uses this lambda with regularizer.proximal_operator
        max_iters=FISTA_ITERS_L1,
        tol=FISTA_TOL_L1,
        verbose=True # Enable verbose output to see FISTA progress
    )
    
    print("FISTA reconstruction with L1 regularization complete.")
    print(f"L1 reconstructed image shape: {recon_l1.shape}")

    # 3. Display the L1-regularized reconstructed image
    plt.figure(figsize=(5, 5))
    plt.imshow(torch.abs(recon_l1.cpu()).numpy(), cmap='gray')
    plt.title(f"Recon: FISTA + L1 (lambda={LAMBDA_L1}, {FISTA_ITERS_L1} iters)")
    plt.axis('off')
    plt.show()

except NameError as e:
    print(f"NameError: {e}. Could not perform L1 regularized reconstruction.")
    recon_l1 = torch.zeros((IMAGE_SIZE, IMAGE_SIZE), dtype=torch.complex64, device=device) # Placeholder
except Exception as e:
    print(f"An error occurred during FISTA L1 reconstruction: {e}")
    recon_l1 = torch.zeros((IMAGE_SIZE, IMAGE_SIZE), dtype=torch.complex64, device=device) # Placeholder

## 5. Reconstruction with Total Variation (TV) Regularization

In [None]:
# --- 5. Reconstruction with Total Variation (TV) Regularization ---
# Total Variation regularization penalizes the sum of the magnitudes of gradients
# in the image, promoting piecewise-constant solutions and preserving edges.
# R(x) = lambda_tv * TV(x)

print("\n--- Performing TV Regularized Reconstruction ---")

# TV Regularization parameters
LAMBDA_TV = 0.002  # Regularization strength for TV (often needs different scale than L1)
FISTA_ITERS_TV = 75 # TV might need more iterations or different tuning
FISTA_TOL_TV = 1e-5

try:
    # Ensure necessary variables from previous cells are available
    if 'k_space_data_noisy' not in globals() or \
       'IMAGE_SIZE' not in globals() or \
       'k_trajectory' not in globals() or \
       'device' not in globals() or \
       'nufft_recon_kwargs' not in globals() or \
       'NUFFT2D' not in globals() or \
       'fista_reconstruction' not in globals() or \
       'TVRegularizer' not in globals():
        raise NameError("One or more required variables/functions are not defined. Please run previous cells.")

    # 1. Instantiate TVRegularizer
    # TVRegularizer has additional parameters for its proximal operator solver (Chambolle's algorithm)
    tv_reg = TVRegularizer(
        lambda_param=LAMBDA_TV,      # Note: TVRegularizer uses lambda_param
        max_chambolle_iter=50,     # Iterations for the inner TV prox solver
        tol_chambolle=1e-4,        # Tolerance for the inner TV prox solver
        verbose_chambolle=False    # Verbosity for the inner TV prox solver
    )
    print(f"TVRegularizer instantiated with lambda_tv = {LAMBDA_TV}")

    # 2. Perform FISTA reconstruction with TV regularization
    print(f"Starting FISTA with TV regularization (iters={FISTA_ITERS_TV}, tol={FISTA_TOL_TV:.1e})...")
    
    recon_tv = fista_reconstruction(
        kspace_data=k_space_data_noisy,
        sampling_points=k_trajectory,
        image_shape=(IMAGE_SIZE, IMAGE_SIZE),
        nufft_operator_class=NUFFT2D,
        nufft_kwargs=nufft_recon_kwargs,
        regularizer=tv_reg,
        lambda_reg=1.0, # Set to 1.0, as TVRegularizer.lambda_param is the true strength.
        max_iters=FISTA_ITERS_TV,
        tol=FISTA_TOL_TV,
        verbose=True 
    )
    
    print("FISTA reconstruction with TV regularization complete.")
    print(f"TV reconstructed image shape: {recon_tv.shape}")

    # 3. Display the TV-regularized reconstructed image
    plt.figure(figsize=(5, 5))
    plt.imshow(torch.abs(recon_tv.cpu()).numpy(), cmap='gray')
    plt.title(f"Recon: FISTA + TV (lambda={LAMBDA_TV}, {FISTA_ITERS_TV} iters)")
    plt.axis('off')
    plt.show()

except NameError as e:
    print(f"NameError: {e}. Could not perform TV regularized reconstruction.")
    recon_tv = torch.zeros((IMAGE_SIZE, IMAGE_SIZE), dtype=torch.complex64, device=device) # Placeholder
except Exception as e:
    print(f"An error occurred during FISTA TV reconstruction: {e}")
    recon_tv = torch.zeros((IMAGE_SIZE, IMAGE_SIZE), dtype=torch.complex64, device=device) # Placeholder

## 6. Comparison and Conclusion