# Comprehensive B0 Mapping Pipeline Example

This notebook demonstrates the recommended workflow for B0 mapping from raw multi-coil, multi-echo complex image data. The pipeline involves:
1.  **Preprocessing:** Starting with multi-coil, multi-echo complex images, we use `preprocess_multi_coil_multi_echo_data` from `reconlib.pipeline_utils`. This utility performs:
    *   Coil combination (using `combine_coils_complex_sum`).
    *   Mask generation based on SNR of the first echo.
    *   Multi-echo phase unwrapping using `unwrap_multi_echo_masked_reference`, which takes a user-provided spatial unwrapper (e.g., `unwrap_phase_3d_quality_guided`) to unwrap the reference echo and phase differences.
    The output is coil-combined, masked, and unwrapped phase images for each echo, along with the mask and combined magnitudes.
2.  **B0 Map Calculation:** The processed (unwrapped) phase images are then fed into the refactored B0 mapping functions from `reconlib.b0_mapping.phase_based_b0_field_maps`:
    *   `calculate_b0_map_dual_echo`
    *   `calculate_b0_map_multi_echo_linear_fit`
    These functions now expect pre-processed, phase-only inputs.

In [None]:
import sys
import os
# This allows running the example directly from the 'examples' folder.
# For general use, it's recommended to install reconlib (e.g., `pip install -e .` from root).
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname('__file__' if '__file__' in globals() else '.'), '..')))

import torch
import numpy as np
import matplotlib.pyplot as plt

from reconlib.pipeline_utils import preprocess_multi_coil_multi_echo_data
from reconlib.phase_unwrapping import unwrap_phase_3d_quality_guided # Example spatial unwrapper
from reconlib.b0_mapping.phase_based_b0_field_maps import (
    calculate_b0_map_dual_echo, 
    calculate_b0_map_multi_echo_linear_fit
)
from reconlib.utils import combine_coils_complex_sum # For visualizing intermediate combined phase

%matplotlib inline

## Setup

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
pi_val = getattr(torch, 'pi', np.pi) # For plotting limits

## Synthetic Data Generation (Multi-Coil, Multi-Echo)

This helper function generates multi-coil complex data. Individual echoes will have spatial wraps, and the true B0 map can cause temporal wrapping of phase differences.

In [None]:
def _wrap_phase_torch(phase_tensor: torch.Tensor) -> torch.Tensor:
    pi = getattr(torch, 'pi', np.pi)
    return (phase_tensor + pi) % (2 * pi) - pi

def _create_spatial_wrap_pattern(shape_spatial, device, max_val_factor=1.5):
    pi = getattr(torch, 'pi', np.pi)
    max_val = max_val_factor * pi
    dim_ramps = []
    for i, dim_size in enumerate(shape_spatial):
        ramp_1d = torch.linspace(0, max_val, dim_size, device=device)
        view_shape = [1] * len(shape_spatial)
        view_shape[i] = dim_size
        dim_ramps.append(ramp_1d.view(view_shape))
    pattern = torch.zeros(shape_spatial, device=device)
    for r in dim_ramps:
        pattern += r
    return pattern

def _generate_base_3d_true_phase(shape_spatial, ramps_scale, device):
    d, h, w = shape_spatial
    pi = getattr(torch, 'pi', np.pi)
    z_coords = torch.linspace(-pi * ramps_scale[0], pi * ramps_scale[0], d, device=device)
    y_coords = torch.linspace(-pi * ramps_scale[1], pi * ramps_scale[1], h, device=device)
    x_coords = torch.linspace(-pi * ramps_scale[2], pi * ramps_scale[2], w, device=device)
    true_phase = z_coords.view(-1, 1, 1) + y_coords.view(1, -1, 1) + x_coords.view(1, 1, -1)
    return true_phase.expand(d,h,w)

