# Voronoi-Weighted Non-Cartesian MRI Reconstruction Comparison

This notebook provides a hands-on demonstration of non-Cartesian Magnetic Resonance Imaging (MRI) reconstruction. We focus on utilizing Voronoi tessellation for density compensation and compare its performance against other common reconstruction approaches.

**Key Concepts Covered:**
*   **Non-Cartesian Sampling:** Unlike conventional Cartesian sampling, non-Cartesian trajectories (e.g., radial, spiral) can offer advantages like motion robustness or faster scanning. However, they require specialized reconstruction algorithms.
*   **Density Compensation:** When gridding non-Cartesian k-space data onto a Cartesian grid for FFT, regions with higher sample density are overrepresented. Density Compensation Factors (DCF) are used to correct for this, typically by down-weighting densely sampled regions.
*   **Voronoi Density Compensation:** This method calculates DCF based on the area (2D) or volume (3D) of Voronoi cells surrounding each k-space sample. Larger cells (sparser sampling) get smaller weights (as weight = 1/area).
*   **NUFFT (Non-Uniform Fast Fourier Transform):** An essential tool for efficiently mapping data between a non-Cartesian grid in k-space and a Cartesian grid in image space.
*   **Iterative Reconstruction:** Algorithms like Conjugate Gradient (CG) can iteratively solve the reconstruction problem, often leading to better image quality than simple gridding, especially when combined with accurate system modeling (including NUFFT and DCF).

**Notebook Workflow:**
1.  **Setup and Imports:** Load necessary libraries and modules from `reconlib`.
2.  **Data Generation:** Create a 2D Shepp-Logan phantom and simulate non-Cartesian (radial) k-space data using a NUFFT forward operation. Noise is added to simulate realistic conditions.
3.  **Voronoi Weight Calculation:** Compute Voronoi density compensation weights from the k-space trajectory.
4.  **Image Reconstruction:** Reconstruct the image using three different methods:
    *   Conjugate Gradient (CG) with Voronoi weights.
    *   Simple gridding (adjoint NUFFT) using the NUFFT operator's default density compensation (if any).
    *   Conjugate Gradient (CG) using the NUFFT operator's default density compensation.
5.  **Comparison:** Visualize the reconstructed images and compare them qualitatively and quantitatively (MSE, PSNR, SSIM).

## 1. Setup and Imports

In [None]:
%matplotlib inline
import torch
import numpy as np
import matplotlib.pyplot as plt
from skimage.data import shepp_logan_phantom
from skimage.transform import resize

# Attempt to import reconlib
try:
    import reconlib
    from reconlib.nufft import NUFFT2D # Specific class for 2D
    from reconlib.solvers import conjugate_gradient_reconstruction
    from reconlib.voronoi.density_weights_pytorch import compute_voronoi_density_weights_pytorch # Corrected path
    from iternufft import generate_radial_trajectory_2d # Corrected import
except ImportError as e:
    print(f"Error importing reconlib or iternufft: {e}")
    print("Please ensure reconlib is installed and accessible in your Python path.")
    print("You might need to run 'pip install -e .' from the root of the repository.")

print(f"PyTorch version: {torch.__version__}")
print(f"NumPy version: {np.__version__}")
if 'reconlib' in locals() or 'reconlib' in globals():
    print(f"reconlib version: {reconlib.__version__ if hasattr(reconlib, '__version__') else 'unknown'}")
# Also print iternufft version if possible, though it might not have __version__
try:
    import iternufft
    print(f"iternufft version: {iternufft.__version__ if hasattr(iternufft, '__version__') else 'unknown'}")
except ImportError:
    print("iternufft not found, which is needed for generate_radial_trajectory_2d.")

## 2. Generate Phantom and Simulate K-Space Data

In [None]:
# --- Configuration for Data Generation ---
IMAGE_SIZE = 128
N_SPOKES = IMAGE_SIZE # Number of radial spokes
N_SAMPLES_PER_SPOKE = int(IMAGE_SIZE * 1.5) # Oversampling along spokes

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- 1. Generate Phantom ---
print("\n--- Generating Phantom ---")
phantom = shepp_logan_phantom()
phantom_resized = resize(phantom, (IMAGE_SIZE, IMAGE_SIZE), anti_aliasing=True)
phantom_tensor = torch.from_numpy(phantom_resized).float().to(device)
phantom_complex = phantom_tensor.to(torch.complex64) # NUFFT expects complex input

