# Non-Cartesian Phantom G-Factor Comparison

This notebook demonstrates the comparison of g-factor calculation methods (PMR vs our diagnostic approach) for different numbers of noise replicas (N) using a non-cartesian phantom dataset.

The experiment compares:
- **PMR Method**: Pseudo-Multiple Replica approach using N noise replicas
- **Our Method**: Diagnostic Hutchinson's method using N random vectors

## Setup

In [None]:
import torch
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import time

from mr_recon.gfactor import gfactor_SENSE_diag, gfactor_SENSE_PMR
from mr_recon.fourier import sigpy_nufft
from mr_recon.linops import sense_linop, batching_params
from mr_recon.recons import CG_SENSE_recon, doubleCG_inv_op_builder

# Set dark theme for plots
plt.style.use('dark_background')

## Configuration

Set the experiment parameters:

In [None]:
# --- Experiment Configuration ---
R = 2                    # Acceleration factor
lamda_l2 = 0.1          # L2 regularization parameter
N_values = [10, 20, 50] # N values for comparison (small for demo)
N_ref = 100             # Number of replicas for reference calculation
display_mode = 'inv_g'  # 'g' or 'inv_g' (1/g-factor)
max_display_window = 'ref'  # 'ref', 'pmr', 'ours', or numeric value

# Reconstruction parameters
max_iter = 100
max_eigen = 1  # Estimate max eigenvalue automatically
tol = 1e-2
sigma = 1e-2   # Noise standard deviation

# Output directory
output_dir = f"experiments/noncartesian_phantom/results/R{R}_L{lamda_l2}"
os.makedirs(output_dir, exist_ok=True)

print(f"Experiment configuration:")
print(f"  R = {R}, λ = {lamda_l2}")
print(f"  N values = {N_values}")
print(f"  N_ref = {N_ref}")
print(f"  Display mode = {display_mode}")
print(f"  Output directory = {output_dir}")

## Data Loading

Load the non-cartesian phantom dataset:

In [None]:
# --- Data Loading ---
data_dir = "experiments/noncartesian_phantom/data"
torch_dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
complex_dtype = torch.complex64
print(f"Using device: {torch_dev}")

# Load data files
img = np.load(os.path.join(data_dir, "img.npy"))
mps = np.load(os.path.join(data_dir, "mps.npy"))
ksp = np.load(os.path.join(data_dir, "ksp.npy"))
dcf = np.load(os.path.join(data_dir, "dcf.npy"))
trj = np.load(os.path.join(data_dir, "trj.npy"))

# Convert to torch tensors
img_torch = torch.tensor(img, dtype=complex_dtype, device=torch_dev)
mps_torch = torch.tensor(mps, dtype=complex_dtype, device=torch_dev)
ksp_torch_orig = torch.tensor(ksp, dtype=complex_dtype, device=torch_dev)
dcf_torch_orig = torch.tensor(dcf, dtype=torch.float32, device=torch_dev)
trj_torch_orig = torch.tensor(trj, dtype=torch.float32, device=torch_dev)

# Get dimensions
im_size = img_torch.shape[-2:]
C = mps_torch.shape[0]

print(f"Image size: {im_size}")
print(f"Number of coils: {C}")
print(f"Trajectory shape: {trj_torch_orig.shape}")
print(f"K-space shape: {ksp_torch_orig.shape}")

## Accelerated Data Preparation

Create the accelerated trajectory and k-space data:

In [None]:
# --- SENSE Operator Setup ---
# For non-Cartesian trajectories, acceleration is achieved by retaining every R-th shot/interleave
# Trajectory shape is [points_per_shot, num_shots, 2], so we subsample dimension 1 (shots)
# K-space shape is [num_coils, points_per_shot, num_shots], so we subsample dimension 2 (shots)
# DCF shape is [points_per_shot, num_shots], so we subsample dimension 1 (shots)

trj_acc = trj_torch_orig[:, ::R, :]  # Subsample shots
dcf_acc = dcf_torch_orig[:, ::R]     # Subsample DCF
ksp_acc = ksp_torch_orig[:, :, ::R] # Subsample k-space