def generate_synthetic_multicoil_multiecho_data_notebook(
    spatial_shape=(8, 32, 32), num_echoes=3, num_coils=4,
    snr_thresh_for_mag_pattern=0.2, 
    max_b0_hz=220.0, # To ensure dual-echo phase difference wraps
    echo_times_s=[0.0020, 0.0045, 0.0070], # Ensure delta_TE for first two allows wrapping with max_b0_hz
    spatial_wrap_factor_echoes=1.8, # For spatial wraps on individual echoes
    csm_spatial_complexity=(0.1,0.15,0.2),
    device='cpu'
):
    d, h, w = spatial_shape
    pi = getattr(torch, 'pi', np.pi)
    
    # 1. True B0 map (can cause temporal wraps)
    true_b0_map = torch.zeros(spatial_shape, dtype=torch.float32, device=device)
    x_ramp = torch.linspace(-max_b0_hz, max_b0_hz, w, device=device)
    for z_idx in range(d):
        z_scale = (z_idx + 1.0) / d
        true_b0_map[z_idx, :, :] = x_ramp.view(1, -1) * z_scale

    # 2. Base Magnitude (for mask generation and overall signal)
    base_magnitude = torch.full(spatial_shape, snr_thresh_for_mag_pattern / 2, device=device)
    slice_d, slice_h, slice_w = d//4, h//4, w//4
    base_magnitude[slice_d:-slice_d, slice_h:-slice_h, slice_w:-slice_w] = snr_thresh_for_mag_pattern * 2
    true_b0_map *= (base_magnitude > 0) # B0 only where there's signal

    # 3. Echo Times
    echo_times_torch = torch.tensor(echo_times_s, dtype=torch.float32, device=device)

    # 4. Create Mock Coil Sensitivity Maps (CSMs)
    mock_csms = torch.zeros((num_coils,) + spatial_shape, dtype=torch.complex64, device=device)
    for c in range(num_coils):
        csm_phase_offset = _generate_base_3d_true_phase(spatial_shape, tuple(f*(c+1) for f in csm_spatial_complexity), device)
        csm_mag_profile = _generate_base_3d_true_phase(spatial_shape, csm_spatial_complexity, device)
        csm_mag_profile = (torch.cos(csm_mag_profile) + 1.5) / 2.5 
        mock_csms[c, ...] = csm_mag_profile * torch.exp(1j * csm_phase_offset)
    rss_csms = torch.sqrt(torch.sum(torch.abs(mock_csms)**2, dim=0, keepdim=True)) + 1e-9
    mock_csms_normalized = mock_csms / rss_csms

    # 5. Synthesize Multi-Coil, Multi-Echo Complex Images
    multi_coil_complex_images = torch.zeros((num_echoes, num_coils) + spatial_shape, dtype=torch.complex64, device=device)
    spatial_wrapping_field_echoes = _create_spatial_wrap_pattern(spatial_shape, device, max_val_factor=spatial_wrap_factor_echoes)

    for e in range(num_echoes):
        # Phase from B0*TE (this is the true *unwrapped* evolution)
        phase_from_b0_te = (2 * pi * true_b0_map * echo_times_torch[e])
        # Add spatial wraps to this true phase evolution
        true_unwrapped_phase_for_echo = phase_from_b0_te + spatial_wrapping_field_echoes
        # This combined signal's phase is what we ideally want after unwrapping
        # Magnitude decays slightly with echo time
        current_magnitude = base_magnitude * (1 - 0.1 * e) 
        true_combined_complex_signal_e = current_magnitude * torch.exp(1j * true_unwrapped_phase_for_echo)
        
        for c in range(num_coils):
            # Distribute the *true combined* signal to coils using CSMs, then wrap the phase of each coil image
            coil_signal_complex = true_combined_complex_signal_e * mock_csms_normalized[c, ...]
            coil_phase_wrapped = _wrap_phase_torch(torch.angle(coil_signal_complex))
            multi_coil_complex_images[e, c, ...] = torch.abs(coil_signal_complex) * torch.exp(1j * coil_phase_wrapped)
            
    return multi_coil_complex_images, true_b0_map, echo_times_torch, base_magnitude

## Step 1: Preprocessing with `preprocess_multi_coil_multi_echo_data`

