# MoDL Network Reconstruction Demo

This notebook demonstrates setting up and using a Model-based Deep Learning (MoDL) network for MRI reconstruction using `reconlib`. We will focus on a 2D example for clarity and speed.

The MoDL architecture alternates between a data consistency step (using the NUFFT operator) and a learned regularization step (a CNN denoiser).

## 1. Imports

In [None]:
import torch
import torch.nn as nn
import numpy as np
import math
import matplotlib.pyplot as plt

%matplotlib inline

# Adjust path to import from reconlib (if not installed as a package)
import sys
import os
if '..' not in sys.path:
    sys.path.append(os.path.abspath(os.path.join(os.path.dirname('__file__'), '..')))

from reconlib.operators import NUFFTOperator
from reconlib.deeplearning.models.resnet_denoiser import SimpleResNetDenoiser
from reconlib.deeplearning.models.modl_network import MoDLNet
from reconlib.deeplearning.datasets import MoDLDataset
try:
    from iternufft import generate_phantom_2d, generate_radial_trajectory_2d
except ImportError:
    print("WARN: iternufft.py not found or not in PYTHONPATH. Using dummy data generators.")
    def generate_phantom_2d(size, device='cpu'): return torch.rand((size,size), device=device) * 0.5
    def generate_radial_trajectory_2d(num_spokes, samples_per_spoke, device='cpu'): 
        return (torch.rand((num_spokes*samples_per_spoke, 2), device=device) - 0.5) * np.pi

## 2. Setup Device and Parameters

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# 2D Example Parameters
image_size = 64 # Keep it small for a quick demo
image_shape_2d = (image_size, image_size)
dim = len(image_shape_2d)

# K-space trajectory parameters (2D radial)
k_traj_params_2d = {'num_spokes': 32, 'samples_per_spoke': image_size}

# NUFFT Operator parameters
oversamp_factor = tuple([2.0] * dim)
kb_J = tuple([4] * dim)
kb_alpha = tuple([2.34 * J for J in kb_J])
Ld_table = tuple([2**8] * dim) # Table oversampling
Kd_grid = tuple(int(N * os) for N, os in zip(image_shape_2d, oversamp_factor))

nufft_op_params = {
    'oversamp_factor': oversamp_factor,
    'kb_J': kb_J,
    'kb_alpha': kb_alpha,
    'Ld': Ld_table,
    'Kd': Kd_grid,
    'kb_m': tuple([0.0]*dim),
    'n_shift': tuple([0.0]*dim)
}

# MoDL Network parameters
denoiser_channels = 1 # Working with magnitude images for simplicity in this demo
denoiser_internal_channels = 32
denoiser_num_blocks = 2
modl_iterations = 3 # Number of unrolled iterations (K)
lambda_dc_val = 0.05
cg_iterations_dc = 3

## 3. Create Dataset and DataLoader

In [None]:
# For this demo, we'll just use one sample from the dataset
demo_dataset = MoDLDataset(
    dataset_size=1, # Just one sample for demo
    image_shape=image_shape_2d,
    k_trajectory_func=generate_radial_trajectory_2d,
    k_trajectory_params=k_traj_params_2d,
    nufft_op_params=nufft_op_params,
    phantom_func=generate_phantom_2d,
    phantom_params={'size': image_size},
    noise_level_kspace=0.01,
    device=device
)

# Get the single sample
x0_initial, y_observed, x_true = demo_dataset[0]

print(f"Initial reconstruction (x0) shape: {x0_initial.shape}")
print(f"Observed k-space (y) shape: {y_observed.shape}")
print(f"Ground truth image (x_true) shape: {x_true.shape}")

## 4. Instantiate NUFFTOperator, Denoiser, and MoDLNet

In [None]:
# NUFFT Operator (re-use the one from dataset for consistency or create new)
nufft_op = demo_dataset.nufft_op 

# Denoiser CNN
denoiser = SimpleResNetDenoiser(
    in_channels=denoiser_channels, 
    out_channels=denoiser_channels,
    num_internal_channels=denoiser_internal_channels,
    num_blocks=denoiser_num_blocks
).to(device)

