# Physics-Informed Neural Network (PINN) for B0 Off-Resonance Correction in EPI

This notebook demonstrates a simplified example of using a Physics-Informed Neural Network (PINN) to correct for B0 off-resonance artifacts in Echo-Planar Imaging (EPI).

**MRI Off-Resonance Artifacts in EPI:**
EPI is a fast MRI acquisition technique, but it's highly sensitive to magnetic field inhomogeneities (ΔB0). These inhomogeneities cause phase errors during the long EPI readouts, leading to geometric distortions (warping, stretching, signal pile-up/loss) in the reconstructed image, particularly in the phase-encoding direction.

**PINN Approach:**
A PINN attempts to solve inverse problems by incorporating known physics into the neural network's loss function. For B0 correction, the PINN can:
1. Use a neural network (e.g., a CNN) to represent the 'corrected' image.
2. Include a data fidelity term that ensures the network output, when transformed by the imaging operator (NUFFT with the actual k-space trajectory), matches the acquired k-space data.
3. Add physics-based loss terms that penalize solutions inconsistent with known physics. For B0 off-resonance, this involves using the B0 map to model expected phase distortions. The `B0OffResonanceLoss` term (currently a placeholder) would aim to quantify this inconsistency.

This demo uses placeholder physics models for simplicity and speed, but illustrates the overall framework.

## 1. Setup

In [None]:
import sys
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

# Add project root to sys.path to ensure reconlib can be found
# Assumes this notebook is in notebooks/ and project root is one level up.
project_root = os.path.abspath(os.path.join(os.getcwd(), '..')) # Use os.getcwd() for notebooks
if project_root not in sys.path:
    sys.path.append(project_root)

# --- Configuration for Fast Demo Mode ---
FAST_DEMO_MODE = True
# If True, uses placeholder NUFFT ops defined locally for speed.
# If False, attempts to use reconlib.nufft.NUFFT2D.

try:
    from reconlib.modalities.MRI.pinn_reconstructor import PINNReconstructor, SimpleCNN
    from reconlib.nufft import NUFFT2D
    from reconlib.nufft_multi_coil import MultiCoilNUFFTOperator
    from reconlib.modalities.MRI.physics_loss import PhysicsLossTerm, BlochResidualLoss, GIRFErrorLoss, B0OffResonanceLoss
    RECONLIB_AVAILABLE = True
    print("Successfully imported reconlib modules.")
except ImportError as e:
    print(f"Warning: Could not import all modules from reconlib. Error: {e}")
    RECONLIB_AVAILABLE = False
    if not FAST_DEMO_MODE:
        print("FATAL: reconlib modules required but not found, and FAST_DEMO_MODE is False. Exiting.")
        # In a notebook, we might not sys.exit, but indicate failure clearly.
        raise e

# Attempt to import scikit-image for Shepp-Logan phantom
try:
    from skimage.data import shepp_logan_phantom
    from skimage.transform import resize as sk_resize
    SKIMAGE_AVAILABLE = True
    print("Successfully imported scikit-image.")
except ImportError:
    print("Warning: scikit-image not found. Using a simple geometric phantom instead.")
    SKIMAGE_AVAILABLE = False

### Parameters

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

image_shape_2d = (64, 64)  # Y, X (increased size slightly for better visuals if possible)
num_coils = 2
num_k_points_epi = image_shape_2d[0] * image_shape_2d[1] # Full sampling for simplicity

# EPI Scan Parameters (for B0 loss)
scan_parameters_epi = {
    'echo_spacing_ms': 0.7,  # Typical echo spacing for EPI
    'phase_encoding_lines': image_shape_2d[0], # Ny, number of phase encoding lines
}
# General Scan Parameters (for Bloch loss, if used)
scan_parameters_general = {"TE": 0.03, "TR": 2.0, "flip_angle": 30, "T1_assumed": 1.0, "T2_assumed": 0.08}

pinn_config = {
    "learning_rate": 1e-3,
    "data_fidelity_weight": 1.0,
    "num_epochs": 20, # Can be small for demo, increase for better results
    "device": device
}

if FAST_DEMO_MODE:
    pinn_config["num_epochs"] = 5 # Even fewer for very fast demo
    print("FAST_DEMO_MODE: num_epochs reduced to 5.")

