# PyTorch-based B0 Mapping Example

This notebook demonstrates the PyTorch-centric B0 mapping functions from `reconlib.b0_mapping`:
- `calculate_b0_map_dual_echo` (demonstrating `unwrap_method_fn` for phase difference unwrapping)
- `calculate_b0_map_multi_echo_linear_fit` (demonstrating `spatial_unwrap_fn` for unwrapping individual echoes spatially)

We will use synthetic 3D multi-echo data for two scenarios:
1. High B0 field: Causes the *phase difference* (TE2-TE1) to wrap, for `calculate_b0_map_dual_echo`.
2. Spatially wrapped echoes: Individual echoes have spatial wraps, for `calculate_b0_map_multi_echo_linear_fit`.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from reconlib.b0_mapping import calculate_b0_map_dual_echo, calculate_b0_map_multi_echo_linear_fit
from reconlib.b0_mapping.utils import create_mask_from_magnitude
from reconlib.phase_unwrapping import unwrap_phase_3d_quality_guided

%matplotlib inline

## Helper Functions for Synthetic Data Generation

In [None]:
def _wrap_phase_torch(phase_tensor: torch.Tensor) -> torch.Tensor:
    """Wraps phase values to the interval [-pi, pi) using PyTorch operations."""
    pi = getattr(torch, 'pi', np.pi)
    return (phase_tensor + pi) % (2 * pi) - pi

def _create_spatial_wrap_pattern(shape_spatial, device, max_val_factor=1.5):
    """
    Creates a 3D spatial pattern that can induce phase wrapping.
    The pattern is a sum of linear ramps along each spatial dimension.
    max_val_factor determines how many times pi the ramp reaches.
    """
    pi = getattr(torch, 'pi', np.pi)
    max_val = max_val_factor * pi
    
    dim_ramps = []
    for i, dim_size in enumerate(shape_spatial):
        ramp_1d = torch.linspace(0, max_val, dim_size, device=device)
        view_shape = [1] * len(shape_spatial)
        view_shape[i] = dim_size
        dim_ramps.append(ramp_1d.view(view_shape))
    
    pattern = torch.zeros(shape_spatial, device=device)
    for r in dim_ramps:
        pattern += r # Summing ramps from each dimension
    return pattern

def generate_synthetic_3d_b0_data(shape=(16, 64, 64), tes_list=[0.002, 0.004, 0.006], 
                                  max_b0_hz=50.0, device='cpu', 
                                  add_spatial_wraps_to_echoes=False, spatial_wrap_factor=1.5):
    """
    Generates synthetic 3D multi-echo phase data, true B0 map, and a magnitude image.
    Can add intentional spatial wraps to individual echoes.
    """
    d, h, w = shape
    pi = getattr(torch, 'pi', np.pi)
    
    b0_map_true = torch.zeros(shape, dtype=torch.float32, device=device)
    x_ramp = torch.linspace(-max_b0_hz, max_b0_hz, w, device=device)
    for z_idx in range(d):
        z_scale = (z_idx + 1.0) / d
        b0_map_true[z_idx, :, :] = x_ramp.view(1, -1) * z_scale
        
    magnitude = torch.zeros(shape, dtype=torch.float32, device=device)
    center_y, center_x = h // 2, w // 2
    radius = min(h, w) // 3
    y_coords, x_coords = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing='ij')
    cylinder_mask_2d = ((y_coords - center_y)**2 + (x_coords - center_x)**2 <= radius**2)
    for z_idx in range(d):
        magnitude[z_idx, :, :] = cylinder_mask_2d.float()
    b0_map_true *= magnitude
    
    echo_times_torch = torch.tensor(tes_list, dtype=torch.float32, device=device)
    num_echoes = len(tes_list)
    phase_images_torch = torch.zeros((num_echoes,) + shape, dtype=torch.float32, device=device)
    
    spatial_wrapping_field = None
    if add_spatial_wraps_to_echoes:
        spatial_wrapping_field = _create_spatial_wrap_pattern(shape, device, max_val_factor=spatial_wrap_factor)

    for i in range(num_echoes):
        base_phase = (2 * pi * b0_map_true * echo_times_torch[i]) * magnitude
        final_phase_for_echo = base_phase
        if add_spatial_wraps_to_echoes and spatial_wrapping_field is not None:
            final_phase_for_echo += spatial_wrapping_field
        phase_images_torch[i, ...] = _wrap_phase_torch(final_phase_for_echo)
        
    return magnitude, phase_images_torch, echo_times_torch, b0_map_true

## Scenario 1: High B0 Field for `calculate_b0_map_dual_echo`

Here, `max_b0_hz` is set high to cause the *phase difference* (TE2-TE1) to wrap. This demonstrates the utility of `unwrap_method_fn`.

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

