# PyTorch-based B0 Mapping Example

This notebook demonstrates the PyTorch-centric B0 mapping functions from `reconlib.b0_mapping`:
- `calculate_b0_map_dual_echo` (with and without internal phase unwrapping via `unwrap_method_fn`)
- `calculate_b0_map_multi_echo_linear_fit`

We will use synthetic 3D multi-echo data, generating a scenario where the phase difference between the first two echoes can wrap to show the utility of the `unwrap_method_fn`.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from reconlib.b0_mapping import calculate_b0_map_dual_echo, calculate_b0_map_multi_echo_linear_fit
from reconlib.b0_mapping.utils import create_mask_from_magnitude
from reconlib.phase_unwrapping import unwrap_phase_3d_quality_guided # Added for new feature
from reconlib.plotting import plot_phase_image, plot_b0_field_map # These are helper functions, but direct plt often used in notebooks

%matplotlib inline

## Helper Function to Generate Synthetic 3D B0 Data

In [None]:
def generate_synthetic_3d_b0_data(shape=(16, 64, 64), tes_list=[0.002, 0.004, 0.006], max_b0_hz=50.0, device='cpu'):
    """
    Generates synthetic 3D multi-echo phase data, true B0 map, and a magnitude image.
    """
    d, h, w = shape
    pi = getattr(torch, 'pi', np.pi)
    
    # Create a simple 3D B0 map: linear gradient along x, scaled by z
    b0_map_true = torch.zeros(shape, dtype=torch.float32, device=device)
    x_ramp = torch.linspace(-max_b0_hz, max_b0_hz, w, device=device)
    for z_idx in range(d):
        z_scale = (z_idx + 1.0) / d # Scale gradient by z slice index
        b0_map_true[z_idx, :, :] = x_ramp.view(1, -1) * z_scale
        
    # Create a simple magnitude image (e.g., a 3D cylinder along z)
    magnitude = torch.zeros(shape, dtype=torch.float32, device=device)
    center_y, center_x = h // 2, w // 2
    radius = min(h, w) // 3
    y_coords, x_coords = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing='ij')
    cylinder_mask_2d = ((y_coords - center_y)**2 + (x_coords - center_x)**2 <= radius**2)
    for z_idx in range(d):
        magnitude[z_idx, :, :] = cylinder_mask_2d.float()
        
    b0_map_true *= magnitude # Apply B0 only within the magnitude object
    
    echo_times_torch = torch.tensor(tes_list, dtype=torch.float32, device=device)
    num_echoes = len(tes_list)
    
    phase_images_torch = torch.zeros((num_echoes,) + shape, dtype=torch.float32, device=device)
    
    for i in range(num_echoes):
        # Calculate true phase: phi = 2 * pi * B0_hz * TE
        true_phase_echo_i = (2 * pi * b0_map_true * echo_times_torch[i]) * magnitude
        # Store the wrapped phase for the images, as this is what's typically measured/input.
        phase_images_torch[i, ...] = (true_phase_echo_i + pi) % (2 * pi) - pi
        
    return magnitude, phase_images_torch, echo_times_torch, b0_map_true

## Generate Data and Calculate B0 Maps

We'll set `max_b0_hz` high enough to cause phase wrapping in the difference between the first two echoes.

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

data_shape = (8, 32, 32) # Smaller depth for faster unwrapping in example: (depth, height, width)
# TE1=2ms, TE2=4.5ms => delta_TE = 2.5ms
# To make 2*pi*B0*delta_TE > pi, then B0*delta_TE > 0.5, so B0 > 0.5/0.0025 = 200 Hz
echo_times_s = [0.0020, 0.0045, 0.0070, 0.0095] # seconds, 4 echoes
current_max_b0_hz = 220.0 # Hz, chosen to induce wrapping for dual-echo

magnitude, phase_images, echo_times, b0_true = generate_synthetic_3d_b0_data(
    shape=data_shape, 
    tes_list=echo_times_s, 
    max_b0_hz=current_max_b0_hz, 
    device=device
)

# Create mask from magnitude
mask = create_mask_from_magnitude(magnitude.cpu().numpy(), threshold_factor=0.1) # util expects numpy
mask_torch = torch.from_numpy(mask).bool().to(device)

# Data for dual-echo (first two echoes)
dual_echo_phase_images = phase_images[:2, ...]
dual_echo_times = echo_times[:2]

# Calculate B0 map using dual-echo method (NO internal unwrapping)
print("\nCalculating B0 map using dual-echo method (no internal unwrapping)...")
b0_map_dual_no_unwrap = calculate_b0_map_dual_echo(
    dual_echo_phase_images, 
    dual_echo_times,   
    mask=mask_torch,
    unwrap_method_fn=None # Explicitly None
)