# NUFFT parameters (2D)
nufft_params_reconlib = {
    'oversamp_factor': (1.5, 1.5),
    'kb_J': (4, 4),
    'Ld': (32, 32) 
}

### Plotting Helper

In [None]:
def show_image(image_tensor, title="Image", cmap="gray", vmin=None, vmax=None, is_complex=True):
    plt.figure(figsize=(5,5))
    if is_complex:
        plot_data = torch.abs(image_tensor.cpu().detach()).numpy()
    else:
        plot_data = image_tensor.cpu().detach().numpy()
    plt.imshow(plot_data, cmap=cmap, vmin=vmin, vmax=vmax)
    plt.title(title)
    plt.axis('off')
    plt.show()

### Placeholder NUFFT Operators (for FAST_DEMO_MODE)

In [None]:
class PlaceholderNUFFT2D:
    def __init__(self, image_shape, k_trajectory, device='cpu', **kwargs):
        self.image_shape = image_shape
        self.k_trajectory = k_trajectory
        self.device = device
        # print("INFO: Using PlaceholderNUFFT2D.") # Less verbose for notebook

    def forward(self, x): # Image (Y,X) -> K-space (K,)
        return (torch.sum(x) * 0.0 + torch.zeros(self.k_trajectory.shape[0], dtype=torch.complex64, device=self.device))

    def adjoint(self, y): # K-space (K,) -> Image (Y,X)
        return (torch.sum(y) * 0.0 + torch.zeros(self.image_shape, dtype=torch.complex64, device=self.device))

class PlaceholderMultiCoilNUFFTOperator:
    def __init__(self, single_coil_nufft_op):
        self.single_coil_nufft_op = single_coil_nufft_op
        self.device = single_coil_nufft_op.device
        self.image_shape = single_coil_nufft_op.image_shape
        # print("INFO: Using PlaceholderMultiCoilNUFFTOperator.")

    def op(self, multi_coil_image_data): # (C,Y,X) -> (C,K)
        output_kspace_list = []
        for i in range(multi_coil_image_data.shape[0]):
            single_coil_image = multi_coil_image_data[i]
            output_kspace_list.append(self.single_coil_nufft_op.forward(single_coil_image))
        return torch.stack(output_kspace_list, dim=0)

    def op_adj(self, multi_coil_kspace_data): # (C,K) -> (C,Y,X)
        output_image_list = []
        for i in range(multi_coil_kspace_data.shape[0]):
            single_coil_kspace = multi_coil_kspace_data[i]
            output_image_list.append(self.single_coil_nufft_op.adjoint(single_coil_kspace))
        return torch.stack(output_image_list, dim=0)

## 2. Data Simulation

### Ground Truth Phantom

In [None]:
if SKIMAGE_AVAILABLE:
    gt_image_np = shepp_logan_phantom()
    gt_image_np = sk_resize(gt_image_np, image_shape_2d, anti_aliasing=True)
else:
    # Simple geometric phantom if scikit-image is not available
    gt_image_np = np.zeros(image_shape_2d, dtype=np.float32)
    s_y, e_y = image_shape_2d[0]//4, 3*image_shape_2d[0]//4
    s_x, e_x = image_shape_2d[1]//4, 3*image_shape_2d[1]//4
    gt_image_np[s_y:e_y, s_x:e_x] = 1.0
    gt_image_np[s_y+5:e_y-5, s_x+5:e_x-5] = 0.5 # Add some structure

gt_image = torch.tensor(gt_image_np, dtype=torch.complex64, device=device)
show_image(gt_image, title="Ground Truth Phantom")

### Coil Sensitivity Maps

In [None]:
coil_sensitivities = torch.zeros(num_coils, *image_shape_2d, dtype=torch.complex64, device=device)
for c in range(num_coils):
    yy, xx = torch.meshgrid(
        torch.linspace(-1, 1, image_shape_2d[0], device=device),
        torch.linspace(-1, 1, image_shape_2d[1], device=device),
        indexing='ij'
    )
    if c == 0:
        coil_sensitivities[c] = (1.0 - 0.8 * torch.abs(xx - 0.3)) * (1.0 - 0.8 * torch.abs(yy - 0.3))
    else:
        coil_sensitivities[c] = (1.0 - 0.8 * torch.abs(xx + 0.3)) * (1.0 - 0.8 * torch.abs(yy + 0.3))
    coil_sensitivities[c] = torch.clamp(coil_sensitivities[c], 0, 1).to(torch.complex64)

