In [1]:
# Import required libraries
import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio
import torch

from scipy.sparse.linalg import lsqr
from scipy.sparse.linalg import LinearOperator

from gapat.algorithms import recon
from gapat.processings import negetive_processing

In [2]:
# Define constants
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SQRT_3 = 1.7321
PI = 3.1415926536
N_SEARCH_SPACE_SIZE = 100
UPSAMPLE_FACTOR = 2

# Define paths
DATA_PATH = "data/sensor_Liver_data_matrix.mat"
LOCATION_PATH = "data/sensor_Liver_location.mat"
MBPD_RESULT_PATH = "results/result_mbpd_Liver.mat"
SI_TABLE_PATH = "data/sphere_integral_table.mat"
SI_GRAD_TABLE_PATH = "data/sphere_integral_gradd_table.mat"

DATA_VAR = "simulation_data"
LOCATION_VAR = "detector_locations"
MBPD_RESULT_VAR = "x_volume"
SI_TABLE_VAR = "sphere_integral_table"
SI_GRAD_TABLE_VAR = "sphere_integral_gradd_table"

# Define variables
x_range = [-12.80e-3, 12.80e-3]
y_range = [-12.80e-3, 12.80e-3]
z_range = [-6.40e-3, 6.40e-3]
res = 0.10e-3
vs = 1510.0
fs = 8.333333e6
fs = fs * UPSAMPLE_FACTOR
dt = 1 / fs
x_start = x_range[0]
x_end = x_range[1]
y_start = y_range[0]
y_end = y_range[1]
z_start = z_range[0]
z_end = z_range[1]
num_x = int(round((x_end - x_start) / res))
num_y = int(round((y_end - y_start) / res))
num_z = int(round((z_end - z_start) / res))
num_voxels = num_x * num_y * num_z
X, Y, Z = torch.meshgrid(
    torch.linspace(x_start, x_end - res, num_x, device=DEVICE),
    torch.linspace(y_start, y_end - res, num_y, device=DEVICE),
    torch.linspace(z_start, z_end - res, num_z, device=DEVICE),
    indexing="ij",
)
voxel_locations = torch.stack([X.flatten(), Y.flatten(), Z.flatten()], dim=1)
time_interval_length = int(2.0 * SQRT_3 * res / vs * fs) + 1
alpha_interval = PI / 2.0 / N_SEARCH_SPACE_SIZE
beta_interval = PI / 2.0 / N_SEARCH_SPACE_SIZE
d_interval = res * 4.0 / N_SEARCH_SPACE_SIZE

# Load simulation data and convert to torch.Tensor
detector_locations = (
    torch.from_numpy(sio.loadmat(LOCATION_PATH)[LOCATION_VAR]).to(DEVICE).contiguous()
)

# Linear interpolation along the time dimension
_sim_np = sio.loadmat(DATA_PATH)[DATA_VAR]
_sim_torch = torch.from_numpy(_sim_np)
_sim_torch = torch.nn.functional.interpolate(
    _sim_torch.unsqueeze(0),
    scale_factor=UPSAMPLE_FACTOR,
    mode="linear",
    align_corners=True,
).squeeze(0)
simulation_data = _sim_torch.to(DEVICE).contiguous()

# Sort data by detector z-coordinate
sorted_indices = torch.argsort(detector_locations[:, 2], descending=True)
detector_locations = detector_locations[sorted_indices].contiguous()
simulation_data = simulation_data[sorted_indices].contiguous()
num_detectors, num_times = simulation_data.shape

# Load sphere integral lookup tables and convert to torch.Tensor
sphere_integral_table = (
    torch.from_numpy(sio.loadmat(SI_TABLE_PATH)[SI_TABLE_VAR]).to(DEVICE).contiguous()
)
sphere_integral_gradd_table = (
    torch.from_numpy(sio.loadmat(SI_GRAD_TABLE_PATH)[SI_GRAD_TABLE_VAR])
    .to(DEVICE)
    .contiguous()
)