# MoDL Network
modl_network = MoDLNet(
    nufft_op=nufft_op,
    denoiser_cnn=denoiser,
    num_iterations=modl_iterations,
    lambda_dc_initial=lambda_dc_val,
    num_cg_iterations_dc=cg_iterations_dc
).to(device)

modl_network.eval() # Set to evaluation mode (as we are not training here)
print("MoDLNet instantiated.")

## 5. Perform Reconstruction (Inference)

Since we don't have a pre-trained model in this simple demo, the reconstruction will use the randomly initialized weights of the denoiser. The purpose is to show the pipeline structure.

In [None]:
with torch.no_grad():
    # MoDLNet expects initial_image_x0 of shape (*image_shape) and complex
    # Our x0_initial is already complex and has the correct shape
    # If denoiser_channels=1, MoDLNet's internal denoiser call needs to handle magnitude
    # For this demo, we'll assume x0_initial (complex) is passed.
    # The SimpleResNetDenoiser expects (N,C,H,W). MoDLNet's forward currently assumes single batch and handles unsqueezing.
    
    # If denoiser expects single channel (e.g. magnitude)
    if denoiser_channels == 1:
        print("Note: Using magnitude of x0 as input to MoDLNet due to denoiser_channels=1")
        # This is a simplification; typically complex data is fed through network parts
        # and denoiser might operate on real/imag channels or magnitude then recombine.
        # The current MoDLNet and SimpleResNetDenoiser setup might need adjustment for perfect complex handling.
        # For this demo, we pass complex x0, and if denoiser is 1-channel, it will likely take abs() internally or error.
        # Let's ensure the MoDLNet's denoiser call is robust or we adapt input here.    # For simplicity, we assume the MoDLNet's forward and internal denoiser call handle the channel logic.     # We pass the complex x0_initial that the DC block would use.
        reconstructed_image = modl_network(y_observed, x0_initial) 
    else: # denoiser_channels == 2 (expects real/imag)
        # The current MoDLNet and denoiser setup expects the denoiser_cnn to handle the input appropriately.
        # If x0_initial is complex (H,W), SimpleResNetDenoiser expects (N,C,H,W). This is handled in SimpleResNetDenoiser's forward.
        reconstructed_image = modl_network(y_observed, x0_initial)

print(f"Reconstructed image shape: {reconstructed_image.shape}")

## 6. Visualize Results

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(15, 5))

axs[0].imshow(x_true.abs().cpu().numpy(), cmap='gray')
axs[0].set_title('Ground Truth (x_true)')
axs[0].axis('off')

axs[1].imshow(x0_initial.abs().cpu().numpy(), cmap='gray')
axs[1].set_title('Initial Recon (A^H y)')
axs[1].axis('off')

axs[2].imshow(reconstructed_image.abs().cpu().numpy(), cmap='gray')
axs[2].set_title('MoDL Reconstructed (Untrained)')
axs[2].axis('off')

plt.tight_layout()
plt.show()

## 7. Adjointness Test (from previous notebook)

In [None]:
print("--- Adjointness Test for 2D NUFFTOperator ---")
x_2d_test = torch.randn(image_shape_2d, dtype=torch.complex64, device=device)
y_2d_test_shape = (nufft_op.k_trajectory.shape[0],)
y_2d_test = torch.randn(y_2d_test_shape, dtype=torch.complex64, device=device)

Ax_2d = nufft_op.op(x_2d_test)
Aty_2d = nufft_op.op_adj(y_2d_test)

lhs_2d = torch.sum(Ax_2d * torch.conj(y_2d_test))
rhs_2d = torch.sum(x_2d_test * torch.conj(Aty_2d))

print(f"LHS (<Ax, y>): {lhs_2d.item()}")
print(f"RHS (<x, A*y>): {rhs_2d.item()}")
abs_diff_2d = torch.abs(lhs_2d - rhs_2d).item()
rel_diff_2d = abs_diff_2d / torch.abs(lhs_2d).item() if torch.abs(lhs_2d) > 1e-9 else 0.0
print(f"Absolute Difference: {abs_diff_2d:.6e}")
print(f"Relative Difference: {rel_diff_2d:.6e}")