# Display phantom
plt.figure(figsize=(5, 5))
plt.imshow(phantom_tensor.cpu().numpy(), cmap='gray')
plt.title(f"Original Phantom ({IMAGE_SIZE}x{IMAGE_SIZE})")
plt.axis('off')
plt.show()

# --- 2. Generate K-Space Trajectory ---
print("\n--- Generating K-Space Trajectory ---")
# The import for generate_radial_trajectory_2d was updated by the previous subtask to be from iternufft
# This script will use that updated import.
try:
    from iternufft import generate_radial_trajectory_2d
    k_trajectory = generate_radial_trajectory_2d(
        num_spokes=N_SPOKES,
        samples_per_spoke=N_SAMPLES_PER_SPOKE,
        device=device
    )
except ImportError:
    print("Failed to import generate_radial_trajectory_2d from iternufft. Trajectory generation will fail.")
    print("Make sure iternufft.py is in the PYTHONPATH or the same directory.")
    k_trajectory = torch.zeros((N_SPOKES * N_SAMPLES_PER_SPOKE, 2), device=device) # Placeholder

# Display trajectory
plt.figure(figsize=(5, 5))
if k_trajectory.shape[0] > 0:
    plt.scatter(k_trajectory[:, 0].cpu().numpy(), k_trajectory[:, 1].cpu().numpy(), s=0.5)
plt.title(f"K-Space Trajectory ({N_SPOKES} spokes, {N_SAMPLES_PER_SPOKE} samples/spoke)")
plt.xlabel("kx")
plt.ylabel("ky")
plt.axis('square')
plt.show()

# --- 3. Setup NUFFT Operator for Simulation (no explicit DCF for forward) ---
print("\n--- Setting up NUFFT for Simulation ---")
# Standard NUFFT parameters (can be adjusted)
# These should be appropriate for NUFFT2D from reconlib.nufft
nufft_params_sim = {
    'image_shape': (IMAGE_SIZE, IMAGE_SIZE),
    'k_trajectory': k_trajectory,
    'oversamp_factor': (2.0, 2.0), # Standard oversampling
    'kb_J': (4, 4),             # Kaiser-Bessel kernel width
    'kb_alpha': (2.34 * 4, 2.34 * 4), # Kaiser-Bessel alpha (beta = pi * sqrt(J^2/os^2 * (os-0.5)^2 - 0.8) )
                                     # A common heuristic is alpha = 2.34 * J for os=2
    'Ld': (1024, 1024), # More typical table length for KB interpolation
    'kb_m': (0.0, 0.0), # Order of KB kernel
    'device': device
    # No density_comp_weights needed for forward simulation
}

# Ensure NUFFT2D is imported correctly from reconlib.nufft
try:
    nufft_op_sim = NUFFT2D(**nufft_params_sim)
    print("NUFFT2D operator for simulation created successfully.")
except Exception as e:
    print(f"Error creating NUFFT2D operator: {e}")
    nufft_op_sim = None


# --- 4. Simulate K-Space Data ---
print("\n--- Simulating K-Space Data ---")
if nufft_op_sim is not None:
    k_space_data_clean = nufft_op_sim.forward(phantom_complex)
    print(f"Clean k-space data shape: {k_space_data_clean.shape}")

    # Add complex Gaussian noise
    noise_level_real = 0.02 # Percentage of max signal amplitude for std dev
    noise_std_real = noise_level_real * torch.max(torch.abs(k_space_data_clean))
    
    # Generate noise for real and imaginary parts separately
    noise_r = torch.randn_like(k_space_data_clean.real) * noise_std_real
    noise_i = torch.randn_like(k_space_data_clean.imag) * noise_std_real
    complex_noise = torch.complex(noise_r, noise_i).to(device)
    
    k_space_data_noisy = k_space_data_clean + complex_noise
    print(f"Noisy k-space data shape: {k_space_data_noisy.shape}")

    # Store key variables for later cells
    # %store phantom_complex k_trajectory k_space_data_noisy IMAGE_SIZE device
else:
    print("NUFFT simulation operator not available. Cannot simulate k-space data.")
    # Create placeholders if simulation failed
    k_space_data_clean = torch.zeros((k_trajectory.shape[0]), dtype=torch.complex64, device=device)
    k_space_data_noisy = torch.zeros((k_trajectory.shape[0]), dtype=torch.complex64, device=device)

