# PyTorch-based B0 Mapping Example

This notebook demonstrates the PyTorch-centric B0 mapping functions from `reconlib.b0_mapping`:
- `calculate_b0_map_dual_echo`
- `calculate_b0_map_multi_echo_linear_fit`

We will use synthetic 3D multi-echo data.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from reconlib.b0_mapping import calculate_b0_map_dual_echo, calculate_b0_map_multi_echo_linear_fit
from reconlib.b0_mapping.utils import create_mask_from_magnitude # Still useful for mask creation from magnitude
from reconlib.plotting import plot_phase_image, plot_b0_field_map

%matplotlib inline

## Helper Function to Generate Synthetic 3D B0 Data

In [None]:
def generate_synthetic_3d_b0_data(shape=(16, 64, 64), tes_list=[0.002, 0.004, 0.006], max_b0_hz=50.0, device='cpu'):
    """
    Generates synthetic 3D multi-echo phase data, true B0 map, and a magnitude image.
    """
    d, h, w = shape
    pi = getattr(torch, 'pi', np.pi)
    
    # Create a simple 3D B0 map: linear gradient along x, scaled by z
    b0_map_true = torch.zeros(shape, dtype=torch.float32, device=device)
    x_ramp = torch.linspace(-max_b0_hz, max_b0_hz, w, device=device)
    for z_idx in range(d):
        z_scale = (z_idx + 1.0) / d # Scale gradient by z slice index
        b0_map_true[z_idx, :, :] = x_ramp.view(1, -1) * z_scale
        
    # Create a simple magnitude image (e.g., a 3D cylinder along z)
    magnitude = torch.zeros(shape, dtype=torch.float32, device=device)
    center_y, center_x = h // 2, w // 2
    radius = min(h, w) // 3
    y_coords, x_coords = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing='ij')
    cylinder_mask_2d = ((y_coords - center_y)**2 + (x_coords - center_x)**2 <= radius**2)
    for z_idx in range(d):
        magnitude[z_idx, :, :] = cylinder_mask_2d.float()
        
    b0_map_true *= magnitude # Apply B0 only within the magnitude object
    
    echo_times_torch = torch.tensor(tes_list, dtype=torch.float32, device=device)
    num_echoes = len(tes_list)
    
    phase_images_torch = torch.zeros((num_echoes,) + shape, dtype=torch.float32, device=device)
    
    for i in range(num_echoes):
        phase_images_torch[i, ...] = (2 * pi * b0_map_true * echo_times_torch[i]) * magnitude
        # Note: dual_echo handles wrapped phase diffs. multi_echo_linear_fit ideally wants unwrapped.
        # For this clean synthetic data, direct phase values are used.
        
    return magnitude, phase_images_torch, echo_times_torch, b0_map_true

## Generate Data and Calculate B0 Maps

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

data_shape = (16, 48, 48) # (depth, height, width)
echo_times_s = [0.002, 0.0045, 0.007, 0.0095] # seconds, 4 echoes

magnitude, phase_images, echo_times, b0_true = generate_synthetic_3d_b0_data(
    shape=data_shape, 
    tes_list=echo_times_s, 
    max_b0_hz=60.0, 
    device=device
)

# Create mask from magnitude
mask = create_mask_from_magnitude(magnitude.cpu().numpy(), threshold_factor=0.1) # util expects numpy
mask_torch = torch.from_numpy(mask).bool().to(device)

# Calculate B0 map using dual-echo method (first two echoes)
print("\nCalculating B0 map using dual-echo method...")
b0_map_dual = calculate_b0_map_dual_echo(
    phase_images.narrow(0, 0, 2), # Use first two echoes
    echo_times.narrow(0, 0, 2),   # Corresponding TEs
    mask=mask_torch
)

# Calculate B0 map using multi-echo linear fit (all echoes)
print("\nCalculating B0 map using multi-echo linear fit...")
b0_map_multi = calculate_b0_map_multi_echo_linear_fit(
    phase_images, 
    echo_times, 
    mask=mask_torch
)
print("\nB0 mapping calculations complete.")

## Visualize Results

We'll visualize a central slice from the 3D volumes.

In [None]:
slice_idx = data_shape[0] // 2
pi = getattr(torch, 'pi', np.pi)

mag_slice = magnitude[slice_idx, ...].cpu().numpy()
phase1_slice = phase_images[0, slice_idx, ...].cpu().numpy()
phase2_slice = phase_images[1, slice_idx, ...].cpu().numpy()
b0_true_slice = b0_true[slice_idx, ...].cpu().numpy()
b0_dual_slice = b0_map_dual[slice_idx, ...].cpu().numpy()
b0_multi_slice = b0_map_multi[slice_idx, ...].cpu().numpy()
mask_slice = mask_torch[slice_idx, ...].cpu().numpy()

fig, axes = plt.subplots(2, 3, figsize=(18, 10))

axes[0, 0].imshow(mag_slice * mask_slice, cmap='gray')
axes[0, 0].set_title(f"Magnitude (Slice {slice_idx})")
axes[0, 0].axis('off')

axes[0, 1].imshow(phase1_slice * mask_slice, cmap='twilight', vmin=-pi, vmax=pi)
axes[0, 1].set_title(f"Phase TE1 (Slice {slice_idx})")
axes[0, 1].axis('off')
fig.colorbar(axes[0,1].images[0], ax=axes[0,1])

axes[0, 2].imshow(phase2_slice * mask_slice, cmap='twilight', vmin=-pi, vmax=pi)
axes[0, 2].set_title(f"Phase TE2 (Slice {slice_idx})")
axes[0, 2].axis('off')
fig.colorbar(axes[0,2].images[0], ax=axes[0,2])

axes[1, 0].imshow(b0_true_slice * mask_slice, cmap='coolwarm', vmin=-max_b0_hz, vmax=max_b0_hz)
axes[1, 0].set_title(f"True B0 Map (Slice {slice_idx})")
axes[1, 0].axis('off')
fig.colorbar(axes[1,0].images[0], ax=axes[1,0], label='Hz')

axes[1, 1].imshow(b0_dual_slice * mask_slice, cmap='coolwarm', vmin=-max_b0_hz, vmax=max_b0_hz)
axes[1, 1].set_title(f"Dual-Echo B0 (Slice {slice_idx})")
axes[1, 1].axis('off')
fig.colorbar(axes[1,1].images[0], ax=axes[1,1], label='Hz')

axes[1, 2].imshow(b0_multi_slice * mask_slice, cmap='coolwarm', vmin=-max_b0_hz, vmax=max_b0_hz)
axes[1, 2].set_title(f"Multi-Echo B0 (Slice {slice_idx})")
axes[1, 2].axis('off')
fig.colorbar(axes[1,2].images[0], ax=axes[1,2], label='Hz')

plt.suptitle("PyTorch B0 Mapping Results (3D Synthetic Data)", fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()