data_shape_dual_echo = (8, 32, 32) # (depth, height, width)
echo_times_dual = [0.0020, 0.0045] # TE1=2ms, TE2=4.5ms => delta_TE = 2.5ms
# To make 2*pi*B0*delta_TE > pi, then B0*delta_TE > 0.5, so B0 > 0.5/0.0025 = 200 Hz
high_b0_val = 220.0 # Hz, chosen to induce phase difference wrapping

magnitude_high_b0, phase_images_high_b0, tes_high_b0, b0_true_high_b0 = generate_synthetic_3d_b0_data(
    shape=data_shape_dual_echo, 
    tes_list=echo_times_dual, 
    max_b0_hz=high_b0_val, 
    device=device,
    add_spatial_wraps_to_echoes=False # No additional spatial wraps on individual echoes here
)

mask_high_b0_np = create_mask_from_magnitude(magnitude_high_b0.cpu().numpy(), threshold_factor=0.1)
mask_high_b0_torch = torch.from_numpy(mask_high_b0_np).bool().to(device)

print("\nCalculating B0 map (dual-echo, high B0, no internal unwrapping)...")
b0_dual_high_b0_no_unwrap = calculate_b0_map_dual_echo(
    phase_images_high_b0, 
    tes_high_b0,   
    mask=mask_high_b0_torch,
    unwrap_method_fn=None
)

print("\nCalculating B0 map (dual-echo, high B0, with quality-guided unwrapping)...")
b0_dual_high_b0_with_unwrap = calculate_b0_map_dual_echo(
    phase_images_high_b0, 
    tes_high_b0, 
    mask=mask_high_b0_torch,
    unwrap_method_fn=unwrap_phase_3d_quality_guided
)

### Visualize Results for Scenario 1 (Dual-Echo with High B0)

In [None]:
slice_idx_dual = data_shape_dual_echo[0] // 2
pi_val = getattr(torch, 'pi', np.pi)

fig_s1, axes_s1 = plt.subplots(1, 4, figsize=(20, 5))
axes_s1[0].imshow(b0_true_high_b0[slice_idx_dual].cpu().numpy() * mask_high_b0_np[slice_idx_dual], cmap='coolwarm', vmin=-high_b0_val, vmax=high_b0_val)
axes_s1[0].set_title(f"S1: True B0 (Slice {slice_idx_dual})")
axes_s1[0].axis('off')
fig_s1.colorbar(axes_s1[0].images[0], ax=axes_s1[0], label='Hz', shrink=0.8)

# Phase difference for visualization
phase_diff_high_b0 = phase_images_high_b0[1,...] - phase_images_high_b0[0,...]
axes_s1[1].imshow(_wrap_phase_torch(phase_diff_high_b0)[slice_idx_dual].cpu().numpy() * mask_high_b0_np[slice_idx_dual], cmap='twilight', vmin=-pi_val, vmax=pi_val)
axes_s1[1].set_title(f"S1: Wrapped Phase Diff (Slice {slice_idx_dual})")
axes_s1[1].axis('off')
fig_s1.colorbar(axes_s1[1].images[0], ax=axes_s1[1], label='rad', shrink=0.8)

axes_s1[2].imshow(b0_dual_high_b0_no_unwrap[slice_idx_dual].cpu().numpy() * mask_high_b0_np[slice_idx_dual], cmap='coolwarm', vmin=-high_b0_val, vmax=high_b0_val)
axes_s1[2].set_title(f"S1: Dual-Echo B0 (No Unwrap)")
axes_s1[2].axis('off')
fig_s1.colorbar(axes_s1[2].images[0], ax=axes_s1[2], label='Hz', shrink=0.8)

axes_s1[3].imshow(b0_dual_high_b0_with_unwrap[slice_idx_dual].cpu().numpy() * mask_high_b0_np[slice_idx_dual], cmap='coolwarm', vmin=-high_b0_val, vmax=high_b0_val)
axes_s1[3].set_title(f"S1: Dual-Echo B0 (Unwrap Diff)")
axes_s1[3].axis('off')
fig_s1.colorbar(axes_s1[3].images[0], ax=axes_s1[3], label='Hz', shrink=0.8)

plt.suptitle(f"Scenario 1: Dual-Echo B0 Mapping with Phase Difference Unwrapping (Slice {slice_idx_dual})", fontsize=14)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

## Scenario 2: Spatially Wrapped Echoes for `calculate_b0_map_multi_echo_linear_fit`

Here, the true B0 field is simple (low `max_b0_hz`), but we add artificial spatial wraps to each echo's phase. This demonstrates the utility of `spatial_unwrap_fn`.