# Display k-space data magnitude (log scale)
plt.figure(figsize=(6, 5))
if k_space_data_noisy.numel() > 0:
    plt.scatter(k_trajectory[:,0].cpu(), k_trajectory[:,1].cpu(), c=torch.log(torch.abs(k_space_data_noisy.cpu()) + 1e-9), s=1, cmap='viridis')
    plt.colorbar(label="log|k-space data|")
else:
    plt.text(0.5, 0.5, "K-space data not generated", ha='center', va='center')
plt.title("Simulated K-Space Data (Log Magnitude)")
plt.xlabel("kx")
plt.ylabel("ky")
plt.axis('square')
plt.show()

print("\n--- Data Generation Cell Complete ---")

## 3. Compute Voronoi Weights

Here, we calculate the density compensation factors (DCF) using Voronoi tessellation. Each k-space sample point is considered a generator site for a Voronoi cell. The area (in 2D) or volume (in 3D) of this cell is inversely proportional to the local sampling density. The DCF is typically taken as the reciprocal of this area/volume.

We use the `compute_voronoi_density_weights_pytorch` function from `reconlib.voronoi.density_weights_pytorch`. This function requires the k-space sample locations (`points`) and the `bounds` of the k-space region being considered.

In [None]:
# --- Compute Voronoi Weights ---
print("\n--- Computing Voronoi Weights ---")

# Define bounds for Voronoi tessellation
# k_trajectory is expected to be normalized between -0.5 and 0.5 in each dimension.
voronoi_bounds = torch.tensor([[-0.5, -0.5], [0.5, 0.5]], dtype=torch.float32, device=device)

# Ensure k_trajectory is on the correct device and is float32 for Voronoi computation.
# It should already be from the previous cell, but as a safeguard:
k_trajectory_voronoi = k_trajectory.to(device=device, dtype=torch.float32)

try:
    # The import for compute_voronoi_density_weights was updated by a previous subtask
    # to be compute_voronoi_density_weights_pytorch
    from reconlib.voronoi.density_weights_pytorch import compute_voronoi_density_weights_pytorch
    
    print(f"Computing Voronoi weights for {k_trajectory_voronoi.shape[0]} points...")
    # The function in density_weights_pytorch.py is named compute_voronoi_density_weights_pytorch
    voronoi_weights = compute_voronoi_density_weights_pytorch(
        points=k_trajectory_voronoi,
        bounds=voronoi_bounds,
        space_dim=2
    )
    print(f"Computed Voronoi weights. Shape: {voronoi_weights.shape}, Sum: {torch.sum(voronoi_weights):.4f}")
    
    # Ensure weights are positive, as expected for DCF
    if torch.any(voronoi_weights <= 0):
        print("Warning: Some Voronoi weights are not positive. Clamping to a small positive value.")
        voronoi_weights = torch.clamp(voronoi_weights, min=1e-9) # Ensure positive for sqrt later if needed by some algos

except ImportError:
    print("Failed to import compute_voronoi_density_weights_pytorch. Voronoi weights cannot be computed.")
    voronoi_weights = torch.ones(k_trajectory.shape[0], dtype=torch.float32, device=device) / k_trajectory.shape[0] # Placeholder
except Exception as e:
    print(f"Error computing Voronoi weights: {e}")
    print("Using placeholder weights (uniform).")
    voronoi_weights = torch.ones(k_trajectory.shape[0], dtype=torch.float32, device=device) / k_trajectory.shape[0] # Placeholder


# Visualize Voronoi weights
plt.figure(figsize=(6, 5))
if k_trajectory.shape[0] > 0 and voronoi_weights.numel() == k_trajectory.shape[0]:
    # Use log of weights for better visualization if range is large
    scatter = plt.scatter(
        k_trajectory[:, 0].cpu().numpy(),
        k_trajectory[:, 1].cpu().numpy(),
        c=torch.log(voronoi_weights.cpu() + 1e-9), # Add epsilon for log
        s=5,
        cmap='viridis'
    )
    plt.colorbar(scatter, label="log(Voronoi Weights)")
else:
    plt.text(0.5, 0.5, "Voronoi weights not computed or incompatible", ha='center', va='center')

plt.title("K-Space Trajectory Colored by Log Voronoi Weights")
plt.xlabel("kx")
plt.ylabel("ky")
plt.axis('square')
plt.show()

# %store voronoi_weights # Optional: store for use in later sessions if notebook is split

