In [None]:
import math
import matplotlib.pyplot as plt
import scipy.io as sio
import torch
import triton
import triton.language as tl
from gapat.algorithms import recon
from gapat.processings import negetive_processing

In [None]:
# Define constants
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PI = 3.1415926536
GAUSSIAN_PEAK_INTENSITY = 0.4000222589 / 1000000.0  # 1/(sqrt(2*pi)*0.9973)/1000000
SCALE_FACTOR = 1.0e3  # Scale factor for the raw signal
MIN_HALF_KERNEL_SIZE = 12  # Ensures at least 4 time samples per sigma in the Gaussian kernel
MIN_NUM_BLOCK = 8192  # Minimum number of thread blocks
MAX_SPLIT_K = 1024  # Maximum Split-K partitions

# Define paths
DATA_PATH = "data/sensor_Vessel_data_matrix_1024.mat"
LOCATION_PATH = "data/sensor_Vessel_location_1024.mat"
GPAIR_RESULT_PATH = "results/gpair_v7_Vessel_1024.mat"
UBP_RESULT_PATH = "results/result_ubp_Vessel_1024.mat"

DATA_VAR = "simulation_data"
LOCATION_VAR = "detector_locations"
GPAIR_RESULT_VAR = "x_3d"
UBP_RESULT_VAR = "result_ubp"

# Define variables
x_range = [-25.60e-3, 25.60e-3]
y_range = [-25.60e-3, 25.60e-3]
z_range = [-12.80e-3, 12.80e-3]
res = 0.10e-3
vs = 1500.0
fs = 40.0e6
x_start = x_range[0]
x_end = x_range[1]
y_start = y_range[0]
y_end = y_range[1]
z_start = z_range[0]
z_end = z_range[1]
num_x = int(round((x_end - x_start) / res))
num_y = int(round((y_end - y_start) / res))
num_z = int(round((z_end - z_start) / res))
num_voxels = num_x * num_y * num_z

# Load data
detector_locations = (
    torch.from_numpy(sio.loadmat(LOCATION_PATH)[LOCATION_VAR]).to(DEVICE).contiguous()
)
simulation_data = (
    torch.from_numpy(sio.loadmat(DATA_PATH)[DATA_VAR]).to(DEVICE).contiguous()
    * SCALE_FACTOR
)

# Sort data by detector z-coordinate
sorted_indices = torch.argsort(detector_locations[:, 2], descending=True)
detector_locations = detector_locations[sorted_indices].contiguous()
simulation_data = simulation_data[sorted_indices].contiguous()
num_detectors, num_times = simulation_data.shape

# Detector coordinates use SoA layout for improved coalesced memory access
detector_x = detector_locations[:, 0].contiguous()
detector_y = detector_locations[:, 1].contiguous()
detector_z = detector_locations[:, 2].contiguous()

In [None]:
# Generate convolution kernel
def generate_adaptive_conv_kernel(res, vs, fs):
    time_interval_length_half = int(3.0 * res / vs * fs + 1)
    adaptive_ratio = int(MIN_HALF_KERNEL_SIZE / time_interval_length_half + 1)
    time_interval_length_half = time_interval_length_half * adaptive_ratio
    fs = fs * adaptive_ratio
    const_factor = 2.0 * PI * GAUSSIAN_PEAK_INTENSITY * vs / res
    conv_input = (
        vs
        / fs
        * torch.arange(
            time_interval_length_half,
            -time_interval_length_half - 1,
            -1,
            device=DEVICE,
            dtype=torch.float32,
        )
    )
    conv_kernel = (
        (const_factor * conv_input * torch.exp(-(conv_input**2) / (2.0 * res**2)))
        .unsqueeze(0)
        .unsqueeze(0)
    )
    return conv_kernel, time_interval_length_half, adaptive_ratio


# Compute convolution kernel and upsampling parameters
conv_kernel, time_interval_length_half, adaptive_ratio = generate_adaptive_conv_kernel(
    res, vs, fs
)
num_times_upsampled = num_times * adaptive_ratio
dd_inv_upsampled = fs * adaptive_ratio / vs
stride_out_d = num_times_upsampled
stride_dx_d = num_times_upsampled
print(f"adaptive_ratio={adaptive_ratio}, num_times_upsampled={num_times_upsampled}")

