In [None]:
import numpy as np
from numba import njit, prange
import matplotlib.pyplot as plt
from skimage.data import coffee

In [None]:
img = coffee().mean(2)

# @njit(inline="always")


# TODO: complex and muli-channel images?
@njit(parallel=True, nogil=True, fastmath=False)
def remap(img, rr, cc):
    def ker(x):
        ax = np.abs(x)
        # if (ax >= 0) & (ax < 1):
        if ax < 1:
            return 1.5 * ax**3 - 2.5 * ax**2 + 1
        elif (ax >= 1) & (ax < 2):
            return -0.5 * ax**3 + 2.5 * ax**2 - 4 * ax + 2
        else:
            return 0.0

    if rr.shape != cc.shape:
        raise ValueError("Coordinate arrays must have the same shape.")

    print("alloc")
    arr_out = np.full_like(rr, np.nan, dtype=img.dtype)
    naz, nrg = img.shape

    print("loop")
    # for idx, (r, c) in enumerate(zip(rr.flat, cc.flat)):
    for idx in prange(len(rr.flat)):
        r = rr.flat[idx]
        c = cc.flat[idx]
        rmin = np.floor(r) - 1
        rmax = np.ceil(r) + 1
        cmin = np.floor(c) - 1
        cmax = np.ceil(c) + 1

        val = 0.0
        is_in_image = (
            (r >= 0) & (r < img.shape[0]) & (c >= 0) & (c < img.shape[1])
        )
        if np.isnan(r) | np.isnan(c):
            continue
        if not is_in_image:
            continue
        for i in range(int(rmin), int(rmax) + 1):
            for j in range(int(cmin), int(cmax) + 1):
                # using nearest neighbor on image border
                i2 = np.minimum(np.maximum(0, i), naz - 1)
                j2 = np.minimum(np.maximum(0, j), nrg - 1)
                val += ker(r - i) * ker(c - j) * img[i2, j2]
        
        arr_out.flat[idx] = val 
    return arr_out

naz, nrg = img.shape

# rr = np.random.rand(256, 256) * (naz - 1)
# cc = np.random.rand(256, 256) * (nrg - 1)

# rr = np.ones(50) * 10
# cc = np.ones(50) * 10

# rr[0] = np.nan
scale = 20.0
rr, cc = np.mgrid[:scale*naz, :scale*nrg].astype("float32")
rr /= scale
cc /= scale

# re = remap(img, rr, cc)


In [None]:
re = remap(img, rr, cc)


In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(re, interpolation="none")
plt.colorbar(fraction=0.046, pad=0.04)