print("\n--- Voronoi Weight Computation Cell Complete ---")

## 4. Perform Reconstructions

### 4.a Helper function for NUFFT parameters

In [None]:
# --- Helper function for NUFFT parameters ---
print("\n--- Defining NUFFT Parameter Helper ---")

def get_nufft_params(current_k_trajectory, current_image_size, current_device, dcf_weights=None):
    """
    Helper function to get a dictionary of NUFFT parameters.
    Allows specifying density compensation weights.
    Assumes NUFFT2D is to be used.
    """
    base_params = {
        'image_shape': (current_image_size, current_image_size),
        'k_trajectory': current_k_trajectory,
        'oversamp_factor': (2.0, 2.0),
        'kb_J': (4, 4),
        'kb_alpha': (2.34 * 4, 2.34 * 4), # Standard alpha for J=4, os=2
        'Ld': (1024, 1024), # Table length for interpolation
        'kb_m': (0.0, 0.0),
        'device': current_device
    }
    if dcf_weights is not None:
        base_params['density_comp_weights'] = dcf_weights
    return base_params

# Test the helper function (optional, for immediate feedback in notebook)
# These variables (k_trajectory, IMAGE_SIZE, device, voronoi_weights) must have been defined in previous cells.
try:
    params_with_dcf = get_nufft_params(k_trajectory, IMAGE_SIZE, device, dcf_weights=voronoi_weights)
    print(f"NUFFT params with DCF: image_shape={params_with_dcf['image_shape']}, "
          f"density_comp_weights_shape={params_with_dcf['density_comp_weights'].shape if 'density_comp_weights' in params_with_dcf and hasattr(params_with_dcf['density_comp_weights'], 'shape') else 'Not set'}")
    
    params_no_dcf = get_nufft_params(k_trajectory, IMAGE_SIZE, device)
    print(f"NUFFT params without DCF: image_shape={params_no_dcf['image_shape']}, "
          f"density_comp_weights_shape={params_no_dcf['density_comp_weights'].shape if 'density_comp_weights' in params_no_dcf and hasattr(params_no_dcf['density_comp_weights'], 'shape') else 'Not set'}")
    
    # Try instantiating NUFFT2D with these params to catch errors early
    nufft_op_test_dcf = NUFFT2D(**params_with_dcf)
    print("Successfully created NUFFT2D with DCF weights using helper.")
    nufft_op_test_no_dcf = NUFFT2D(**params_no_dcf)
    print("Successfully created NUFFT2D without DCF weights using helper.")
    
except NameError as e:
    print(f"A required variable (k_trajectory, IMAGE_SIZE, device, or voronoi_weights) might not be defined yet: {e}")
    print("This cell should be run after the data generation and Voronoi weights cells.")
except Exception as e:
    print(f"Error during NUFFT parameter helper function test: {e}")

print("\n--- NUFFT Parameter Helper Cell Complete ---")

### 4.b Method 1: Conjugate Gradient with Voronoi Weights

In [None]:
# --- Method 1: Conjugate Gradient with Voronoi Weights ---
print("\n--- Running Reconstruction Method 1: CG with Voronoi Weights ---")

# Parameters for CG reconstruction
cg_iters = 20 # Number of conjugate gradient iterations
cg_tol = 1e-6 # Tolerance for CG convergence

