# Multi-Coil, Multi-Echo Phase Preprocessing Pipeline Example

This notebook demonstrates the full pipeline from raw multi-coil, multi-echo complex image data to coil-combined, masked, and unwrapped phase images. These processed phase images are suitable for subsequent B0 mapping or Quantitative Susceptibility Mapping (QSM).

The key function showcased is `preprocess_multi_coil_multi_echo_data` from `reconlib.pipeline_utils`, which internally uses `combine_coils_complex_sum` for coil combination and `unwrap_multi_echo_masked_reference` for reference-based multi-echo phase unwrapping.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from reconlib.pipeline_utils import preprocess_multi_coil_multi_echo_data
from reconlib.phase_unwrapping import unwrap_phase_3d_quality_guided # Example spatial unwrapper
from reconlib.utils import combine_coils_complex_sum # For visualizing intermediate combined phase

%matplotlib inline

## Setup

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
pi_val = getattr(torch, 'pi', np.pi) # For plotting limits

## Synthetic Data Generation

We'll create a helper function to generate multi-coil, multi-echo complex data. This is adapted from the unit tests.

In [None]:
def _wrap_phase_torch(phase_tensor: torch.Tensor) -> torch.Tensor:
    """Wraps phase values to the interval [-pi, pi) using PyTorch operations."""
    pi = getattr(torch, 'pi', np.pi)
    return (phase_tensor + pi) % (2 * pi) - pi

def _generate_synthetic_3d_phase(shape_spatial, ramps_scale, device):
    """Generates a base 3D true phase and its wrapped version."""
    d, h, w = shape_spatial
    pi = getattr(torch, 'pi', np.pi)
    z_coords = torch.linspace(-pi * ramps_scale[0], pi * ramps_scale[0], d, device=device)
    y_coords = torch.linspace(-pi * ramps_scale[1], pi * ramps_scale[1], h, device=device)
    x_coords = torch.linspace(-pi * ramps_scale[2], pi * ramps_scale[2], w, device=device)
    true_phase = z_coords.view(-1, 1, 1) + y_coords.view(1, -1, 1) + x_coords.view(1, 1, -1)
    true_phase = true_phase.expand(d,h,w)
    return true_phase, _wrap_phase_torch(true_phase)