In [None]:
# Define parameters
spatial_shape_g = (8, 40, 40) # D, H, W
num_echoes_g = 3
num_coils_g = 4
echo_times_s_g = [0.0020, 0.0045, 0.0070] # s, TE1=2ms, TE2=4.5ms => delta_TE=2.5ms
max_b0_hz_g = 220.0 # To ensure phase_diff for dual-echo wraps (220*0.0025 = 0.55 > 0.5)
snr_thresh_g = 0.15

# Generate multi-coil multi-echo complex data
mc_me_images, true_b0, echo_times_val, base_mag_viz = \
    generate_synthetic_multicoil_multiecho_data_notebook(
        spatial_shape=spatial_shape_g,
        num_echoes=num_echoes_g,
        num_coils=num_coils_g,
        snr_thresh_for_mag_pattern=snr_thresh_g * 1.5, # Ensure good mask region
        max_b0_hz=max_b0_hz_g,
        echo_times_s=echo_times_s_g,
        spatial_wrap_factor_echoes=1.8, # Add spatial wraps to individual echoes
        device=device
    )

print(f"Input multi-coil complex images shape: {mc_me_images.shape}")

# Define the spatial unwrapper for the pipeline
def quality_guided_adapter_for_pipeline(phase_tensor, mask_tensor):
    # unwrap_phase_3d_quality_guided does not take a mask in its call signature for processing,
    # but relies on the mask being applied to its output by unwrap_multi_echo_masked_reference if needed.
    # Or, if the unwrapper itself could use a mask (e.g. for quality map), it would be passed here.
    # Here, we assume the unwrapper does its best on the whole volume, 
    # and `unwrap_multi_echo_masked_reference` handles applying the generated mask to the final result.
    return unwrap_phase_3d_quality_guided(phase_tensor, sigma_blur=1.0)

spatial_unwrapper_for_pipeline = quality_guided_adapter_for_pipeline

print("\nRunning preprocessing pipeline...")
processed_unwrapped_phases, generated_mask, combined_magnitudes = preprocess_multi_coil_multi_echo_data(
    mc_me_images,
    snr_thresh_g,
    spatial_unwrapper_for_pipeline
)
print("Preprocessing complete.")
print(f"Processed unwrapped phases shape: {processed_unwrapped_phases.shape}")
print(f"Generated mask shape: {generated_mask.shape}")
print(f"Combined magnitudes shape: {combined_magnitudes.shape}")

### Visualize Intermediate Preprocessing Results

In [None]:
slice_d_viz = spatial_shape_g[0] // 2
echo_idx_viz = 0 # Visualize first echo

# Calculate wrapped combined phase for visualization (as it's not directly returned by pipeline)
wrapped_combined_phase_viz, _ = combine_coils_complex_sum(mc_me_images[echo_idx_viz])

fig_preproc, axes_preproc = plt.subplots(1, 4, figsize=(20, 5))
fig_preproc.suptitle(f"Preprocessing Results (Echo: {echo_idx_viz}, Slice: {slice_d_viz})", fontsize=16)

ax = axes_preproc[0]
im = ax.imshow(combined_magnitudes[echo_idx_viz, slice_d_viz, ...].cpu().numpy(), cmap='gray')
ax.set_title(f"Combined Mag Echo {echo_idx_viz}")
ax.axis('off')
fig_preproc.colorbar(im, ax=ax, shrink=0.8)

ax = axes_preproc[1]
im = ax.imshow(generated_mask[slice_d_viz, ...].cpu().numpy(), cmap='gray')
ax.set_title("Generated SNR Mask")
ax.axis('off')

ax = axes_preproc[2]
im = ax.imshow(wrapped_combined_phase_viz[slice_d_viz, ...].cpu().numpy() * generated_mask[slice_d_viz, ...].cpu().numpy(), cmap='twilight', vmin=-pi_val, vmax=pi_val)
ax.set_title(f"Wrapped Combined Phase Echo {echo_idx_viz} (Masked)")
ax.axis('off')
fig_preproc.colorbar(im, ax=ax, label='rad', shrink=0.8)