# Ensure necessary variables from previous cells are available
try:
    if 'k_space_data_noisy' not in globals() or \
       'k_trajectory' not in globals() or \
       'IMAGE_SIZE' not in globals() or \
       'device' not in globals() or \
       'voronoi_weights' not in globals() or \
       'get_nufft_params' not in globals() or \
       'NUFFT2D' not in globals() or \
       'conjugate_gradient_reconstruction' not in globals():
        raise NameError("One or more required variables or functions are not defined. Please run previous cells.")

    # 1. Get NUFFT parameters with Voronoi weights
    nufft_params_voronoi = get_nufft_params(
        current_k_trajectory=k_trajectory,
        current_image_size=IMAGE_SIZE,
        current_device=device,
        dcf_weights=voronoi_weights
    )
    print("NUFFT parameters for Voronoi-weighted CG obtained.")

    # 2. Instantiate NUFFT2D operator
    nufft_op_voronoi = NUFFT2D(**nufft_params_voronoi)
    print("NUFFT2D operator with Voronoi weights created.")

    # 3. Perform Conjugate Gradient reconstruction
    print(f"Starting Voronoi-weighted CG reconstruction (iters={cg_iters}, tol={cg_tol:.1e})...")
    recon_cg_voronoi = conjugate_gradient_reconstruction(
        kspace_data=k_space_data_noisy,
        # sampling_points is part of nufft_params_voronoi and handled by NUFFT2D init
        # image_shape is part of nufft_params_voronoi and handled by NUFFT2D init
        nufft_operator_class=NUFFT2D, # Pass the class itself
        nufft_kwargs=nufft_params_voronoi, # Pass the dict of args for NUFFT2D
        # use_voronoi flag in conjugate_gradient_reconstruction is not strictly needed
        # if density_comp_weights are directly in nufft_kwargs, but set for clarity
        use_voronoi=True, 
        voronoi_weights=voronoi_weights, # Redundant if in nufft_kwargs, but doesn't hurt
        max_iters=cg_iters,
        tol=cg_tol
    )
    print("Voronoi-weighted CG reconstruction complete.")
    print(f"Reconstructed image shape: {recon_cg_voronoi.shape}")

    # 4. Display the reconstructed image
    plt.figure(figsize=(5, 5))
    plt.imshow(torch.abs(recon_cg_voronoi.cpu()).numpy(), cmap='gray')
    plt.title(f"Recon: CG with Voronoi Weights ({cg_iters} iters)")
    plt.axis('off')
    plt.show()
    
    # %store recon_cg_voronoi # Optional

except NameError as e:
    print(f"NameError: {e}. Ensure all previous cells have been run successfully.")
    recon_cg_voronoi = torch.zeros((IMAGE_SIZE, IMAGE_SIZE), dtype=torch.complex64, device=device) # Placeholder
except Exception as e:
    print(f"An error occurred during Voronoi-weighted CG reconstruction: {e}")
    recon_cg_voronoi = torch.zeros((IMAGE_SIZE, IMAGE_SIZE), dtype=torch.complex64, device=device) # Placeholder

print("\n--- CG with Voronoi Weights Cell Complete ---")

### 4.c Method 2: Simple Gridding (Adjoint NUFFT with default/no explicit DCF)

This is the most basic reconstruction approach. It involves applying the adjoint NUFFT operation directly to the (noisy) k-space data. The `NUFFT2D` operator, when initialized without explicit `density_comp_weights`, might use a simple internal default DCF (like a radial ramp for radial data) or no DCF at all, depending on its implementation. This method is fast but often yields images with density-related artifacts and lower quality compared to iterative methods with accurate DCF.

In [None]:
# --- Method 2: Simple Gridding (Adjoint NUFFT with default/no explicit DCF) ---
print("\n--- Running Reconstruction Method 2: Simple Gridding ---")

try:
    if 'k_space_data_noisy' not in globals() or \
       'k_trajectory' not in globals() or \
       'IMAGE_SIZE' not in globals() or \
       'device' not in globals() or \
       'get_nufft_params' not in globals() or \
       'NUFFT2D' not in globals():
        raise NameError("One or more required variables or functions are not defined. Please run previous cells.")

    # 1. Get NUFFT parameters *without* explicit DCF weights
    # This will rely on NUFFT2D's default behavior (e.g., simple radial DCF or none)
    nufft_params_basic = get_nufft_params(
        current_k_trajectory=k_trajectory,
        current_image_size=IMAGE_SIZE,
        current_device=device,
        dcf_weights=None # Explicitly None
    )
    print("NUFFT parameters for simple gridding (default DCF) obtained.")

    # 2. Instantiate NUFFT2D operator
    nufft_op_basic = NUFFT2D(**nufft_params_basic)
    print("NUFFT2D operator for simple gridding created.")

    # 3. Perform reconstruction using simple adjoint (gridding)
    print("Starting simple gridding (adjoint NUFFT)...")
    # The adjoint of NUFFT2D applies its internally estimated DCF if no density_comp_weights were passed at init.
    recon_gridding = nufft_op_basic.adjoint(k_space_data_noisy)
    print("Simple gridding complete.")
    print(f"Reconstructed image shape: {recon_gridding.shape}")

    # 4. Display the reconstructed image
    plt.figure(figsize=(5, 5))
    plt.imshow(torch.abs(recon_gridding.cpu()).numpy(), cmap='gray')
    plt.title("Recon: Simple Gridding (Adjoint NUFFT)")
    plt.axis('off')
    plt.show()
    
    # %store recon_gridding # Optional