In [3]:
# Define forward operator
@torch.inference_mode()
def _forward_operator(x: torch.Tensor) -> torch.Tensor:
    """
    Forward operator A * x
    """
    y = torch.zeros(num_detectors * num_times, device=DEVICE)
    for det_idx in range(num_detectors):
        voxel_to_det_vectors = detector_locations[det_idx, :] - voxel_locations
        voxel_to_det_distances = torch.norm(voxel_to_det_vectors, dim=1)
        time_starts = torch.floor(
            (voxel_to_det_distances - SQRT_3 * res) / vs * fs
        ).int()
        alpha_idxs = (
            (
                torch.atan(voxel_to_det_vectors[:, 1] / (voxel_to_det_vectors[:, 0]))
                / alpha_interval
            )
            .abs()
            .int()
            .clamp(0, N_SEARCH_SPACE_SIZE - 1)
        )
        beta_idxs = (
            (
                torch.atan(
                    voxel_to_det_vectors[:, 2]
                    / (torch.norm(voxel_to_det_vectors[:, :2], dim=1))
                )
                / beta_interval
            )
            .abs()
            .int()
            .clamp(0, N_SEARCH_SPACE_SIZE - 1)
        )
        for i in range(time_interval_length):
            time_idxs = time_starts + i
            d_idxs = (
                ((voxel_to_det_distances - vs * time_idxs * dt) / d_interval).int()
                + N_SEARCH_SPACE_SIZE // 2
            ).clamp(0, N_SEARCH_SPACE_SIZE - 1)
            I = sphere_integral_table[d_idxs, alpha_idxs, beta_idxs]
            dI = sphere_integral_gradd_table[d_idxs, alpha_idxs, beta_idxs]
            y.scatter_add_(
                0,
                (det_idx * num_times + time_idxs).long(),
                -vs / (vs * dt * time_idxs) * (I / (vs * dt * time_idxs) + dI) * x,
            )
    return y

In [4]:
# Define transpose operator
@torch.inference_mode()
def _transpose_operator(x: torch.Tensor) -> torch.Tensor:
    """
    Transpose operator A^T * x
    """
    y = torch.zeros(num_voxels, device=DEVICE)
    for det_idx in range(num_detectors):
        voxel_to_det_vectors = detector_locations[det_idx, :] - voxel_locations
        voxel_to_det_distances = torch.norm(voxel_to_det_vectors, dim=1)
        time_starts = torch.floor(
            (voxel_to_det_distances - SQRT_3 * res) / vs * fs
        ).int()
        alpha_idxs = (
            (
                torch.atan(voxel_to_det_vectors[:, 1] / (voxel_to_det_vectors[:, 0]))
                / alpha_interval
            )
            .abs()
            .int()
            .clamp(0, N_SEARCH_SPACE_SIZE - 1)
        )
        beta_idxs = (
            (
                torch.atan(
                    voxel_to_det_vectors[:, 2]
                    / (torch.norm(voxel_to_det_vectors[:, :2], dim=1))
                )
                / beta_interval
            )
            .abs()
            .int()
            .clamp(0, N_SEARCH_SPACE_SIZE - 1)
        )
        for i in range(time_interval_length):
            time_idxs = time_starts + i
            d_idxs = (
                ((voxel_to_det_distances - vs * time_idxs * dt) / d_interval).int()
                + N_SEARCH_SPACE_SIZE // 2
            ).clamp(0, N_SEARCH_SPACE_SIZE - 1)
            I = sphere_integral_table[d_idxs, alpha_idxs, beta_idxs]
            dI = sphere_integral_gradd_table[d_idxs, alpha_idxs, beta_idxs]
            y += (
                -vs
                / (vs * dt * time_idxs)
                * (I / (vs * dt * time_idxs) + dI)
                * x[det_idx * num_times + time_idxs]
            )
    return y

