# 3D Goldstein-Style Phase Unwrapping Example

This notebook demonstrates the `unwrap_phase_3d_goldstein` function from `reconlib.phase_unwrapping` using synthetic 3D phase data. This method uses FFT-based k-space filtering.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from reconlib.phase_unwrapping import unwrap_phase_3d_goldstein
from reconlib.plotting import plot_phase_image, plot_unwrapped_phase_map # Using plot_unwrapped_phase_map

%matplotlib inline

## Helper Function to Generate Synthetic Data

In [None]:
def _wrap_phase_torch(phase_tensor: torch.Tensor) -> torch.Tensor:
    """Wraps phase values to the interval [-pi, pi) using PyTorch operations."""
    pi = getattr(torch, 'pi', np.pi)
    return (phase_tensor + pi) % (2 * pi) - pi

def generate_synthetic_3d_phase_data(shape=(16, 64, 64), ramps_scale=(1.5, 2.0, 2.5), device='cpu'):
    """
    Generates synthetic 3D true and wrapped phase data.
    Creates a sum of 3D linear ramps.
    """
    d, h, w = shape
    pi = getattr(torch, 'pi', np.pi)
    z_coords = torch.linspace(-pi * ramps_scale[0], pi * ramps_scale[0], d, device=device)
    y_coords = torch.linspace(-pi * ramps_scale[1], pi * ramps_scale[1], h, device=device)
    x_coords = torch.linspace(-pi * ramps_scale[2], pi * ramps_scale[2], w, device=device)

    true_phase = z_coords.view(-1, 1, 1) + y_coords.view(1, -1, 1) + x_coords.view(1, 1, -1)
    true_phase = true_phase.expand(d,h,w) # Ensure exact shape

    wrapped_phase = _wrap_phase_torch(true_phase)
    return true_phase, wrapped_phase

## Generate Data and Perform Unwrapping

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

data_shape = (16, 64, 64) # (depth, height, width)
true_phase, wrapped_phase = generate_synthetic_3d_phase_data(shape=data_shape, device=device)

# Perform Goldstein-style unwrapping
k_filter_strength = 1.0 # Default strength
print(f"Running Goldstein-Style Unwrapping with k_filter_strength={k_filter_strength}...")
unwrapped_phase_gs = unwrap_phase_3d_goldstein(wrapped_phase, k_filter_strength=k_filter_strength)
print("Unwrapping complete.")

# Example with different filter strength
k_filter_strength_strong = 2.0
print(f"Running Goldstein-Style Unwrapping with k_filter_strength={k_filter_strength_strong}...")
unwrapped_phase_gs_strong = unwrap_phase_3d_goldstein(wrapped_phase, k_filter_strength=k_filter_strength_strong)
print("Unwrapping complete.")

## Visualize Results

We'll visualize a central slice from the 3D volumes for both filter strengths.

In [None]:
slice_idx = data_shape[0] // 2

true_phase_slice = true_phase[slice_idx, :, :].cpu().numpy()
wrapped_phase_slice = wrapped_phase[slice_idx, :, :].cpu().numpy()
unwrapped_phase_gs_slice = unwrapped_phase_gs[slice_idx, :, :].cpu().numpy()
unwrapped_phase_gs_strong_slice = unwrapped_phase_gs_strong[slice_idx, :, :].cpu().numpy()

# Correct for constant offset
offset_gs = np.mean(unwrapped_phase_gs_slice - true_phase_slice)
unwrapped_phase_gs_slice_corrected = unwrapped_phase_gs_slice - offset_gs
diff_map_gs = unwrapped_phase_gs_slice_corrected - true_phase_slice

offset_gs_strong = np.mean(unwrapped_phase_gs_strong_slice - true_phase_slice)
unwrapped_phase_gs_strong_slice_corrected = unwrapped_phase_gs_strong_slice - offset_gs_strong
diff_map_gs_strong = unwrapped_phase_gs_strong_slice_corrected - true_phase_slice

fig, axes = plt.subplots(2, 4, figsize=(20, 10))

# Row 1: k_filter_strength = 1.0
axes[0, 0].imshow(true_phase_slice, cmap='viridis')
axes[0, 0].set_title(f"True Phase (Slice {slice_idx})")
axes[0, 0].axis('off')
fig.colorbar(axes[0,0].images[0], ax=axes[0,0], orientation='horizontal', fraction=0.046, pad=0.08)

axes[0, 1].imshow(wrapped_phase_slice, cmap='twilight', vmin=-np.pi, vmax=np.pi)
axes[0, 1].set_title(f"Wrapped Phase (Slice {slice_idx})")
axes[0, 1].axis('off')
fig.colorbar(axes[0,1].images[0], ax=axes[0,1], orientation='horizontal', fraction=0.046, pad=0.08)

axes[0, 2].imshow(unwrapped_phase_gs_slice_corrected, cmap='viridis')
axes[0, 2].set_title(f"GS Unwrapped (k={k_filter_strength}, Corrected)")
axes[0, 2].axis('off')
fig.colorbar(axes[0,2].images[0], ax=axes[0,2], orientation='horizontal', fraction=0.046, pad=0.08)

im_diff_gs = axes[0, 3].imshow(diff_map_gs, cmap='coolwarm', vmin=-np.pi/4, vmax=np.pi/4)
axes[0, 3].set_title(f"Difference (GS k={k_filter_strength} - True)")
axes[0, 3].axis('off')
fig.colorbar(im_diff_gs, ax=axes[0,3], orientation='horizontal', fraction=0.046, pad=0.08)

# Row 2: k_filter_strength = 2.0 (stronger)
axes[1, 0].imshow(true_phase_slice, cmap='viridis') # True phase for reference
axes[1, 0].set_title(f"True Phase (Slice {slice_idx})")
axes[1, 0].axis('off')
fig.colorbar(axes[1,0].images[0], ax=axes[1,0], orientation='horizontal', fraction=0.046, pad=0.08)

axes[1, 1].imshow(wrapped_phase_slice, cmap='twilight', vmin=-np.pi, vmax=np.pi) # Wrapped phase for reference
axes[1, 1].set_title(f"Wrapped Phase (Slice {slice_idx})")
axes[1, 1].axis('off')
fig.colorbar(axes[1,1].images[0], ax=axes[1,1], orientation='horizontal', fraction=0.046, pad=0.08)

axes[1, 2].imshow(unwrapped_phase_gs_strong_slice_corrected, cmap='viridis')
axes[1, 2].set_title(f"GS Unwrapped (k={k_filter_strength_strong}, Corrected)")
axes[1, 2].axis('off')
fig.colorbar(axes[1,2].images[0], ax=axes[1,2], orientation='horizontal', fraction=0.046, pad=0.08)

im_diff_gs_strong = axes[1, 3].imshow(diff_map_gs_strong, cmap='coolwarm', vmin=-np.pi/4, vmax=np.pi/4)
axes[1, 3].set_title(f"Difference (GS k={k_filter_strength_strong} - True)")
axes[1, 3].axis('off')
fig.colorbar(im_diff_gs_strong, ax=axes[1,3], orientation='horizontal', fraction=0.046, pad=0.08)

plt.suptitle("Goldstein-Style 3D Phase Unwrapping Results", fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()