except NameError as e:
    print(f"NameError: {e}. Ensure all previous cells have been run successfully.")
    recon_gridding = torch.zeros((IMAGE_SIZE, IMAGE_SIZE), dtype=torch.complex64, device=device) # Placeholder
except Exception as e:
    print(f"An error occurred during simple gridding reconstruction: {e}")
    recon_gridding = torch.zeros((IMAGE_SIZE, IMAGE_SIZE), dtype=torch.complex64, device=device) # Placeholder

print("\n--- Simple Gridding Cell Complete ---")

### 4.d Method 3: Conjugate Gradient with default/no explicit DCF

This method uses the Conjugate Gradient algorithm to iteratively solve the least-squares problem: `argmin_x ||Ax - y||^2`, where `A` is the NUFFT operator (using its default internal DCF, if any, as no explicit Voronoi weights are provided) and `y` is the noisy k-space data. This should improve upon simple gridding but might not perform as well as CG with more accurate (Voronoi) DCF.

In [None]:
# --- Method 3: Conjugate Gradient with default/no explicit DCF ---
print("\n--- Running Reconstruction Method 3: CG with Default/No Explicit DCF ---")

# Parameters for CG reconstruction (can be same as for Voronoi CG)
cg_iters_basic = 20 # Number of conjugate gradient iterations
cg_tol_basic = 1e-6 # Tolerance for CG convergence

try:
    if 'k_space_data_noisy' not in globals() or \
       'k_trajectory' not in globals() or \
       'IMAGE_SIZE' not in globals() or \
       'device' not in globals() or \
       'get_nufft_params' not in globals() or \
       'NUFFT2D' not in globals() or \
       'conjugate_gradient_reconstruction' not in globals():
        raise NameError("One or more required variables or functions are not defined. Please run previous cells.")

    # 1. Get NUFFT parameters *without* explicit DCF weights
    nufft_params_cg_basic = get_nufft_params(
        current_k_trajectory=k_trajectory,
        current_image_size=IMAGE_SIZE,
        current_device=device,
        dcf_weights=None # Explicitly None for default NUFFT DCF behavior
    )
    print("NUFFT parameters for CG with default DCF obtained.")

    # 2. Instantiate NUFFT2D operator
    # NUFFT2D's adjoint will use its internal default DCF if density_comp_weights is None
    nufft_op_cg_basic = NUFFT2D(**nufft_params_cg_basic)
    print("NUFFT2D operator for CG with default DCF created.")

    # 3. Perform Conjugate Gradient reconstruction
    print(f"Starting CG reconstruction with default DCF (iters={cg_iters_basic}, tol={cg_tol_basic:.1e})...")
    recon_cg_basic_dcf = conjugate_gradient_reconstruction(
        kspace_data=k_space_data_noisy,
        nufft_operator_class=NUFFT2D,
        nufft_kwargs=nufft_params_cg_basic, # These kwargs do NOT include explicit Voronoi weights
        use_voronoi=False, # Explicitly False
        voronoi_weights=None, # Explicitly None
        max_iters=cg_iters_basic,
        tol=cg_tol_basic
    )
    print("CG reconstruction with default DCF complete.")
    print(f"Reconstructed image shape: {recon_cg_basic_dcf.shape}")

    # 4. Display the reconstructed image
    plt.figure(figsize=(5, 5))
    plt.imshow(torch.abs(recon_cg_basic_dcf.cpu()).numpy(), cmap='gray')
    plt.title(f"Recon: CG with Default DCF ({cg_iters_basic} iters)")
    plt.axis('off')
    plt.show()
    
    # %store recon_cg_basic_dcf # Optional

except NameError as e:
    print(f"NameError: {e}. Ensure all previous cells have been run successfully.")
    recon_cg_basic_dcf = torch.zeros((IMAGE_SIZE, IMAGE_SIZE), dtype=torch.complex64, device=device) # Placeholder
except Exception as e:
    print(f"An error occurred during CG reconstruction with default DCF: {e}")
    recon_cg_basic_dcf = torch.zeros((IMAGE_SIZE, IMAGE_SIZE), dtype=torch.complex64, device=device) # Placeholder

print("\n--- CG with Default DCF Cell Complete ---")

## 5. Visualize and Compare Results