ax = axes_preproc[3]
im = ax.imshow(processed_unwrapped_phases[echo_idx_viz, slice_d_viz, ...].cpu().numpy() * generated_mask[slice_d_viz, ...].cpu().numpy(), cmap='viridis')
ax.set_title(f"Processed Unwrapped Phase Echo {echo_idx_viz} (Masked)")
ax.axis('off')
fig_preproc.colorbar(im, ax=ax, label='rad', shrink=0.8)

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

## Step 2: B0 Mapping using Refactored Functions

In [None]:
# Dual-Echo B0 Map
dual_echo_input_phases = processed_unwrapped_phases[:2, ...] # First two echoes
dual_echo_tes = echo_times_val[:2]
b0_map_dual = calculate_b0_map_dual_echo(dual_echo_input_phases, dual_echo_tes, mask=generated_mask)
print(f"Dual-echo B0 map calculated, shape: {b0_map_dual.shape}")

# Multi-Echo Linear Fit B0 Map
b0_map_multi = calculate_b0_map_multi_echo_linear_fit(processed_unwrapped_phases, echo_times_val, mask=generated_mask)
print(f"Multi-echo linear fit B0 map calculated, shape: {b0_map_multi.shape}")

### Visualize B0 Mapping Results

In [None]:
fig_b0, axes_b0 = plt.subplots(1, 4, figsize=(20, 5))
fig_b0.suptitle(f"B0 Mapping Results (Slice: {slice_d_viz})", fontsize=16)

vmin_b0, vmax_b0 = -max_b0_hz_g, max_b0_hz_g # For consistent color scaling

ax = axes_b0[0]
im = ax.imshow(true_b0[slice_d_viz, ...].cpu().numpy() * generated_mask[slice_d_viz, ...].cpu().numpy(), cmap='coolwarm', vmin=vmin_b0, vmax=vmax_b0)
ax.set_title("True B0 Map (Masked)")
ax.axis('off')
fig_b0.colorbar(im, ax=ax, label='Hz', shrink=0.8)

ax = axes_b0[1]
im = ax.imshow(b0_map_dual[slice_d_viz, ...].cpu().numpy() * generated_mask[slice_d_viz, ...].cpu().numpy(), cmap='coolwarm', vmin=vmin_b0, vmax=vmax_b0)
ax.set_title("Dual-Echo B0 Map (Masked)")
ax.axis('off')
fig_b0.colorbar(im, ax=ax, label='Hz', shrink=0.8)

ax = axes_b0[2]
im = ax.imshow(b0_map_multi[slice_d_viz, ...].cpu().numpy() * generated_mask[slice_d_viz, ...].cpu().numpy(), cmap='coolwarm', vmin=vmin_b0, vmax=vmax_b0)
ax.set_title("Multi-Echo Fit B0 Map (Masked)")
ax.axis('off')
fig_b0.colorbar(im, ax=ax, label='Hz', shrink=0.8)

# Difference map (Multi-Echo vs True)
diff_b0_multi = (b0_map_multi - true_b0)[slice_d_viz, ...].cpu().numpy() * generated_mask[slice_d_viz, ...].cpu().numpy()
ax = axes_b0[3]
im = ax.imshow(diff_b0_multi, cmap='coolwarm', vmin=-max_b0_hz_g/10, vmax=max_b0_hz_g/10) # Smaller range for diff
ax.set_title("Difference (Multi-Echo - True)")
ax.axis('off')
fig_b0.colorbar(im, ax=ax, label='Hz', shrink=0.8)

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

## Conclusion

This notebook demonstrated the new recommended workflow for B0 mapping, starting from multi-coil, multi-echo complex image data. 
1. The `preprocess_multi_coil_multi_echo_data` function was used to perform coil combination and robust multi-echo phase unwrapping, yielding pre-processed phase images ready for B0 calculation.
2. These processed phases were then used with the simplified `calculate_b0_map_dual_echo` and `calculate_b0_map_multi_echo_linear_fit` functions to estimate the B0 field map.

This approach modularizes the processing, separating the complex preprocessing steps from the core B0 calculation logic.