# 3D Quality-Guided Phase Unwrapping Example

This notebook demonstrates the `unwrap_phase_3d_quality_guided` function from `reconlib.phase_unwrapping` using synthetic 3D phase data.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from reconlib.phase_unwrapping import unwrap_phase_3d_quality_guided
from reconlib.plotting import plot_phase_image, plot_unwrapped_phase_map # Using plot_unwrapped_phase_map

%matplotlib inline

## Helper Function to Generate Synthetic Data

In [None]:
def _wrap_phase_torch(phase_tensor: torch.Tensor) -> torch.Tensor:
    """Wraps phase values to the interval [-pi, pi) using PyTorch operations."""
    pi = getattr(torch, 'pi', np.pi)
    return (phase_tensor + pi) % (2 * pi) - pi

def generate_synthetic_3d_phase_data(shape=(16, 64, 64), ramps_scale=(1.5, 2.0, 2.5), device='cpu'):
    """
    Generates synthetic 3D true and wrapped phase data.
    Creates a sum of 3D linear ramps.
    """
    d, h, w = shape
    pi = getattr(torch, 'pi', np.pi)
    z_coords = torch.linspace(-pi * ramps_scale[0], pi * ramps_scale[0], d, device=device)
    y_coords = torch.linspace(-pi * ramps_scale[1], pi * ramps_scale[1], h, device=device)
    x_coords = torch.linspace(-pi * ramps_scale[2], pi * ramps_scale[2], w, device=device)

    true_phase = z_coords.view(-1, 1, 1) + y_coords.view(1, -1, 1) + x_coords.view(1, 1, -1)
    true_phase = true_phase.expand(d,h,w) # Ensure exact shape

    wrapped_phase = _wrap_phase_torch(true_phase)
    return true_phase, wrapped_phase

## Generate Data and Perform Unwrapping

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

data_shape = (16, 64, 64) # (depth, height, width)
true_phase, wrapped_phase = generate_synthetic_3d_phase_data(shape=data_shape, device=device)

# Perform quality-guided unwrapping
sigma_blur_quality_map = 1.0 # Sigma for Gaussian blur of the quality map
print(f"Running Quality-Guided Unwrapping with sigma_blur={sigma_blur_quality_map}...")
unwrapped_phase_qg = unwrap_phase_3d_quality_guided(wrapped_phase, sigma_blur=sigma_blur_quality_map)
print("Unwrapping complete.")

## Visualize Results

We'll visualize a central slice from the 3D volumes.

In [None]:
slice_idx = data_shape[0] // 2

true_phase_slice = true_phase[slice_idx, :, :].cpu().numpy()
wrapped_phase_slice = wrapped_phase[slice_idx, :, :].cpu().numpy()
unwrapped_phase_qg_slice = unwrapped_phase_qg[slice_idx, :, :].cpu().numpy()

# Account for potential constant offset in unwrapped result for fair comparison
offset_qg = np.mean(unwrapped_phase_qg_slice - true_phase_slice)
unwrapped_phase_qg_slice_corrected = unwrapped_phase_qg_slice - offset_qg
diff_map_qg = unwrapped_phase_qg_slice_corrected - true_phase_slice

fig, axes = plt.subplots(1, 4, figsize=(20, 5))

plot_unwrapped_phase_map(true_phase_slice, title=f"True Phase (Slice {slice_idx})") # Reuse for general data
plt.sca(axes[0]) # Set current axis for plot_unwrapped_phase_map if it doesn't take ax
axes[0].imshow(true_phase_slice, cmap='viridis')
axes[0].set_title(f"True Phase (Slice {slice_idx})")
axes[0].axis('off')
fig.colorbar(axes[0].images[0], ax=axes[0], orientation='horizontal', fraction=0.046, pad=0.04)

plot_phase_image(wrapped_phase_slice, title=f"Wrapped Phase (Slice {slice_idx})")
plt.sca(axes[1])
axes[1].imshow(wrapped_phase_slice, cmap='twilight', vmin=-np.pi, vmax=np.pi)
axes[1].set_title(f"Wrapped Phase (Slice {slice_idx})")
axes[1].axis('off')
fig.colorbar(axes[1].images[0], ax=axes[1], orientation='horizontal', fraction=0.046, pad=0.04)

plot_unwrapped_phase_map(unwrapped_phase_qg_slice_corrected, title=f"QG Unwrapped (Slice {slice_idx}, Corrected)")
plt.sca(axes[2])
axes[2].imshow(unwrapped_phase_qg_slice_corrected, cmap='viridis')
axes[2].set_title(f"QG Unwrapped (Slice {slice_idx}, Corrected)")
axes[2].axis('off')
fig.colorbar(axes[2].images[0], ax=axes[2], orientation='horizontal', fraction=0.046, pad=0.04)

im_diff = axes[3].imshow(diff_map_qg, cmap='coolwarm', vmin=-np.pi/4, vmax=np.pi/4)
axes[3].set_title(f"Difference Map (QG - True, Slice {slice_idx})")
axes[3].axis('off')
fig.colorbar(im_diff, ax=axes[3], orientation='horizontal', fraction=0.046, pad=0.04)

plt.suptitle("Quality-Guided 3D Phase Unwrapping Results", fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

### Notes on Visualization:
The `reconlib.plotting` functions `plot_phase_image` and `plot_unwrapped_phase_map` by default create their own figures. In this notebook, for subplot arrangement, we are directly using `plt.imshow()` after extracting the slice and converting to NumPy. The `reconlib.plotting` functions are still useful for quick single image displays or if they are modified to accept an `ax` argument.