def generate_synthetic_multicoil_multiecho_data_notebook(
    spatial_shape=(8, 32, 32), # D, H, W
    num_echoes=3,
    num_coils=4,
    snr_thresh_for_mag_pattern=0.2, # Magnitude value to define mask pattern
    base_ramp_scales=(1.5, 1.8, 2.0), # For spatial wraps in echo 1 (combined)
    diff_ramp_scales=(0.3, 0.35, 0.4),  # For evolving pattern in subsequent echoes (combined)
    device='cpu'
):
    """
    Generates synthetic multi-coil, multi-echo complex data for notebook demonstration.
    Outputs:
        - multi_coil_complex_images (num_echoes, num_coils, D, H, W)
        - true_unwrapped_coil_combined_phases (num_echoes, D, H, W)
        - true_coil_combined_magnitudes (num_echoes, D, H, W)
        - mock_csms (num_coils, D, H, W)
    """
    true_unwrapped_coil_combined_phases = torch.zeros((num_echoes,) + spatial_shape, device=device)
    true_coil_combined_magnitudes = torch.zeros((num_echoes,) + spatial_shape, device=device)

    # Echo 1: Phase has spatial wraps, Magnitude is structured for mask generation
    true_phase_echo1_combined, _ = _generate_synthetic_3d_phase(
        spatial_shape, ramps_scale=base_ramp_scales, device=device
    )
    true_unwrapped_coil_combined_phases[0, ...] = true_phase_echo1_combined
    
    mag_echo1_combined = torch.full(spatial_shape, snr_thresh_for_mag_pattern / 2, device=device)
    d, h, w = spatial_shape
    slice_d, slice_h, slice_w = d//4, h//4, w//4 # Central block for higher magnitude
    mag_echo1_combined[slice_d:-slice_d, slice_h:-slice_h, slice_w:-slice_w] = snr_thresh_for_mag_pattern * 2
    true_coil_combined_magnitudes[0, ...] = mag_echo1_combined

    # Subsequent echoes: Evolving phase and slightly decaying magnitude
    for i in range(1, num_echoes):
        current_diff_ramp_scales = tuple(s * (1 + 0.25*i) for s in diff_ramp_scales) 
        evolving_pattern, _ = _generate_synthetic_3d_phase(
            spatial_shape, ramps_scale=current_diff_ramp_scales, device=device
        )
        evolving_pattern -= torch.mean(evolving_pattern) # Center the evolving pattern
        true_unwrapped_coil_combined_phases[i, ...] = true_unwrapped_coil_combined_phases[0, ...] + evolving_pattern
        true_coil_combined_magnitudes[i, ...] = true_coil_combined_magnitudes[0, ...] * (1 - 0.15 * i) # Magnitude decay

    # Create Mock Coil Sensitivity Maps (CSMs)
    mock_csms = torch.zeros((num_coils,) + spatial_shape, dtype=torch.complex64, device=device)
    for c in range(num_coils):
        csm_phase_offset, _ = _generate_synthetic_3d_phase(spatial_shape, ramps_scale=(0.1*(c+1), 0.1*(c+1), 0.1*(c+1)), device=device)
        csm_mag_profile, _ = _generate_synthetic_3d_phase(spatial_shape, ramps_scale=(0.25, 0.25, 0.25), device=device)
        csm_mag_profile = (torch.cos(csm_mag_profile) + 1.5) / 2.5 # Ensure positive and varying
        mock_csms[c, ...] = csm_mag_profile * torch.exp(1j * csm_phase_offset)
    rss_csms = torch.sqrt(torch.sum(torch.abs(mock_csms)**2, dim=0, keepdim=True)) + 1e-9
    mock_csms_normalized = mock_csms / rss_csms
    
    # Synthesize Multi-Coil Complex Images
    multi_coil_complex_images = torch.zeros((num_echoes, num_coils) + spatial_shape, dtype=torch.complex64, device=device)
    for e in range(num_echoes):
        # Ground truth combined complex signal for echo e
        true_combined_complex_signal_e = true_coil_combined_magnitudes[e] * torch.exp(1j * true_unwrapped_coil_combined_phases[e])
        for c in range(num_coils):
            # Distribute the combined signal to coils using CSMs (simplified model)
            multi_coil_complex_images[e, c, ...] = true_combined_complex_signal_e * mock_csms_normalized[c, ...]
            
    return multi_coil_complex_images, true_unwrapped_coil_combined_phases, true_coil_combined_magnitudes, mock_csms_normalized

## Pipeline Execution

In [None]:
# Define parameters for data generation and pipeline
num_echoes_g = 3
num_coils_g = 4
spatial_shape_g = (8, 48, 48) # D, H, W - smaller for quicker processing
snr_threshold_for_mask_g = 0.1 # Threshold for mask generation

multi_coil_images_torch, true_unwrapped_phases_torch, true_magnitudes_torch, mock_csms_viz = \
    generate_synthetic_multicoil_multiecho_data_notebook(
        spatial_shape=spatial_shape_g,
        num_echoes=num_echoes_g,
        num_coils=num_coils_g,
        snr_thresh_for_mag_pattern=snr_threshold_for_mask_g * 1.5, # Ensure pattern is above threshold
        device=device
    )

print(f"Generated multi-coil complex images shape: {multi_coil_images_torch.shape}")
print(f"Generated true unwrapped combined phases shape: {true_unwrapped_phases_torch.shape}")
print(f"Generated true combined magnitudes shape: {true_magnitudes_torch.shape}")

# Define the spatial unwrapper to use
# The spatial_unwrap_fn for unwrap_multi_echo_masked_reference expects (phase, mask)
# unwrap_phase_3d_quality_guided takes (phase, quality_metric_str, sigma_blur)
# We create a lambda to adapt it, assuming default quality metric and sigma, and applying the mask.
def quality_guided_with_mask_adapter(phase_tensor, mask_tensor):
    unwrapped_phase = unwrap_phase_3d_quality_guided(phase_tensor, sigma_blur=1.0)
    return unwrapped_phase * mask_tensor.float() # Apply mask to the output of unwrapper

spatial_unwrapper_to_use = quality_guided_with_mask_adapter

print("\nRunning preprocess_multi_coil_multi_echo_data...")
final_unwrapped_phases, final_mask, stacked_combined_magnitudes = preprocess_multi_coil_multi_echo_data(
    multi_coil_images_torch,
    snr_threshold_for_mask_g,
    spatial_unwrapper_to_use
)