In [None]:
# Generate dynamic Split-K partition count
def generate_dynamic_split_k(num_detectors):
    split_k = max(1, (MIN_NUM_BLOCK + num_detectors - 1) // num_detectors)
    split_k = min(split_k, MAX_SPLIT_K)
    split_k = 2 ** int(math.log2(split_k) + 0.5)
    return split_k


# Compute Split-K partition count
split_k = generate_dynamic_split_k(num_detectors)
print(f"[forward] split_k={split_k}")

In [None]:
# ==================== Autotune Configuration ====================
# Forward operator optimization key points:
# 1) Detector blocking (BLOCK_DET): voxel/x data loaded once, reused for multiple detectors
# 2) On-the-fly voxel coordinate computation (division/modulo): avoids global reads of 3 large voxel_x/y/z arrays
@triton.autotune(
    configs=[
        triton.Config({"BLOCK_VOXEL": 256, "BLOCK_DET": 4}, num_stages=3, num_warps=4),
        triton.Config({"BLOCK_VOXEL": 256, "BLOCK_DET": 8}, num_stages=3, num_warps=8),
        triton.Config({"BLOCK_VOXEL": 512, "BLOCK_DET": 4}, num_stages=4, num_warps=4),
        triton.Config({"BLOCK_VOXEL": 512, "BLOCK_DET": 8}, num_stages=4, num_warps=8),
        triton.Config({"BLOCK_VOXEL": 1024, "BLOCK_DET": 4}, num_stages=4, num_warps=8),
        triton.Config({"BLOCK_VOXEL": 1024, "BLOCK_DET": 8}, num_stages=4, num_warps=8),
        triton.Config({"BLOCK_VOXEL": 2048, "BLOCK_DET": 4}, num_stages=5, num_warps=8),
        triton.Config({"BLOCK_VOXEL": 256, "BLOCK_DET": 16}, num_stages=3, num_warps=8),
        triton.Config({"BLOCK_VOXEL": 512, "BLOCK_DET": 16}, num_stages=4, num_warps=8),
        triton.Config({"BLOCK_VOXEL": 256, "BLOCK_DET": 32}, num_stages=3, num_warps=8),
    ],
    key=["num_voxels", "num_detectors", "num_times_upsampled", "split_k"],
)
@triton.jit
def forward_kernel_splitk(
    x_ptr,  # Input voxel data pointer [num_voxels]
    detector_x_ptr,  # Detector X coordinate pointer [num_detectors] - SoA
    detector_y_ptr,  # Detector Y coordinate pointer [num_detectors] - SoA
    detector_z_ptr,  # Detector Z coordinate pointer [num_detectors] - SoA
    partial_output_ptr,  # [split_k, num_detectors, num_times_upsampled] - independent output per split
    num_voxels: tl.constexpr,  # Total number of voxels
    num_detectors: tl.constexpr,  # Total number of detectors
    num_times_upsampled: tl.constexpr,  # Number of upsampled time points
    dd_inv_upsampled: tl.constexpr,  # Inverse of upsampled single time step distance
    stride_partial_k: tl.constexpr,  # Split-K dimension stride of partial_output
    stride_partial_d: tl.constexpr,  # Detector dimension stride of partial_output
    split_k: tl.constexpr,  # Number of Split-K partitions
    x_start: tl.constexpr,  # Voxel grid start x coordinate
    y_start: tl.constexpr,  # Voxel grid start y coordinate
    z_start: tl.constexpr,  # Voxel grid start z coordinate
    res: tl.constexpr,  # Voxel grid resolution
    num_y: tl.constexpr,  # Voxel grid y dimension size
    num_z: tl.constexpr,  # Voxel grid y dimension size
    BLOCK_VOXEL: tl.constexpr,  # Triton block size for voxels
    BLOCK_DET: tl.constexpr,  # Number of detectors processed per program
):
    """Split-K Triton implementation of forward operator A*x (detector blocking + on-the-fly voxel coordinates)"""
    det_block_idx = tl.program_id(0)
    split_k_idx = tl.program_id(1)

    # Detector block for the current program
    det_start = det_block_idx * BLOCK_DET
    out_start = partial_output_ptr + split_k_idx * stride_partial_k

    # Compute the voxel range for each split
    voxels_per_split = tl.cdiv(num_voxels, split_k)
    v_start_base = split_k_idx * voxels_per_split
    v_end = tl.minimum(v_start_base + voxels_per_split, num_voxels)

    # Precompute num_y * num_z for division
    num_yz = num_y * num_z

    # Software pipeline: num_stages configured by autotune
    for v_start in tl.range(v_start_base, v_end, BLOCK_VOXEL):
        v_offsets = v_start + tl.arange(0, BLOCK_VOXEL)
        v_mask = v_offsets < v_end

        # x loaded once, reused for BLOCK_DET detectors
        x_val = tl.load(x_ptr + v_offsets, mask=v_mask, other=0.0)

        # On-the-fly voxel coordinates (using division/modulo, supports arbitrary dimensions)
        v = v_offsets.to(tl.int32)
        z_idx = v % num_z
        y_idx = (v // num_z) % num_y
        x_idx = v // num_yz

        vox_loc_x = x_start + x_idx * res
        vox_loc_y = y_start + y_idx * res
        vox_loc_z = z_start + z_idx * res

        # Scalar detector loop: avoids register/local memory pressure from 2D broadcasting
        for d in tl.static_range(0, BLOCK_DET):
            det_id = det_start + d
            d_mask = det_id < num_detectors

            det_loc_x = tl.load(detector_x_ptr + det_id, mask=d_mask, other=0.0)
            det_loc_y = tl.load(detector_y_ptr + det_id, mask=d_mask, other=0.0)
            det_loc_z = tl.load(detector_z_ptr + det_id, mask=d_mask, other=0.0)

            dx = det_loc_x - vox_loc_x
            dy = det_loc_y - vox_loc_y
            dz = det_loc_z - vox_loc_z

            dist_sq = dx * dx + dy * dy + dz * dz
            dist_inv = tl.rsqrt(dist_sq)

            time_center = (dist_sq * dist_inv * dd_inv_upsampled + 0.5).to(tl.int32)
            t_mask = (time_center >= 0) & (time_center < num_times_upsampled)

            out_ptr = out_start + det_id * stride_partial_d + time_center
            val_to_add = x_val * dist_inv
            mask = t_mask & v_mask & d_mask
            tl.atomic_add(out_ptr, val_to_add, mask=mask, sem="relaxed")


def forward_operator_triton(x: torch.Tensor) -> torch.Tensor:
    """Optimized Triton forward operator A*x (Split-K + detector blocking + on-the-fly voxel coordinates)"""
    # Allocate partial_output buffer [split_k, num_detectors, num_times_upsampled]
    partial_output = torch.zeros(
        (split_k, num_detectors, num_times_upsampled),
        device=DEVICE,
        dtype=torch.float32,
    )
    stride_partial_k = num_detectors * num_times_upsampled
    stride_partial_d = num_times_upsampled

    # Grid: (detector_blocks, split_k); BLOCK_DET selected by autotune
    grid = lambda META: (triton.cdiv(num_detectors, META["BLOCK_DET"]), split_k)

    forward_kernel_splitk[grid](
        x,
        detector_x,
        detector_y,
        detector_z,
        partial_output,
        num_voxels,
        num_detectors,
        num_times_upsampled,
        dd_inv_upsampled,
        stride_partial_k,
        stride_partial_d,
        split_k,
        x_start,
        y_start,
        z_start,
        res,
        num_y,
        num_z,
    )

    # Reduction: sum along the Split-K dimension
    dy_upsampling_batch = partial_output.sum(dim=0)

    # Convolution operation
    dy_conv_transpose = torch.nn.functional.conv_transpose1d(
        dy_upsampling_batch.unsqueeze(1),
        conv_kernel,
        padding=time_interval_length_half,
    ).squeeze(1)

    # Downsampling and flattening
    y = dy_conv_transpose[:, ::adaptive_ratio].contiguous().flatten()
    return y


# ==================== Autotune Optimized Transpose Kernel ====================
# Note: The transpose operator does not use the Split-K strategy for the following reasons:
# 1. Each voxel is an independent output target, no write conflicts (no atomic_add needed)
# 2. Accumulation is performed in registers, very efficient
# 3. Sufficient voxel count (e.g., 256*256*128=8M), thread blocks already saturate the GPU
# 4. Removing Split-K saves memory allocation and reduction overhead
@triton.autotune(
    configs=[
        triton.Config({"BLOCK_VOXEL": 128}, num_stages=2, num_warps=4),
        triton.Config({"BLOCK_VOXEL": 256}, num_stages=3, num_warps=4),
        triton.Config({"BLOCK_VOXEL": 512}, num_stages=4, num_warps=8),
        triton.Config({"BLOCK_VOXEL": 1024}, num_stages=4, num_warps=8),
    ],
    key=["num_voxels", "num_detectors", "num_times_upsampled"],
)
@triton.jit
def transpose_kernel(
    dx_conv_ptr,  # [num_detectors, num_times_upsampled]
    detector_x_ptr,  # Detector X coordinate pointer [num_detectors] - SoA
    detector_y_ptr,  # Detector Y coordinate pointer [num_detectors] - SoA
    detector_z_ptr,  # Detector Z coordinate pointer [num_detectors] - SoA
    output_ptr,  # [num_voxels], direct output
    num_voxels: tl.constexpr,  # Total number of voxels
    num_detectors: tl.constexpr,  # Total number of detectors
    num_times_upsampled: tl.constexpr,  # Number of upsampled time points
    dd_inv_upsampled: tl.constexpr,  # Inverse of upsampled single time step distance
    stride_dx_d: tl.constexpr,  # Stride of dx_conv
    x_start: tl.constexpr,  # Voxel grid start x coordinate
    y_start: tl.constexpr,  # Voxel grid start y coordinate
    z_start: tl.constexpr,  # Voxel grid start z coordinate
    res: tl.constexpr,  # Voxel grid resolution
    num_y: tl.constexpr,  # Voxel grid y dimension size
    num_z: tl.constexpr,  # Voxel grid z dimension size
    BLOCK_VOXEL: tl.constexpr,  # # Triton block size for voxels
):
    """Triton implementation of transpose operator A^T*x (on-the-fly voxel coordinate computation by index, avoiding reads of voxel_x/y/z)"""
    voxel_block_idx = tl.program_id(0)

    # Voxel block for the current program
    v_start = voxel_block_idx * BLOCK_VOXEL
    v_offsets = v_start + tl.arange(0, BLOCK_VOXEL)
    v_mask = v_offsets < num_voxels

    # Precompute num_y * num_z for division
    num_yz = num_y * num_z

    # On-the-fly voxel coordinates (using division/modulo, supports arbitrary dimensions)
    v = v_offsets.to(tl.int32)
    z_idx = v % num_z
    y_idx = (v // num_z) % num_y
    x_idx = v // num_yz

    vox_loc_x = x_start + x_idx * res
    vox_loc_y = y_start + y_idx * res
    vox_loc_z = z_start + z_idx * res

    # Initialize accumulator to 0
    accum = tl.zeros((BLOCK_VOXEL,), dtype=tl.float32)

    # Iterate over all detectors (no Split-K)
    for det_idx in tl.range(0, num_detectors):
        det_loc_x = tl.load(detector_x_ptr + det_idx)
        det_loc_y = tl.load(detector_y_ptr + det_idx)
        det_loc_z = tl.load(detector_z_ptr + det_idx)

        dx = det_loc_x - vox_loc_x
        dy = det_loc_y - vox_loc_y
        dz = det_loc_z - vox_loc_z

        dist_sq = dx * dx + dy * dy + dz * dz
        dist_inv = tl.rsqrt(dist_sq)

        time_center = (dist_sq * dist_inv * dd_inv_upsampled + 0.5).to(tl.int32)
        t_mask = (time_center >= 0) & (time_center < num_times_upsampled)

        dx_row_ptr = dx_conv_ptr + det_idx * stride_dx_d + time_center
        mask = t_mask & v_mask
        dx_val = tl.load(dx_row_ptr, mask=mask, other=0.0).to(tl.float32)

        val_to_add = dx_val * dist_inv
        accum += tl.where(mask, val_to_add, 0.0)

    tl.store(output_ptr + v_offsets, accum, mask=v_mask)


def transpose_operator_triton(x: torch.Tensor) -> torch.Tensor:
    """Transpose operator A^T*x (Autotune, no Split-K; on-the-fly voxel coordinate computation)"""
    # Convolution part
    x_reshaped = x.reshape(num_detectors, num_times)
    dx_upsampling = torch.zeros(
        (num_detectors, num_times_upsampled), device=DEVICE, dtype=torch.float32
    )
    dx_upsampling[:, ::adaptive_ratio] = x_reshaped

    dx_conv_batch = (
        torch.nn.functional.conv1d(
            dx_upsampling.unsqueeze(1), conv_kernel, padding=time_interval_length_half
        )
        .squeeze(1)
        .contiguous()
    )

    y = torch.empty(num_voxels, device=DEVICE, dtype=torch.float32)

    grid = lambda META: (triton.cdiv(num_voxels, META["BLOCK_VOXEL"]),)

    transpose_kernel[grid](
        dx_conv_batch,
        detector_x,
        detector_y,
        detector_z,
        y,
        num_voxels,
        num_detectors,
        num_times_upsampled,
        dd_inv_upsampled,
        stride_dx_d,
        x_start,
        y_start,
        z_start,
        res,
        num_y,
        num_z,
    )

    return y


# Custom Autograd Function
class A_Operator_Triton_Function(torch.autograd.Function):
    """
    Custom Autograd Function
    """

    @staticmethod
    def forward(ctx, x):
        return forward_operator_triton(x)

    @staticmethod
    def backward(ctx, grad_output):
        return transpose_operator_triton(grad_output)


# Wrap Triton operators
def A_operator_triton(x):
    """
    Wrap Triton operators as Torch functions for automatic differentiation
    """
    return A_Operator_Triton_Function.apply(x)

In [None]:
# Vessel continuity regularization
def vessel_continuity_loss(x, num_x, num_y, num_z, beta=0.1, eps=1e-8):
    """
    Optimized vessel continuity loss
    1. Uses slicing instead of diff+prepend
    2. Reuses computed results
    """
    x_3d = x.reshape(num_x, num_y, num_z)

    # Compute gradients using slicing (avoids prepend overhead)
    grad_x = torch.zeros_like(x_3d)
    grad_y = torch.zeros_like(x_3d)
    grad_z = torch.zeros_like(x_3d)

    grad_x[1:] = x_3d[1:] - x_3d[:-1]
    grad_y[:, 1:] = x_3d[:, 1:] - x_3d[:, :-1]
    grad_z[:, :, 1:] = x_3d[:, :, 1:] - x_3d[:, :, :-1]

    # TV norm
    grad_sq = grad_x.pow(2) + grad_y.pow(2) + grad_z.pow(2)
    tv_norm = torch.sqrt(grad_sq + eps).sum()

    # Second-order derivatives (reusing grad)
    hxx = torch.zeros_like(x_3d)
    hyy = torch.zeros_like(x_3d)
    hzz = torch.zeros_like(x_3d)

    hxx[1:] = grad_x[1:] - grad_x[:-1]
    hyy[:, 1:] = grad_y[:, 1:] - grad_y[:, :-1]
    hzz[:, :, 1:] = grad_z[:, :, 1:] - grad_z[:, :, :-1]

    # Mixed partial derivatives
    hxy = torch.zeros_like(x_3d)
    hxz = torch.zeros_like(x_3d)
    hyz = torch.zeros_like(x_3d)

    hxy[:, 1:] = grad_x[:, 1:] - grad_x[:, :-1]
    hxz[:, :, 1:] = grad_x[:, :, 1:] - grad_x[:, :, :-1]
    hyz[:, :, 1:] = grad_y[:, :, 1:] - grad_y[:, :, :-1]

    # Hessian Frobenius norm
    hessian_sq = (
        hxx.pow(2)
        + hyy.pow(2)
        + hzz.pow(2)
        + 2 * (hxy.pow(2) + hxz.pow(2) + hyz.pow(2))
    )
    hessian_norm = torch.sqrt(hessian_sq + eps).sum()

    return hessian_norm + beta * tv_norm


# Non-negative function
def nonnegative(z, eps=1e-8):
    return (z + eps) * (z + eps)


# Accelerate with torch.compile
vessel_continuity_loss = torch.compile(vessel_continuity_loss)
nonnegative = torch.compile(nonnegative)

In [None]:
# Use parameter transformation to ensure non-negativity constraint
def gradient_descent_reconstruction_nonnegative(
    observed_data,
    max_iterations=50,
    learning_rate=1e-6,
    lambda_reg=1e-4,
    T_0=10,  # First restart cycle length
    T_mult=1,  # Restart cycle multiplier
    eta_min=1e-8,  # Minimum learning rate
    is_print=False,  # Whether to print progress
):
    """
    Iterative reconstruction with parameter transformation ensuring non-negativity constraint
    Optimizes z = sqrt(x) to ensure x = z^2 >= 0
    """
    # Initialize parameter z, actual optimization variable x = z^2
    z = torch.zeros(num_voxels, device=DEVICE, dtype=torch.float32, requires_grad=True)

    # Create optimizer
    optimizer = torch.optim.Adam([z], lr=learning_rate)

    # Create cosine annealing warm restarts learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=T_0, T_mult=T_mult, eta_min=eta_min
    )

    # Create loss list
    losses = []

    # Iterative reconstruction
    for iteration in range(max_iterations):
        # Zero gradients
        optimizer.zero_grad()

        # Ensure non-negativity through squaring transformation
        x = nonnegative(z)

        # Forward pass: y = A * x
        predicted_data = A_operator_triton(x)

        # Compute loss function
        data_fidelity = torch.nn.functional.mse_loss(predicted_data, observed_data)
        regularization = lambda_reg * vessel_continuity_loss(x, num_x, num_y, num_z)
        loss = data_fidelity + regularization

        # Backward pass
        loss.backward()

        # Update parameters
        optimizer.step()

        # Update learning rate
        scheduler.step()

        # Print progress
        if is_print:
            losses.append(loss.item())
            print(
                f"  Iter {iteration:3d}: Loss = {loss.item():.6e}, "
                f"Data Fidelity = {data_fidelity.item():.6e}, "
                f"Regularization = {regularization.item():.6e}, "
            )

    print("Non-negative constrained gradient descent reconstruction completed!")
    x_final = nonnegative(z)
    return x_final, losses

In [None]:
# Main program
print("Starting gradient descent reconstruction...")

# Flatten observed data
b = simulation_data.flatten().contiguous()

# Reset GPU memory statistics
torch.cuda.reset_peak_memory_stats()

# Perform gradient descent reconstruction (with cosine annealing warm restarts and thresholding)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
x_reconstructed, losses = gradient_descent_reconstruction_nonnegative(
    b,
    max_iterations=60,
    learning_rate=8e-2,
    lambda_reg=1e-5,
    T_0=30,  # First restart cycle length
    T_mult=1,  # Restart cycle multiplier
    eta_min=5e-4,  # Minimum learning rate
    is_print=False,  # Whether to print progress
)
end_event.record()
end_event.synchronize()

# Get GPU memory statistics
peak_mem = torch.cuda.max_memory_allocated()

print(f"Gradient descent reconstruction completed! Total time: {start_event.elapsed_time(end_event) / 1000:.3f} s")
print(f"Peak GPU memory usage: {peak_mem / 1024**2:.2f} MB")

# Reshape to 3D image and save
x_3d = x_reconstructed.reshape(num_x, num_y, num_z).detach().cpu().numpy()
sio.savemat(GPAIR_RESULT_PATH, {GPAIR_RESULT_VAR: x_3d})

In [None]:
# Use UBP reconstruction for comparison and save
detector_normals = torch.zeros(detector_locations.shape[0], 3, device=DEVICE)
detector_normals[:, 2] = -1.0
# detector_normals = -detector_locations
result_ubp = recon(
    simulation_data.cpu().numpy(),
    detector_locations.cpu().numpy(),
    detector_normals.cpu().numpy(),
    x_range,
    y_range,
    z_range,
    res,
    vs,
    fs,
    method="ubp",
)
result_ubp = negetive_processing(result_ubp)
sio.savemat(UBP_RESULT_PATH, {UBP_RESULT_VAR: result_ubp})

In [None]:
# Visualize results
plt.figure(figsize=(20, 10))
cmap = "hot"

# Plot UBP reconstruction z-axis MAP (Maximum Amplitude Projection)
plt.subplot(2, 3, 1)
ubp_map = result_ubp.max(axis=2)  # Maximum intensity projection along z-axis
plt.imshow(
    ubp_map.T,
    cmap=cmap,
    aspect="equal",
    origin="lower",
    extent=[x_start * 1e3, x_end * 1e3, y_start * 1e3, y_end * 1e3],
)
plt.colorbar(label="Intensity (a.u.)")
plt.xlabel("X (mm)")
plt.ylabel("Y (mm)")
plt.title("UBP Reconstruction - Z-axis MAP")

# Plot GPAIR reconstruction z-axis MAP (Maximum Amplitude Projection)
plt.subplot(2, 3, 2)
gpair_map = x_3d.max(axis=2)  # Maximum intensity projection along z-axis
plt.imshow(
    gpair_map.T,
    cmap=cmap,
    aspect="equal",
    origin="lower",
    extent=[x_start * 1e3, x_end * 1e3, y_start * 1e3, y_end * 1e3],
)
plt.colorbar(label="Intensity (a.u.)")
plt.xlabel("X (mm)")
plt.ylabel("Y (mm)")
plt.title("GPAIR Reconstruction - Z-axis MAP")

# Plot complete loss curve
plt.subplot(2, 3, 3)
plt.plot(losses, "b-", linewidth=2)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("Complete Loss Curve")
plt.grid(True, alpha=0.3)
plt.yscale("log")

# Plot UBP reconstruction z-axis center slice
plt.subplot(2, 3, 4)
center_z_idx = num_z // 2
ubp_slice = result_ubp[:, :, center_z_idx]
plt.imshow(
    ubp_slice.T,
    cmap=cmap,
    aspect="equal",
    origin="lower",
    extent=[x_start * 1e3, x_end * 1e3, y_start * 1e3, y_end * 1e3],
)
plt.colorbar(label="Intensity (a.u.)")
plt.xlabel("X (mm)")
plt.ylabel("Y (mm)")
plt.title(
    f"UBP Reconstruction - Center Slice (z={z_start*1e3 + center_z_idx*res*1e3:.2f} mm)"
)

# Plot GPAIR reconstruction z-axis center slice
plt.subplot(2, 3, 5)
center_z_idx = num_z // 2
gpair_slice = x_3d[:, :, center_z_idx]
plt.imshow(
    gpair_slice.T,
    cmap=cmap,
    aspect="equal",
    origin="lower",
    extent=[x_start * 1e3, x_end * 1e3, y_start * 1e3, y_end * 1e3],
)
plt.colorbar(label="Intensity (a.u.)")
plt.xlabel("X (mm)")
plt.ylabel("Y (mm)")
plt.title(
    f"GPAIR Reconstruction - Center Slice (z={z_start*1e3 + center_z_idx*res*1e3:.2f} mm)"
)

# Plot loss curve for the last 10 iterations
plt.subplot(2, 3, 6)
last_n = min(10, len(losses))
plt.plot(
    range(len(losses) - last_n, len(losses)),
    losses[-last_n:],
    "r-",
    linewidth=2,
    marker="o",
)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title(f"Last {last_n} Iterations Loss Curve")
plt.grid(True, alpha=0.3)

# Display results
plt.tight_layout()
plt.show()