In [None]:
# gpu_acceleration.py
import numpy as np

try:
    import cupy as cp
    import cupyx.scipy.sparse as cusparse
    import cupyx.scipy.sparse.linalg as cusparse_linalg
except ImportError:
    raise ImportError("Requires CuPy with cupyx.scipy. "
                      "Install via 'pip install cupy-cuda12x'")


# --------------------- Sparse Laplacians -----------------------------

def _laplacian_2d_aniso(H, W, dx=1.0, dy=1.0, wx=1.0, wy=1.0, dtype=cp.float64):
    """2D anisotropic, weighted Laplacian."""
    e_h = cp.ones(H, dtype=dtype)
    e_w = cp.ones(W, dtype=dtype)

    # 1D second differences scaled by spacing and weight
    D_h = cusparse.diags([e_h, -2*e_h, e_h], offsets=[-1,0,1], shape=(H,H), format='csr') * (wy/dy**2)
    D_w = cusparse.diags([e_w, -2*e_w, e_w], offsets=[-1,0,1], shape=(W,W), format='csr') * (wx/dx**2)

    I_h = cusparse.identity(H, dtype=dtype, format='csr')
    I_w = cusparse.identity(W, dtype=dtype, format='csr')

    L = cusparse.kron(I_w, D_h, format='csr') + cusparse.kron(D_w, I_h, format='csr')
    return L


def _laplacian_3d_aniso(D, H, W, dz=1.0, dy=1.0, dx=1.0, wz=1.0, wy=1.0, wx=1.0, dtype=cp.float64):
    """3D anisotropic, weighted Laplacian."""
    e_d = cp.ones(D, dtype=dtype)
    e_h = cp.ones(H, dtype=dtype)
    e_w = cp.ones(W, dtype=dtype)

    D_d = cusparse.diags([e_d, -2*e_d, e_d], offsets=[-1,0,1], shape=(D,D), format='csr') * (wz/dz**2)
    D_h = cusparse.diags([e_h, -2*e_h, e_h], offsets=[-1,0,1], shape=(H,H), format='csr') * (wy/dy**2)
    D_w = cusparse.diags([e_w, -2*e_w, e_w], offsets=[-1,0,1], shape=(W,W), format='csr') * (wx/dx**2)

    I_d = cusparse.identity(D, dtype=dtype, format='csr')
    I_h = cusparse.identity(H, dtype=dtype, format='csr')
    I_w = cusparse.identity(W, dtype=dtype, format='csr')

    L = (
        cusparse.kron(cusparse.kron(I_w, I_h), D_d, format='csr') +
        cusparse.kron(cusparse.kron(I_w, D_h), I_d, format='csr') +
        cusparse.kron(cusparse.kron(D_w, I_h), I_d, format='csr')
    )
    return L


# --------------------- Main function -----------------------------

def biharmonic_inpaint_gpu_aniso(image, mask,
                                 spacing=None,
                                 weights=None,
                                 tol=1e-6, maxiter=None,
                                 solver='cg', verbose=False):
    """
    GPU biharmonic inpainting (2D or 3D) with anisotropic spacing and axis weights.

    Parameters
    ----------
    image : np.ndarray
        2D or 3D array.
    mask : np.ndarray
        Boolean mask. True = missing voxel/pixel.
    spacing : tuple of floats
        Voxel spacing (dx, dy) for 2D, (dx, dy, dz) for 3D.
    weights : tuple of floats
        Axis weights (wx, wy) for 2D, (wx, wy, wz) for 3D.
    tol, maxiter, solver, verbose : see previous function

    Returns
    -------
    np.ndarray
        Inpainted array.
    """

    ndim = image.ndim
    if ndim not in (2,3):
        raise ValueError("Only 2D or 3D arrays supported")

    img_gpu = cp.asarray(image, dtype=cp.float64)
    mask_gpu = cp.asarray(mask, dtype=bool)

    shape = img_gpu.shape
    N = np.prod(shape)
    mask_flat = mask_gpu.ravel()
    unknown_idx = cp.nonzero(mask_flat)[0]
    known_idx = cp.nonzero(~mask_flat)[0]

    if unknown_idx.size == 0:
        return image.copy()

    # Default spacing and weights
    if spacing is None:
        spacing = (1.0,) * ndim
    if weights is None:
        weights = (1.0,) * ndim

    if verbose:
        print(f"Building {ndim}D anisotropic, weighted Laplacian ...")

    if ndim==2:
        dx, dy = spacing
        wx, wy = weights
        L = _laplacian_2d_aniso(shape[0], shape[1], dx, dy, wx, wy)
    else:
        dx, dy, dz = spacing
        wx, wy, wz = weights
        L = _laplacian_3d_aniso(shape[0], shape[1], shape[2], dz, dy, dx, wz, wy, wx)

    if verbose:
        print("Computing biharmonic operator ...")
    B = L.dot(L)

    A = B[unknown_idx][:, unknown_idx]
    B_uk = B[unknown_idx][:, known_idx]

    known_vals = img_gpu.ravel()[known_idx]
    rhs = -B_uk.dot(known_vals)

    if maxiter is None:
        maxiter = max(2000, 2*unknown_idx.size)

    solve_fn = cusparse_linalg.cg if solver=='cg' else cusparse_linalg.gmres
    if verbose:
        print(f"Solving system of size {unknown_idx.size} ...")

    x, info = solve_fn(A, rhs, tol=tol, maxiter=maxiter)
    if info!=0 and verbose:
        print(f"Solver returned info={info}")

    result = img_gpu.copy().ravel()
    result[unknown_idx] = x
    return cp.asnumpy(result.reshape(shape))