print(f"Accelerated trajectory shape: {trj_acc.shape}")
print(f"Accelerated k-space shape: {ksp_acc.shape}")
print(f"Acceleration factor: {trj_torch_orig.shape[1] / trj_acc.shape[1]:.1f}x")

## SENSE Operators

Set up the SENSE linear operators for both fully sampled and accelerated reconstructions:

In [None]:
# Create NUFFT operator
nufft = sigpy_nufft(im_size, width=4)
bparams = batching_params(C)

# SENSE operators
A_acc = sense_linop(im_size, trj_acc, mps_torch, dcf_acc, nufft, bparams=bparams, use_toeplitz=True)
A = sense_linop(im_size, trj_torch_orig, mps_torch, dcf_torch_orig, nufft, bparams=bparams, use_toeplitz=True)

print("SENSE operators created successfully")

## Reconstruct Accelerated Image

Perform the accelerated reconstruction:

In [None]:
# --- Reconstruct Image for Plotting ---
print("Reconstructing accelerated image...")
start_time = time.time()
recon_image = CG_SENSE_recon(A_acc, ksp_acc, max_iter, lamda_l2, max_eigen, verbose=False, tolerance=tol)
recon_time = time.time() - start_time
print(f"Reconstruction completed in {recon_time:.2f} seconds")

# Calculate reconstruction metrics
recon_magnitude = recon_image.abs()
img_magnitude = img_torch.abs()
nmse = torch.norm(recon_magnitude - img_magnitude)**2 / torch.norm(img_magnitude)**2
print(f"NMSE: {nmse.item():.6f}")

## G-Factor Calculation Functions

Define the reconstruction functions for g-factor calculation:

In [None]:
# --- G-Factor Calculation Functions ---
recon_acc_func = lambda ksp_in: CG_SENSE_recon(A_acc, ksp_in, max_iter, lamda_l2, max_eigen, verbose=False, tolerance=tol)
recon_ref_func = lambda ksp_in: CG_SENSE_recon(A, ksp_in, max_iter, lamda_l2, max_eigen, verbose=False, tolerance=tol)

print("Reconstruction functions defined")

## Reference G-Factor Calculation

Calculate the high-quality reference g-factor using PMR method:

In [None]:
# --- Calculate High-Quality Reference ---
print(f"Calculating reference g-factor with N={N_ref}...")
start_time = time.time()
g_ref_raw = gfactor_SENSE_PMR(R_ref=recon_ref_func, R_acc=recon_acc_func, ksp_ref=ksp_torch_orig, ksp_acc=ksp_acc, noise_var=sigma**2, n_replicas=N_ref, verbose=True)
time_ref = time.time() - start_time
print(f"Reference calculation completed in {time_ref:.2f} seconds")

## N Comparison Loop

Compare g-factor calculations for different values of N:

In [None]:
# --- Loop over N values ---
results = {'N': [], 'Time_PMR': [], 'Time_Diag': [], 'g_PMR': [], 'g_Diag': []}

for n in N_values:
    print(f"\nCalculating for N={n}...")
    results['N'].append(n)
    
    # PMR method
    start_time = time.time()
    g_pmr_raw = gfactor_SENSE_PMR(R_ref=recon_ref_func, R_acc=recon_acc_func, ksp_ref=ksp_torch_orig, ksp_acc=ksp_acc, noise_var=sigma**2, n_replicas=n, verbose=False)
    results['Time_PMR'].append(time.time() - start_time)
    results['g_PMR'].append(torch.nan_to_num(g_pmr_raw, nan=1.0))
    print(f"  PMR: {results['Time_PMR'][-1]:.2f}s")
    
    # Our diagnostic method (Hutchinson's)
    AHA_inv = doubleCG_inv_op_builder(A=A, dcf=dcf_torch_orig, max_iter=max_iter, lamda_l2=lamda_l2, max_eigen=max_eigen, tolerance=tol, verbose=False)
    AHA_inv_acc = doubleCG_inv_op_builder(A=A_acc, dcf=dcf_acc, max_iter=max_iter, lamda_l2=lamda_l2, max_eigen=max_eigen, tolerance=tol, verbose=False)

    start_time = time.time()
    g_diag_raw = gfactor_SENSE_diag(AHA_inv_ref=AHA_inv, AHA_inv_acc=AHA_inv_acc, inp_example=torch.zeros(im_size, device=torch_dev, dtype=complex_dtype), n_replicas=n, sigma=sigma, rnd_vec_type='complex', verbose=False)
    results['Time_Diag'].append(time.time() - start_time)
    results['g_Diag'].append(torch.nan_to_num(g_diag_raw, nan=1.0))
    print(f"  Our method: {results['Time_Diag'][-1]:.2f}s")