fig, axes = plt.subplots(1, num_coils, figsize=(5*num_coils, 5))
if num_coils ==1 : axes = [axes] # Make iterable if single coil
for i in range(num_coils):
    axes[i].imshow(torch.abs(coil_sensitivities[i]).cpu().numpy(), cmap='viridis')
    axes[i].set_title(f"Coil {i+1} Sensitivity (Mag)")
    axes[i].axis('off')
plt.show()

### B0 Map

In [None]:
b0_freq_max_hz = 30.0  # Max off-resonance in Hz
y_coords = torch.linspace(-1, 1, image_shape_2d[0], device=device)
x_coords = torch.linspace(-1, 1, image_shape_2d[1], device=device)
yy, xx = torch.meshgrid(y_coords, x_coords, indexing='ij')
b0_map = b0_freq_max_hz * yy  # Simple linear gradient in y for b0 map
b0_map = b0_map.to(device)
show_image(b0_map, title="B0 Map (Hz)", cmap='viridis', is_complex=False, vmin=-b0_freq_max_hz, vmax=b0_freq_max_hz)

### K-space Trajectories

In [None]:
trajectory_ideal = (torch.rand(num_k_points_epi, 2, device=device) - 0.5).float()

off_resonance_factor = 0.05 # Simulates slight distortion for GIRF placeholder
trajectory_actual = trajectory_ideal.clone()
trajectory_actual[:, 0] += off_resonance_factor * trajectory_ideal[:, 1]
trajectory_actual = torch.clamp(trajectory_actual, -0.5, 0.5)

plt.figure(figsize=(6,6))
plt.scatter(trajectory_ideal[:,1].cpu().numpy(), trajectory_ideal[:,0].cpu().numpy(), s=1, label='Ideal Traj.')
plt.scatter(trajectory_actual[:,1].cpu().numpy(), trajectory_actual[:,0].cpu().numpy(), s=1, label='Actual Traj.', alpha=0.5)
plt.xlabel('kx'); plt.ylabel('ky'); plt.title('K-space Trajectories'); plt.legend(); plt.axis('square');
plt.xlim([-0.6, 0.6]); plt.ylim([-0.6, 0.6]);
plt.show()

### Simulate Off-Resonant K-space Data

In [None]:
# Simplified effective TE for phase simulation on image
effective_te_for_simulation_s = (scan_parameters_epi['echo_spacing_ms'] / 1000.0) * \
                                (scan_parameters_epi['phase_encoding_lines'] / 2.0)

phase_shift = 2 * torch.pi * b0_map * effective_te_for_simulation_s
gt_image_offresonant = gt_image * torch.exp(1j * phase_shift)
show_image(gt_image_offresonant, title="Ground Truth with B0 Phase (Off-Resonant)")

gt_coil_images_offresonant = gt_image_offresonant.unsqueeze(0) * coil_sensitivities

# NUFFT operator for simulation (using ideal trajectory for simplicity here)
if not FAST_DEMO_MODE and RECONLIB_AVAILABLE:
    nufft_sim_single_coil = NUFFT2D(
        image_shape=image_shape_2d,
        k_trajectory=trajectory_ideal, 
        device=device,
        **nufft_params_reconlib
    )
    class ReconlibNUFFT2DAdapter:
        def __init__(self, nufft2d_instance: NUFFT2D):
            self.nufft_instance = nufft2d_instance; self.device = nufft2d_instance.device
            self.image_shape = nufft2d_instance.image_shape; self.k_trajectory = nufft2d_instance.k_trajectory
        def op(self, x): return self.nufft_instance.forward(x)
        def op_adj(self, y): return self.nufft_instance.adjoint(y)
    adapter_for_simulation = ReconlibNUFFT2DAdapter(nufft_sim_single_coil)
    mc_nufft_for_simulation = MultiCoilNUFFTOperator(adapter_for_simulation)
    print("INFO: Using reconlib.NUFFT2D for k-space simulation.")