In [None]:
#Spacing: dx=1, dy=2 means vertical distances are twice horizontal.

#Weights: wx=2.0, wy=1.0 biases stronger continuity along the x-axis.

In [None]:
inpainted = biharmonic_inpaint_gpu_aniso(img, mask,
                                         spacing=(1.0, 2.0),
                                         weights=(2.0, 1.0),
                                         verbose=True)

In [None]:
# -------------------- 3D demo with anisotropic spacing and axis weights --------------------
D, H, W = 20, 40, 40
z, y, x = np.mgrid[:D, :H, :W]
vol = np.sin(x / 10.0) + np.cos(y / 15.0) + np.sin(z / 8.0)

# Mask: spherical hole in the center
mask3d = (x - W//2)**2 + (y - H//2)**2 + (z - D//2)**2 < (min(D,H,W)//3)**2
corrupted3d = vol.copy()
corrupted3d[mask3d] = 0

# Anisotropic spacing: dz, dy, dx (spacing along each axis)
spacing = (1.5, 1.0, 0.5)   # e.g., z spaced more coarsely, x finer
# Axis weights: stronger continuity along x-axis
weights = (0.5, 1.0, 2.0)   # wz, wy, wx

print("Running 3D GPU biharmonic inpainting with anisotropic spacing & axis weights...")
inpaint3d = biharmonic_inpaint_gpu_aniso(corrupted3d, mask3d,
                                            spacing=spacing,
                                            weights=weights,
                                            verbose=True)

# Visualize mid slices along each axis
mid_z, mid_y, mid_x = D//2, H//2, W//2
fig, axes = plt.subplots(3, 3, figsize=(12, 12))
# Original
axes[0,0].imshow(vol[mid_z,:,:], cmap='viridis'); axes[0,0].set_title('Original Z-slice'); axes[0,0].axis('off')
axes[1,0].imshow(vol[:,mid_y,:], cmap='viridis'); axes[1,0].set_title('Original Y-slice'); axes[1,0].axis('off')
axes[2,0].imshow(vol[:,:,mid_x], cmap='viridis'); axes[2,0].set_title('Original X-slice'); axes[2,0].axis('off')
# Corrupted
axes[0,1].imshow(corrupted3d[mid_z,:,:], cmap='viridis'); axes[0,1].set_title('Corrupted Z'); axes[0,1].axis('off')
axes[1,1].imshow(corrupted3d[:,mid_y,:], cmap='viridis'); axes[1,1].set_title('Corrupted Y'); axes[1,1].axis('off')
axes[2,1].imshow(corrupted3d[:,:,mid_x], cmap='viridis'); axes[2,1].set_title('Corrupted X'); axes[2,1].axis('off')
# Inpainted
axes[0,2].imshow(inpaint3d[mid_z,:,:], cmap='viridis'); axes[0,2].set_title('Inpainted Z'); axes[0,2].axis('off')
axes[1,2].imshow(inpaint3d[:,mid_y,:], cmap='viridis'); axes[1,2].set_title('Inpainted Y'); axes[1,2].axis('off')
axes[2,2].imshow(inpaint3d[:,:,mid_x], cmap='viridis'); axes[2,2].set_title('Inpainted X'); axes[2,2].axis('off')

plt.tight_layout()
plt.show()