In [None]:
import scipy.io as sio
import matplotlib.pyplot as plt
import torch

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
N_QUAD_POINT_PER_EDGE = 1000
N_SEARCH_SPACE_SIZE = 100
PI = 3.1415926536
RES = 0.10e-3
BOUNDARY = 2.0 * RES

In [None]:
def trilinear_interp(coords: torch.Tensor) -> torch.Tensor:
    """
    Optimized trilinear interpolation function

    Args:
        coords: coordinates, shape (N, 3)

    Returns:
        weights: interpolation weights, shape (N,)
    """
    coords_abs = coords.abs() / RES
    valid_mask = (coords_abs < 1.0).all(dim=1)
    weights = (1.0 - coords_abs).prod(dim=1).abs()
    return valid_mask * weights

In [None]:
def interp_plane_rotate(
    coords: torch.Tensor, alpha: torch.Tensor, beta: torch.Tensor
) -> torch.Tensor:
    sa, ca = torch.sin(alpha), torch.cos(alpha)
    sb, cb = torch.sin(beta), torch.cos(beta)
    R = torch.tensor(
        [
            [ca * sb, -sa, ca * cb],
            [sa * sb, ca, sa * cb],
            [-cb, 0.0, sb],
        ],
        device=DEVICE,
    )
    norm_vector = torch.tensor([ca * cb, sa * cb, sb], device=DEVICE)
    return coords @ R.T, norm_vector

In [None]:
x = torch.linspace(
    -BOUNDARY,
    BOUNDARY * (1 - 2 / N_QUAD_POINT_PER_EDGE),
    N_QUAD_POINT_PER_EDGE,
    device=DEVICE,
)
y = torch.linspace(
    -BOUNDARY,
    BOUNDARY * (1 - 2 / N_QUAD_POINT_PER_EDGE),
    N_QUAD_POINT_PER_EDGE,
    device=DEVICE,
)
X, Y = torch.meshgrid(x, y, indexing="ij")
Z = torch.zeros_like(X)
coords = torch.stack([X.flatten(), Y.flatten(), Z.flatten()], dim=1)

**Computing the sphere integral lookup table on GPU takes approximately 70 minutes.**

In [None]:
alpha = torch.linspace(
    0, PI / 2.0 * (1 - 1 / N_SEARCH_SPACE_SIZE), N_SEARCH_SPACE_SIZE, device=DEVICE
)
beta = torch.linspace(
    0, PI / 2.0 * (1 - 1 / N_SEARCH_SPACE_SIZE), N_SEARCH_SPACE_SIZE, device=DEVICE
)
d = torch.linspace(
    -BOUNDARY,
    BOUNDARY * (1 - 2 / N_SEARCH_SPACE_SIZE),
    N_SEARCH_SPACE_SIZE,
    device=DEVICE,
)
sphere_integral_table = torch.zeros(
    (N_SEARCH_SPACE_SIZE, N_SEARCH_SPACE_SIZE, N_SEARCH_SPACE_SIZE), device=DEVICE
)

for i in range(N_SEARCH_SPACE_SIZE):
    print(f"Processing {i}")
    for j in range(N_SEARCH_SPACE_SIZE):
        # Reset coordinates and perform rotation
        coords_rotated, norm_vector = interp_plane_rotate(coords, alpha[i], beta[j])
        for k in range(N_SEARCH_SPACE_SIZE):
            # Fix: start from base rotated coordinates each time to avoid cumulative displacement
            coords_displaced = coords_rotated + d[k] * norm_vector
            weights = trilinear_interp(coords_displaced)
            sphere_integral_table[k, i, j] = weights.sum()

sphere_integral_table *= (BOUNDARY * 2.0 / N_QUAD_POINT_PER_EDGE) ** 2

In [None]:
sio.savemat(
    "data/sphere_integral_table.mat",
    {"sphere_integral_table": sphere_integral_table.cpu().numpy()},
)

In [None]:
sphere_integral_table_cpu = sphere_integral_table.cpu().numpy()
plt.plot(sphere_integral_table_cpu[30, 10, :])
plt.show()

In [None]:
# Compute gradient of sphere_integral_table along the d direction and save as sphere_integral_gradd_table.mat

# Compute gradient along dimension 0 (d direction)
# Use central difference for gradient computation, forward/backward difference at boundaries
sphere_integral_gradd_table = torch.zeros_like(sphere_integral_table)

# Spacing along the d direction
dd = d[1] - d[0]  # d is uniformly spaced

# Compute gradient - only need to loop over the d dimension (dimension 0)
for i in range(N_SEARCH_SPACE_SIZE):
    if i == 0:
        # Forward difference (at boundary)
        sphere_integral_gradd_table[i, :, :] = (
            sphere_integral_table[i + 1, :, :] - sphere_integral_table[i, :, :]
        ) / dd
    elif i == N_SEARCH_SPACE_SIZE - 1:
        # Backward difference (at boundary)
        sphere_integral_gradd_table[i, :, :] = (
            sphere_integral_table[i, :, :] - sphere_integral_table[i - 1, :, :]
        ) / dd
    else:
        # Central difference (interior points)
        sphere_integral_gradd_table[i, :, :] = (
            sphere_integral_table[i + 1, :, :] - sphere_integral_table[i - 1, :, :]
        ) / (2 * dd)

print("Gradient computation completed")

In [None]:
# Save gradient table to .mat file
sio.savemat(
    "data/sphere_integral_gradd_table.mat",
    {"sphere_integral_gradd_table": sphere_integral_gradd_table.cpu().numpy()},
)

print("Gradient table saved to data/sphere_integral_gradd_table.mat")

In [None]:
# Visualize original data and gradient data for verification
sphere_integral_gradd_table_cpu = sphere_integral_gradd_table.cpu().numpy()

plt.figure(figsize=(12, 4))

# Original data
plt.subplot(1, 3, 1)
plt.plot(sphere_integral_table_cpu[:, 30, 10])
plt.title("Original Integral Table (alpha=30, beta=10)")
plt.xlabel("d Index")
plt.ylabel("Integral Value")

# Gradient data
plt.subplot(1, 3, 2)
plt.plot(sphere_integral_gradd_table_cpu[:, 30, 10])
plt.title("Gradient along d direction (alpha=30, beta=10)")
plt.xlabel("d Index")
plt.ylabel("Gradient Value")

# Compare original data at different d positions
plt.subplot(1, 3, 3)
plt.plot(
    d.cpu().numpy(), sphere_integral_table_cpu[:, 30, 10], "b-", label="Original Data"
)
plt.plot(
    d.cpu().numpy(),
    sphere_integral_gradd_table_cpu[:, 30, 10],
    "r-",
    label="Gradient Data",
)
plt.title("Original Data vs Gradient (alpha=30, beta=10)")
plt.xlabel("d Value")
plt.ylabel("Value")
plt.legend()

plt.tight_layout()
plt.show()

- **Note: The integral values (related to area) scale with $res^2$, while the gradient values along the $d$ direction scale only with $res$!**