In [5]:
# Define the linear operator required by the LSQR algorithm
def Ax_ATx_operator(x, flag):
    """
    GPU-optimized linear operator for computing A and A^T matrix-vector products.

    Args:
        x: torch.Tensor, input vector
        flag: str, computation direction, "notransp" for A * x, "transp" for A^T * x

    Returns:
        y: numpy.ndarray, output vector
    """
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x).to(dtype=torch.float32, device=DEVICE)
    else:
        x = x.to(dtype=torch.float32, device=DEVICE)

    if flag == "notransp":
        # Forward operation: A * x (batched GPU implementation)
        result = _forward_operator(x)
        return result.cpu().numpy()
    elif flag == "transp":
        # Transpose operation: A^T * x (batched GPU implementation)
        result = _transpose_operator(x)
        return result.cpu().numpy()
    else:
        raise ValueError("flag must be 'notransp' or 'transp'")


A_operator = LinearOperator(
    shape=(num_detectors * num_times, num_voxels),
    matvec=lambda x: Ax_ATx_operator(x, "notransp"),
    rmatvec=lambda x: Ax_ATx_operator(x, "transp"),
    dtype=np.float32,
)

In [6]:
# Main program
print("=" * 60)
print("Starting 3D Image Reconstruction Using LSQR Algorithm")
print("=" * 60)

# Data preprocessing: convert simulation data to column vector
print("Data preprocessing...")
b = simulation_data.flatten().cpu().numpy().astype(np.float32)
print(f"Observation data shape: {b.shape}")
print(f"Observation data range: [{b.min():.6f}, {b.max():.6f}]")
print(f"Observation data non-zero elements: {np.count_nonzero(b)}/{len(b)}")

# Set LSQR algorithm parameters
max_iter = 100  # Maximum number of iterations
tolerance = 1.0e-5  # Convergence tolerance
print(f"\nLSQR parameter settings:")
print(f"Maximum iterations: {max_iter}")
print(f"Convergence tolerance: {tolerance}")
print(f"Number of reconstruction voxels: {num_voxels}")

# Reset GPU memory statistics
torch.cuda.reset_peak_memory_stats()

# Perform LSQR reconstruction
print(f"\nStarting LSQR reconstruction...")

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()

x_reconstructed, exit_reason, itn, r1norm, r2norm, anorm, acond, arnorm, xnorm, var = (
    lsqr(
        A_operator,
        b,
        iter_lim=max_iter,
        atol=tolerance,
        btol=tolerance,
        show=False,  # Show iteration progress
        damp=3e2,
    )
)

end_event.record()
end_event.synchronize()

# Get GPU memory statistics
peak_mem = torch.cuda.max_memory_allocated()

print(f"\nLSQR reconstruction completed!")
print(f"Reconstruction time: {start_event.elapsed_time(end_event) / 1000:.3f} s")
print(f"Peak GPU memory usage: {peak_mem / 1024**2:.2f} MB")
print(f"Number of iterations: {itn}")
print(f"Exit reason: {exit_reason}")
print(f"Residual norm: {r1norm:.6e}")
print(f"Solution norm: {xnorm:.6e}")

# Reshape reconstruction result to 3D volume
x_volume = x_reconstructed.reshape(num_x, num_y, num_z)
x_volume = negetive_processing(x_volume)
sio.savemat(MBPD_RESULT_PATH, {MBPD_RESULT_VAR: x_volume})
losses = []
print(f"\nReconstruction result statistics:")
print(f"Reconstructed volume shape: {x_volume.shape}")
print(f"Reconstructed value range: [{x_volume.min():.6f}, {x_volume.max():.6f}]")
print(f"Reconstructed value mean: {x_volume.mean():.6f}")
print(f"Reconstructed value std: {x_volume.std():.6f}")