print("\nAll N calculations completed!")

## Process G-Factor Results

Apply scaling and display mode transformations:

In [None]:
# --- Process g-factors based on display_mode ---
g_ref_unscaled = results['g_PMR'][N_values.index(min(N_values))] if min(N_values) in N_values else g_ref_raw

# --- Apply sqrt(R) scaling ---
g_ref = g_ref_unscaled / (R ** 0.5)
results['g_PMR'] = [g / (R ** 0.5) for g in results['g_PMR']]
results['g_Diag'] = [g / (R ** 0.5) for g in results['g_Diag']]

if display_mode == 'inv_g':
    g_ref = 1 / g_ref
    results['g_PMR'] = [1 / g for g in results['g_PMR']]
    results['g_Diag'] = [1 / g for g in results['g_Diag']]
    title_prefix = "1/G-Factor"
else: # display_mode == 'g'
    title_prefix = "G-Factor"

print(f"Results processed for display mode: {display_mode}")
print(f"Title prefix: {title_prefix}")

## Create Comparison Plot

Generate the visualization comparing different N values:

In [None]:
# --- Plotting ---
print("Generating comparison plot...")
num_comparisons = len(N_values)

fig, axes = plt.subplots(2, num_comparisons + 1, figsize=(5 * (num_comparisons + 1), 11))
fig.patch.set_facecolor('black')

# Determine global vmin and vmax for color consistency
vmin = g_ref.min()

try:
    # Check if max_display_window is a number
    vmax = float(max_display_window)
except ValueError:
    # It's a string like 'ref', 'pmr', or 'ours'
    if max_display_window == 'ref':
        vmax = g_ref.max()
    elif max_display_window == 'pmr':
        vmax = max(g.max() for g in results['g_PMR'])
    elif max_display_window == 'ours':
        vmax = max(g.max() for g in results['g_Diag'])
    else: # Default fallback to reference
        print(f"Warning: Invalid string for max_display_window. Defaulting to 'ref'.")
        vmax = g_ref.max()

print(f"Colorbar range: [{vmin:.3f}, {vmax:.3f}]")

# --- Row 1: Our Method + N values + Reference G-factor ---
# Columns 0-4: Our Method results
for i, n in enumerate(N_values):
    ax = axes[0, i]
    im = ax.imshow(results['g_Diag'][i].cpu().numpy(), cmap='jet', vmin=vmin, vmax=vmax)
    ax.set_title(f"Our Method (N={n})")
    ax.axis('off')
    