print(f"\nOutput unwrapped phases shape: {final_unwrapped_phases.shape}")
print(f"Output mask shape: {final_mask.shape}")
print(f"Output combined magnitudes shape: {stacked_combined_magnitudes.shape}")

## Visualization

In [None]:
slice_d_idx = spatial_shape_g[0] // 2
echo_idx_viz = 1 # Visualize the second echo (index 1)

fig, axes = plt.subplots(2, 3, figsize=(18, 10))
fig.suptitle(f"Multi-Coil Multi-Echo Preprocessing Pipeline Results (Echo: {echo_idx_viz}, Slice: {slice_d_idx})", fontsize=16)

# 1. Show magnitude of one coil for one echo (optional)
coil_idx_viz = 0
ax = axes[0,0]
im = ax.imshow(torch.abs(multi_coil_images_torch[echo_idx_viz, coil_idx_viz, slice_d_idx, ...]).cpu().numpy(), cmap='gray')
ax.set_title(f"Mag: Echo {echo_idx_viz}, Coil {coil_idx_viz}, Slice {slice_d_idx}")
ax.axis('off')
fig.colorbar(im, ax=ax, shrink=0.8)

# 2. Show combined magnitude (from pipeline output)
ax = axes[0,1]
im = ax.imshow(stacked_combined_magnitudes[echo_idx_viz, slice_d_idx, ...].cpu().numpy(), cmap='gray')
ax.set_title(f"Combined Mag (Echo {echo_idx_viz}, Slice {slice_d_idx})")
ax.axis('off')
fig.colorbar(im, ax=ax, shrink=0.8)

# 3. Show generated SNR mask
ax = axes[0,2]
im = ax.imshow(final_mask[slice_d_idx, ...].cpu().numpy(), cmap='gray')
ax.set_title(f"Generated SNR Mask (Slice {slice_d_idx})")
ax.axis('off')

# 4. Show true unwrapped combined phase (ground truth)
ax = axes[1,0]
im = ax.imshow(true_unwrapped_phases_torch[echo_idx_viz, slice_d_idx, ...].cpu().numpy() * final_mask[slice_d_idx, ...].cpu().numpy(), cmap='viridis')
ax.set_title(f"True Unwrapped Phase (Echo {echo_idx_viz}, Masked)")
ax.axis('off')
fig.colorbar(im, ax=ax, label='rad', shrink=0.8)

# 5. Show final unwrapped phase from pipeline
ax = axes[1,1]
im = ax.imshow(final_unwrapped_phases[echo_idx_viz, slice_d_idx, ...].cpu().numpy() * final_mask[slice_d_idx, ...].cpu().numpy(), cmap='viridis')
ax.set_title(f"Pipeline Unwrapped Phase (Echo {echo_idx_viz}, Masked)")
ax.axis('off')
fig.colorbar(im, ax=ax, label='rad', shrink=0.8)

# 6. Show difference map
true_masked = true_unwrapped_phases_torch[echo_idx_viz, slice_d_idx, ...].cpu().numpy() * final_mask[slice_d_idx, ...].cpu().numpy()
calc_masked = final_unwrapped_phases[echo_idx_viz, slice_d_idx, ...].cpu().numpy() * final_mask[slice_d_idx, ...].cpu().numpy()
diff_map = calc_masked - true_masked
# Correct for global offset in difference map within the mask
mask_slice_np = final_mask[slice_d_idx, ...].cpu().numpy()
if mask_slice_np.sum() > 0:
    diff_map_offset_corrected = diff_map - np.mean(diff_map[mask_slice_np])
else:
    diff_map_offset_corrected = diff_map

ax = axes[1,2]
im = ax.imshow(diff_map_offset_corrected * mask_slice_np, cmap='coolwarm', vmin=-pi_val/4, vmax=pi_val/4)
ax.set_title(f"Difference Map (Echo {echo_idx_viz}, Masked, Offset-Corr.)")
ax.axis('off')
fig.colorbar(im, ax=ax, label='rad', shrink=0.8)

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

## Conclusion

This notebook demonstrated the use of `preprocess_multi_coil_multi_echo_data` to process synthetic multi-coil, multi-echo complex image data. The pipeline performs coil combination and multi-echo phase unwrapping using a reference echo strategy, producing unwrapped phase images for each echo and a common brain mask. The visualizations show the intermediate combined magnitude, the generated mask, and the final unwrapped phase compared to the ground truth, along with a difference map highlighting the accuracy within the masked region.