In [7]:
# UBP reconstruction
# detector_normals = torch.zeros(detector_locations.shape[0], 3, device=DEVICE)
# detector_normals[:, 2] = -1.0
detector_normals = -detector_locations
result_ubp = recon(
    simulation_data.cpu().numpy(),
    detector_locations.cpu().numpy(),
    detector_normals.cpu().numpy(),
    x_range,
    y_range,
    z_range,
    res,
    vs,
    fs,
    method="ubp",
)
result_ubp = negetive_processing(result_ubp)

In [8]:
# Visualize results
plt.figure(figsize=(20, 10))
cmap = "hot"

# Plot UBP reconstruction z-axis MAP (Maximum Amplitude Projection)
plt.subplot(2, 3, 1)
ubp_map = result_ubp.max(axis=2)  # Maximum intensity projection along z-axis
plt.imshow(
    ubp_map.T,
    cmap=cmap,
    aspect="equal",
    origin="lower",
    extent=[x_start * 1e3, x_end * 1e3, y_start * 1e3, y_end * 1e3],
)
plt.colorbar(label="Intensity (a.u.)")
plt.xlabel("X (mm)")
plt.ylabel("Y (mm)")
plt.title("UBP Reconstruction - Z-axis MAP")

# Plot MB-PD reconstruction z-axis MAP (Maximum Amplitude Projection)
plt.subplot(2, 3, 2)
mbpd_map = x_volume.max(axis=2)  # Maximum intensity projection along z-axis
plt.imshow(
    mbpd_map.T,
    cmap=cmap,
    aspect="equal",
    origin="lower",
    extent=[x_start * 1e3, x_end * 1e3, y_start * 1e3, y_end * 1e3],
)
plt.colorbar(label="Intensity (a.u.)")
plt.xlabel("X (mm)")
plt.ylabel("Y (mm)")
plt.title("MB-PD Reconstruction - Z-axis MAP")

# Plot complete loss curve
plt.subplot(2, 3, 3)
plt.plot(losses, "b-", linewidth=2)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("Complete Loss Curve")
plt.grid(True, alpha=0.3)
plt.yscale("log")

# Plot UBP reconstruction z-axis center slice
plt.subplot(2, 3, 4)
center_z_idx = num_z // 2
ubp_slice = result_ubp[:, :, center_z_idx]
plt.imshow(
    ubp_slice.T,
    cmap=cmap,
    aspect="equal",
    origin="lower",
    extent=[x_start * 1e3, x_end * 1e3, y_start * 1e3, y_end * 1e3],
)
plt.colorbar(label="Intensity (a.u.)")
plt.xlabel("X (mm)")
plt.ylabel("Y (mm)")
plt.title(
    f"UBP Reconstruction - Center Slice (z={z_start*1e3 + center_z_idx*res*1e3:.2f} mm)"
)

# Plot MB-PD reconstruction z-axis center slice
plt.subplot(2, 3, 5)
center_z_idx = num_z // 2
mbpd_slice = x_volume[:, :, center_z_idx]
plt.imshow(
    mbpd_slice.T,
    cmap=cmap,
    aspect="equal",
    origin="lower",
    extent=[x_start * 1e3, x_end * 1e3, y_start * 1e3, y_end * 1e3],
)
plt.colorbar(label="Intensity (a.u.)")
plt.xlabel("X (mm)")
plt.ylabel("Y (mm)")
plt.title(
    f"MB-PD Reconstruction - Center Slice (z={z_start*1e3 + center_z_idx*res*1e3:.2f} mm)"
)

# Plot loss curve for the last 10 iterations
plt.subplot(2, 3, 6)
last_n = min(10, len(losses))
plt.plot(
    range(len(losses) - last_n, len(losses)),
    losses[-last_n:],
    "r-",
    linewidth=2,
    marker="o",
)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title(f"Last {last_n} Iterations Loss Curve")
plt.grid(True, alpha=0.3)

# Display results
plt.tight_layout()
plt.show()