# Column 5: Reference G-Factor (moved to top row)
im = axes[0, num_comparisons].imshow(g_ref.cpu().numpy(), cmap='jet', vmin=vmin, vmax=vmax)
axes[0, num_comparisons].set_title(f"Reference {title_prefix}
(PMR, N={N_ref})")
axes[0, num_comparisons].axis('off')

# --- Row 2: PMR Method + N values + Reconstructed Image ---
# Columns 0-4: PMR Method results
for i, n in enumerate(N_values):
    ax = axes[1, i]
    im = ax.imshow(results['g_PMR'][i].cpu().numpy(), cmap='jet', vmin=vmin, vmax=vmax)
    ax.set_title(f"PMR (N={n})")
    ax.axis('off')
    
# Column 5: Reconstructed Image (moved to bottom row)
axes[1, num_comparisons].imshow(img_torch.abs().cpu().numpy(), cmap='gray')
axes[1, num_comparisons].set_title(f"Reconstructed Image (R={R})")
axes[1, num_comparisons].axis('off')

# Adjust layout to make space for the colorbar and title
fig.subplots_adjust(right=0.92, top=0.90)

# Add a single colorbar for all g-factor maps on the right side
cbar_ax = fig.add_axes([0.94, 0.15, 0.015, 0.7])
fig.colorbar(im, cax=cbar_ax)

fig.suptitle(f"{title_prefix} Calculation Convergence (R={R}, λ={lamda_l2})", fontsize=24)
plt.show()

print("Plot generated successfully!")

## Timing Analysis

Display the timing results:

In [None]:
# Display timing results
timing_df = pd.DataFrame({
    'N': results['N'],
    'Time_PMR': results['Time_PMR'],
    'Time_Our_Method': results['Time_Diag']
})

print(f"Reference G-Factor (PMR, N={N_ref}) Time: {time_ref:.4f}s\n")
print("--- Comparison Times ---")
print(timing_df.to_string(index=False))

# Plot timing comparison
plt.figure(figsize=(10, 6))
plt.plot(results['N'], results['Time_PMR'], 'o-', label='PMR Method', linewidth=2, markersize=8)
plt.plot(results['N'], results['Time_Diag'], 's-', label='Our Method', linewidth=2, markersize=8)
plt.xlabel('Number of Replicas (N)')
plt.ylabel('Computation Time (seconds)')
plt.title(f'G-Factor Computation Time vs N (R={R}, λ={lamda_l2})')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## Save Results (Optional)

Save the plots and timing data:

In [None]:
# --- Save Outputs ---
mode_tag = "g" if display_mode == "g" else "invG"
N_str = "N" + "-".join(str(n) for n in N_values)

# Save plot
plot_path_png = os.path.join(output_dir, f"N_comparison_plot_R{R}_L{lamda_l2}_{N_str}_mode_{mode_tag}_notebook.png")
plot_path_svg = os.path.join(output_dir, f"N_comparison_plot_R{R}_L{lamda_l2}_{N_str}_mode_{mode_tag}_notebook.svg")
plt.savefig(plot_path_png, dpi=300, bbox_inches='tight')
plt.savefig(plot_path_svg, format='svg', bbox_inches='tight')
print(f"Plot saved to {plot_path_png}")

# Save reconstructed image
recon_img_np = recon_image.abs().detach().cpu().numpy()
plt.figure(figsize=(6, 6))
plt.imshow(recon_img_np, cmap='gray')
plt.axis('off')
plt.title(f'Reconstructed Image (R={R})')
recon_png_path = os.path.join(output_dir, f"recon_image_R{R}_L{lamda_l2}_{N_str}_mode_{mode_tag}_notebook.png")
plt.savefig(recon_png_path, bbox_inches='tight', pad_inches=0, dpi=300)
plt.close()
print(f"Reconstructed image saved to {recon_png_path}")

# Save timing data
times_path = os.path.join(output_dir, f"N_comparison_times_R{R}_L{lamda_l2}_{N_str}_mode_{mode_tag}_notebook.txt")
with open(times_path, 'w') as f:
    f.write(f"Reference G-Factor (PMR, N={N_ref}) Time: {time_ref:.4f}s\n\n")
    f.write("--- Comparison Times ---\n")
    f.write(timing_df.to_string(index=False))
print(f"Timings saved to {times_path}")

print("\nAll results saved!")

## Summary

This notebook demonstrates:

1. **Non-cartesian phantom reconstruction** with SENSE acceleration (R={R})
2. **G-factor calculation comparison** between PMR and our diagnostic methods
3. **Convergence analysis** showing how results improve with increasing N
4. **Performance comparison** between the two approaches

### Key Findings:
- Both methods converge as N increases
- Our diagnostic method provides comparable accuracy to PMR
- Performance depends on the specific implementation and problem size

### Next Steps:
- Try different acceleration factors (R values)
- Test with real data instead of phantom
- Compare with Cartesian acquisitions