# Calculate B0 map using dual-echo method (WITH quality-guided unwrapping)
print("\nCalculating B0 map using dual-echo method (with quality-guided unwrapping)...")
b0_map_dual_with_unwrap = calculate_b0_map_dual_echo(
    dual_echo_phase_images, 
    dual_echo_times, 
    mask=mask_torch,
    unwrap_method_fn=unwrap_phase_3d_quality_guided # Pass the unwrapper
)

# Calculate B0 map using multi-echo linear fit (all echoes)
print("\nCalculating B0 map using multi-echo linear fit...")
# For multi-echo fit, the input phase images should ideally be unwrapped across echoes.
# Our synthetic data generates wrapped phase for each echo based on true B0.
# The linear fit function itself does not unwrap across the time (echo) dimension.
# To make this part more robust, one might unwrap phase_images along echo dimension first.
# For this example, we proceed with potentially wrapped inputs for the multi-echo fit.
b0_map_multi = calculate_b0_map_multi_echo_linear_fit(
    phase_images, # Using all echoes
    echo_times, 
    mask=mask_torch
)
print("\nB0 mapping calculations complete.")

## Visualize Results

We'll visualize a central slice from the 3D volumes. The dual-echo result without unwrapping is expected to show aliasing due to the high B0 field, while the version with quality-guided unwrapping should be more accurate.

In [None]:
slice_idx = data_shape[0] // 2
pi = getattr(torch, 'pi', np.pi)

mag_slice = magnitude[slice_idx, ...].cpu().numpy()
phase1_slice = phase_images[0, slice_idx, ...].cpu().numpy()
phase2_slice = phase_images[1, slice_idx, ...].cpu().numpy()
b0_true_slice = b0_true[slice_idx, ...].cpu().numpy()
b0_dual_no_unwrap_slice = b0_map_dual_no_unwrap[slice_idx, ...].cpu().numpy()
b0_dual_with_unwrap_slice = b0_map_dual_with_unwrap[slice_idx, ...].cpu().numpy()
b0_multi_slice = b0_map_multi[slice_idx, ...].cpu().numpy()
mask_slice = mask_torch[slice_idx, ...].cpu().numpy()

fig, axes = plt.subplots(2, 4, figsize=(22, 10))

axes[0, 0].imshow(mag_slice * mask_slice, cmap='gray')
axes[0, 0].set_title(f"Magnitude (Slice {slice_idx})")
axes[0, 0].axis('off')

axes[0, 1].imshow(phase1_slice * mask_slice, cmap='twilight', vmin=-pi, vmax=pi)
axes[0, 1].set_title(f"Phase TE1 (Slice {slice_idx})")
axes[0, 1].axis('off')
fig.colorbar(axes[0,1].images[0], ax=axes[0,1], shrink=0.8)

axes[0, 2].imshow(phase2_slice * mask_slice, cmap='twilight', vmin=-pi, vmax=pi)
axes[0, 2].set_title(f"Phase TE2 (Slice {slice_idx})")
axes[0, 2].axis('off')
fig.colorbar(axes[0,2].images[0], ax=axes[0,2], shrink=0.8)

axes[0, 3].imshow(b0_true_slice * mask_slice, cmap='coolwarm', vmin=-current_max_b0_hz, vmax=current_max_b0_hz)
axes[0, 3].set_title(f"True B0 Map (Slice {slice_idx})")
axes[0, 3].axis('off')
fig.colorbar(axes[0,3].images[0], ax=axes[0,3], label='Hz', shrink=0.8)

axes[1, 0].imshow(b0_dual_no_unwrap_slice * mask_slice, cmap='coolwarm', vmin=-current_max_b0_hz, vmax=current_max_b0_hz)
axes[1, 0].set_title(f"Dual-Echo B0 (No Unwrap)")
axes[1, 0].axis('off')
fig.colorbar(axes[1,0].images[0], ax=axes[1,0], label='Hz', shrink=0.8)

axes[1, 1].imshow(b0_dual_with_unwrap_slice * mask_slice, cmap='coolwarm', vmin=-current_max_b0_hz, vmax=current_max_b0_hz)
axes[1, 1].set_title(f"Dual-Echo B0 (QualityGuided Unwrap)")
axes[1, 1].axis('off')
fig.colorbar(axes[1,1].images[0], ax=axes[1,1], label='Hz', shrink=0.8)

axes[1, 2].imshow(b0_multi_slice * mask_slice, cmap='coolwarm', vmin=-current_max_b0_hz, vmax=current_max_b0_hz)
axes[1, 2].set_title(f"Multi-Echo B0 (Linear Fit)")
axes[1, 2].axis('off')
fig.colorbar(axes[1,2].images[0], ax=axes[1,2], label='Hz', shrink=0.8)

axes[1, 3].axis('off') # Empty subplot for now

plt.suptitle(f"PyTorch B0 Mapping Results (3D Synthetic Data, Slice {slice_idx})", fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()