else:
    nufft_sim_single_coil = PlaceholderNUFFT2D(image_shape_2d, trajectory_ideal, device)
    mc_nufft_for_simulation = PlaceholderMultiCoilNUFFTOperator(nufft_sim_single_coil)
    print("INFO: Using PlaceholderNUFFT for k-space simulation.")

true_kspace_data_mc = mc_nufft_for_simulation.op(gt_coil_images_offresonant)
print(f"Simulated multi-coil k-space data shape: {true_kspace_data_mc.shape}")

### Initial Image (Adjoint NUFFT)

In [None]:
# Adjoint NUFFT of the (potentially off-resonant) k-space data
# This uses the 'actual' trajectory because that's what the measurements correspond to.
if not FAST_DEMO_MODE and RECONLIB_AVAILABLE:
    nufft_adj_single_coil = NUFFT2D(image_shape_2d, trajectory_actual, device, **nufft_params_reconlib) # Use actual for adjoint
    adapter_adj = ReconlibNUFFT2DAdapter(nufft_adj_single_coil)
    mc_nufft_for_adjoint = MultiCoilNUFFTOperator(adapter_adj)
    print("INFO: Using reconlib.NUFFT2D for initial adjoint image.")
else:
    nufft_adj_single_coil = PlaceholderNUFFT2D(image_shape_2d, trajectory_actual, device)
    mc_nufft_for_adjoint = PlaceholderMultiCoilNUFFTOperator(nufft_adj_single_coil)
    print("INFO: Using PlaceholderNUFFT for initial adjoint image.")

initial_images_mc = mc_nufft_for_adjoint.op_adj(true_kspace_data_mc)
# Combine coils for initial viewing (RSS)
initial_image_rss = torch.sqrt(torch.sum(torch.abs(initial_images_mc)**2, dim=0))
show_image(initial_image_rss, title="Initial Image (Adjoint NUFFT + RSS)")

## 3. PINN Reconstructor Setup

In [None]:
# NUFFT operator for the PINN (uses actual_trajectory for data fidelity)
mc_nufft_pinn = mc_nufft_for_adjoint # Can reuse the same operator if params match

# CNN Model
cnn_model = SimpleCNN(n_channels_in=1, n_channels_out=num_coils, n_spatial_dims=2).to(device)

# Physics Loss Terms
physics_terms = []
if B0OffResonanceLoss is not None:
    b0_loss = B0OffResonanceLoss(b0_map=b0_map, scan_parameters_epi=scan_parameters_epi, weight=0.01)
    physics_terms.append(b0_loss)
    print(f"Added {b0_loss.name} to physics terms.")

if GIRFErrorLoss is not None: # Example: also include GIRF placeholder
    girf_loss = GIRFErrorLoss(weight=0.001)
    physics_terms.append(girf_loss)
    print(f"Added {girf_loss.name} to physics terms.")

# PINNReconstructor
reconstructor = PINNReconstructor(
    nufft_op=mc_nufft_pinn,
    cnn_model=cnn_model,
    config=pinn_config,
    physics_terms=physics_terms
)
print("PINNReconstructor instantiated.")

## 4. Run Reconstruction

In [None]:
# Training loop (copied and modified from PINNReconstructor.reconstruct for loss logging)
print(f"Starting reconstruction for {pinn_config['num_epochs']} epochs...")
optimizer = torch.optim.Adam(reconstructor.cnn_model.parameters(), lr=reconstructor.config.get("learning_rate", 1e-3))
loss_history = [] # To store loss components per epoch

# Prepare input for CNN (RSS of initial multi-coil image, then batched)
if reconstructor.cnn_model.n_channels_in == 1 and initial_images_mc.shape[0] > 1:
    cnn_input_image = torch.sqrt(torch.sum(torch.abs(initial_images_mc)**2, dim=0, keepdim=True))
elif initial_images_mc.shape[0] == reconstructor.cnn_model.n_channels_in:
    cnn_input_image = initial_images_mc
else:
    print(f"Warning: Channel mismatch for CNN input. Using RSS.")
    cnn_input_image = torch.sqrt(torch.sum(torch.abs(initial_images_mc)**2, dim=0, keepdim=True))
cnn_input_image_batched = cnn_input_image.unsqueeze(0).to(reconstructor.device)