In [None]:
data_shape_multi_echo = (8, 32, 32) # (depth, height, width)
echo_times_multi = [0.0020, 0.0045, 0.0070, 0.0095] # 4 echoes
low_b0_val = 50.0 # Hz, simple B0 field

magnitude_spat_wrap, phase_images_spat_wrap, tes_spat_wrap, b0_true_spat_wrap = generate_synthetic_3d_b0_data(
    shape=data_shape_multi_echo, 
    tes_list=echo_times_multi, 
    max_b0_hz=low_b0_val, 
    device=device,
    add_spatial_wraps_to_echoes=True, # Key for this scenario
    spatial_wrap_factor=2.0 # Induce significant spatial wraps per echo
)

mask_spat_wrap_np = create_mask_from_magnitude(magnitude_spat_wrap.cpu().numpy(), threshold_factor=0.1)
mask_spat_wrap_torch = torch.from_numpy(mask_spat_wrap_np).bool().to(device)

print("\nCalculating B0 map (multi-echo, spatially wrapped echoes, no spatial unwrapping)...")
b0_multi_spat_wrap_no_unwrap = calculate_b0_map_multi_echo_linear_fit(
    phase_images_spat_wrap, 
    tes_spat_wrap,   
    mask=mask_spat_wrap_torch,
    spatial_unwrap_fn=None # Explicitly None
)

print("\nCalculating B0 map (multi-echo, spatially wrapped echoes, with spatial unwrapping)...")
b0_multi_spat_wrap_with_unwrap = calculate_b0_map_multi_echo_linear_fit(
    phase_images_spat_wrap, 
    tes_spat_wrap, 
    mask=mask_spat_wrap_torch,
    spatial_unwrap_fn=unwrap_phase_3d_quality_guided # Pass the unwrapper for individual echoes
)

### Visualize Results for Scenario 2 (Multi-Echo with Spatially Wrapped Echoes)

In [None]:
slice_idx_multi = data_shape_multi_echo[0] // 2

fig_s2, axes_s2 = plt.subplots(1, 4, figsize=(20, 5))
axes_s2[0].imshow(b0_true_spat_wrap[slice_idx_multi].cpu().numpy() * mask_spat_wrap_np[slice_idx_multi], cmap='coolwarm', vmin=-low_b0_val, vmax=low_b0_val)
axes_s2[0].set_title(f"S2: True B0 (Slice {slice_idx_multi})")
axes_s2[0].axis('off')
fig_s2.colorbar(axes_s2[0].images[0], ax=axes_s2[0], label='Hz', shrink=0.8)

# Show one of the spatially wrapped echoes for visualization
axes_s2[1].imshow(phase_images_spat_wrap[0, slice_idx_multi].cpu().numpy() * mask_spat_wrap_np[slice_idx_multi], cmap='twilight', vmin=-pi_val, vmax=pi_val)
axes_s2[1].set_title(f"S2: Spatially Wrapped Echo1 (Sl {slice_idx_multi})")
axes_s2[1].axis('off')
fig_s2.colorbar(axes_s2[1].images[0], ax=axes_s2[1], label='rad', shrink=0.8)

axes_s2[2].imshow(b0_multi_spat_wrap_no_unwrap[slice_idx_multi].cpu().numpy() * mask_spat_wrap_np[slice_idx_multi], cmap='coolwarm', vmin=-low_b0_val, vmax=low_b0_val)
axes_s2[2].set_title(f"S2: Multi-Echo B0 (No Echo Unwrap)")
axes_s2[2].axis('off')
fig_s2.colorbar(axes_s2[2].images[0], ax=axes_s2[2], label='Hz', shrink=0.8)

axes_s2[3].imshow(b0_multi_spat_wrap_with_unwrap[slice_idx_multi].cpu().numpy() * mask_spat_wrap_np[slice_idx_multi], cmap='coolwarm', vmin=-low_b0_val, vmax=low_b0_val)
axes_s2[3].set_title(f"S2: Multi-Echo B0 (SpatUnwrap Echoes)")
axes_s2[3].axis('off')
fig_s2.colorbar(axes_s2[3].images[0], ax=axes_s2[3], label='Hz', shrink=0.8)

plt.suptitle(f"Scenario 2: Multi-Echo B0 Mapping with Spatial Unwrapping of Echoes (Slice {slice_idx_multi})", fontsize=14)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

## Summary of Original Multi-Echo Fit (from previous notebook version)
For completeness, if you had a dataset with smooth echoes (no intentional spatial wraps added beyond what B0*TE causes), the multi-echo linear fit would typically be accurate without needing `spatial_unwrap_fn` if the phase values `phi_echo = 2*pi*B0*TE_echo` do not alias *across TEs* too severely for the linear fit. The `spatial_unwrap_fn` is for when *each individual echo's spatial phase pattern* is wrapped.