Finally, we display the original phantom and the images reconstructed by the three different methods. Visual inspection helps in qualitatively assessing the impact of different reconstruction strategies and density compensation techniques. We also compute quantitative metrics (MSE, PSNR, SSIM) to provide an objective measure of reconstruction fidelity against the original phantom. Lower MSE, higher PSNR, and SSIM closer to 1 indicate better reconstruction quality.

In [None]:
# --- 5. Visualize and Compare Results ---
print("\n--- Visualizing and Comparing Reconstruction Results ---")

try:
    # Ensure all required variables are defined from previous cells
    if 'phantom_complex' not in globals() or \
       'recon_cg_voronoi' not in globals() or \
       'recon_gridding' not in globals() or \
       'recon_cg_basic_dcf' not in globals() or \
       'IMAGE_SIZE' not in globals():
        raise NameError("One or more required reconstruction results or variables are not defined. Please run all previous cells.")

    # Prepare original phantom for comparison (use absolute value, as NUFFT output is complex)
    original_phantom_abs = torch.abs(phantom_complex.cpu())

    # Prepare reconstructed images (use absolute values)
    recon_cg_voronoi_abs = torch.abs(recon_cg_voronoi.cpu())
    recon_gridding_abs = torch.abs(recon_gridding.cpu())
    recon_cg_basic_dcf_abs = torch.abs(recon_cg_basic_dcf.cpu())
    
    reconstructions = {
        "Original Phantom": original_phantom_abs,
        "CG + Voronoi Weights": recon_cg_voronoi_abs,
        "Simple Gridding (Adjoint)": recon_gridding_abs,
        "CG + Default/No DCF": recon_cg_basic_dcf_abs
    }

    # --- Plotting Side-by-Side ---
    num_images = len(reconstructions)
    fig, axes = plt.subplots(1, num_images, figsize=(num_images * 4, 4))
    fig.suptitle("Comparison of Reconstruction Methods", fontsize=16)

    for i, (title, img) in enumerate(reconstructions.items()):
        ax = axes[i]
        im = ax.imshow(img.numpy(), cmap='gray')
        ax.set_title(title)
        ax.axis('off')
        # fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) # Optional colorbar per image

    plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust layout to make space for suptitle
    plt.show()

    # --- Quantitative Comparison (Optional: MSE, PSNR) ---
    print("\n--- Quantitative Metrics (vs Original Phantom) ---")
    try:
        from reconlib.metrics.image_metrics import mse, psnr, ssim # Assuming these are available
        
        # Ensure original_phantom_abs is on the same device as recons if metrics require it,
        # but typically metrics are computed on CPU tensors.
        # The images are already .cpu().numpy() or .cpu() above.

        data_range = original_phantom_abs.max() - original_phantom_abs.min()
        if data_range == 0: data_range = 1.0 # Avoid division by zero if phantom is flat
        
        print(f"{'Method':<30} | {'MSE':<10} | {'PSNR (dB)':<10} | {'SSIM':<10}")
        print("-" * 65)

        for title, recon_img_abs in reconstructions.items():
            if title == "Original Phantom":
                continue # Skip comparing original to itself for these metrics

            # Ensure recon_img_abs is also a CPU tensor for metric functions
            current_mse = mse(original_phantom_abs, recon_img_abs).item()
            current_psnr = psnr(original_phantom_abs, recon_img_abs, data_range=data_range).item()
            # For SSIM, ensure images are suitable (e.g., data_range, potentially normalize if needed by specific SSIM impl)
            # ssim function in reconlib.metrics.image_metrics expects PyTorch tensors.
            current_ssim = ssim(original_phantom_abs.unsqueeze(0).unsqueeze(0), # Add batch/channel for some SSIM versions
                                recon_img_abs.unsqueeze(0).unsqueeze(0), 
                                data_range=data_range,
                                gaussian_kernel=True, # Common default
                                kernel_size=7 # Common default
                               ).item()

            print(f"{title:<30} | {current_mse:<10.4e} | {current_psnr:<10.2f} | {current_ssim:<10.4f}")

    except ImportError:
        print("reconlib.metrics.image_metrics not found. Skipping quantitative metrics.")
    except Exception as e:
        print(f"Error calculating metrics: {e}")
        
except NameError as e:
    print(f"NameError: {e}. Ensure all previous cells have been run successfully to generate reconstruction variables.")
except Exception as e:
    print(f"An error occurred during visualization or comparison: {e}")

print("\n--- Visualization and Comparison Cell Complete ---")