# Data for loss function (kwargs)
loss_fn_kwargs = {
    "trajectory_ideal": trajectory_ideal.to(reconstructor.device),
    "trajectory_actual": trajectory_actual.to(reconstructor.device),
    "scan_parameters": scan_parameters_general, # For Bloch (if used)
    "b0_map": b0_map.to(reconstructor.device),
    "scan_parameters_epi": scan_parameters_epi # For B0OffResonanceLoss
}

reconstructor.cnn_model.train()
for epoch in range(pinn_config['num_epochs']):
    optimizer.zero_grad()
    
    predicted_output_batched = reconstructor.cnn_model(cnn_input_image_batched)
    current_cnn_output = predicted_output_batched[0] # Remove batch dim
    
    total_loss, loss_comp = reconstructor.loss_function(
        current_cnn_output=current_cnn_output,
        true_kspace_data_mc=true_kspace_data_mc.to(reconstructor.device),
        **loss_fn_kwargs
    )
    
    total_loss.backward()
    optimizer.step()
    
    # Log losses
    epoch_losses = {name: val.item() for name, val in loss_comp.items()}
    loss_history.append(epoch_losses)
    if epoch % max(1, pinn_config['num_epochs'] // 10) == 0 or epoch == pinn_config['num_epochs'] - 1:
        loss_str = ", ".join([f"{k}: {v:.4e}" for k,v in epoch_losses.items()])
        print(f"Epoch {epoch}/{pinn_config['num_epochs']}, Losses: [{loss_str}]")

reconstructor.cnn_model.eval()
with torch.no_grad():
    final_prediction_batched = reconstructor.cnn_model(cnn_input_image_batched)
reconstructed_image_mc = final_prediction_batched[0]

print("Reconstruction finished.")

## 5. Results

In [None]:
# Plot Loss History
if loss_history:
    loss_names = loss_history[0].keys()
    plt.figure(figsize=(10, 6))
    for loss_name in loss_names:
        if loss_name == 'total' or physics_terms: # Plot total and individual physics if physics_terms were added
            plt.plot([epoch_loss[loss_name] for epoch_loss in loss_history], label=loss_name)
    plt.xlabel("Epoch")
    plt.ylabel("Loss Value")
    plt.title("Loss Curve")
    plt.legend()
    plt.grid(True)
    plt.show()
else:
    print("No loss history to plot.")

# Display final reconstructed image (RSS)
reconstructed_image_rss = torch.sqrt(torch.sum(torch.abs(reconstructed_image_mc)**2, dim=0))
show_image(reconstructed_image_rss, title="PINN Reconstructed Image (RSS)")

# Comparison with Ground Truth and Initial Adjoint
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(torch.abs(gt_image).cpu().numpy(), cmap='gray')
axes[0].set_title("Ground Truth (Corrected)"); axes[0].axis('off');

axes[1].imshow(torch.abs(initial_image_rss).cpu().numpy(), cmap='gray')
axes[1].set_title("Initial Adjoint NUFFT (RSS)"); axes[1].axis('off');

axes[2].imshow(torch.abs(reconstructed_image_rss).cpu().numpy(), cmap='gray')
axes[2].set_title("PINN Reconstructed (RSS)"); axes[2].axis('off');

plt.suptitle("B0 Off-Resonance Correction Demo")
plt.tight_layout()
plt.show()

## 6. Conclusion

This notebook demonstrated the setup for a PINN-based B0 off-resonance correction.
- We simulated a ground truth image, coil sensitivities, and a B0 map.
- K-space data was generated from an off-resonant version of the ground truth.
- A `PINNReconstructor` was configured with a CNN, a data fidelity term, and placeholder physics loss terms (including `B0OffResonanceLoss`).
- The reconstruction loop was run, and the resulting image can be compared to the ground truth and the initial distorted image.

**Next Steps:**
- Implement realistic physics models within the `B0OffResonanceLoss` and other physics terms.
- Use more realistic k-space trajectories and coil sensitivity maps.
- Experiment with different CNN architectures and hyperparameters.
- Evaluate the reconstruction quality using appropriate metrics.
- If `FAST_DEMO_MODE = False` is used, ensure `reconlib` NUFFT implementations are performant enough or use a more powerful environment.