### 4D natural cubic spline method assuming equidistant grids in each dimension, of Habermann and Kindermann 2007,
### jittable, auto-differentiable using JAX, 2022/10/4

In [1]:
import numpy as np
from scipy import linalg

In [2]:
N = 4  # dimension (number of independent variables)

In [3]:
# lower and uppder bounds of x-coordinate in each dimension [a,b]
a = [0, 0, 0, 0]  # [1st dim, 2nd dim, 3rd dim, 4th dim]
b = [1, 2, 3, 4]  # [1st dim, 2nd dim, 3rd dim, 4th dim]
a = np.array(a, dtype=float)
b = np.array(b, dtype=float)

In [4]:
# number of grid interval n in each dimension
n = [10, 20, 30, 40]  # [1st dim, 2nd dim, 3rd dim]
n = np.array(n, dtype=int)

In [5]:
# grid interval n in each dimension
h = (b - a) / n

In [6]:
# x_grid=(np.linspace(a[0],b[0],n[0]+1), np.linspace(a[1],b[1],n[1]+1))  # tuple of (1st dim grid points , 2nd dim grid points)
x_grid = ()
for j in range(N):
    x_grid += (np.linspace(a[j], b[j], n[j] + 1),)
# x_grid[0] is numpy array of 1st dim grid points
# x_grid[1] is numpy array of 2nd dim grid points
# x_grid

In [7]:
grid_shape = ()
for j in range(N):
    grid_shape += (n[j] + 1,)
y_data = np.zeros(grid_shape)

In [8]:
# test y_data
for q1 in range(n[0] + 1):
    for q2 in range(n[1] + 1):
        for q3 in range(n[2] + 1):
            for q4 in range(n[3] + 1):
                y_data[q1, q2, q3, q4] = (
                    np.sin(x_grid[0][q1])
                    * np.sin(x_grid[1][q2])
                    * np.sin(x_grid[2][q3])
                    * np.sin(x_grid[3][q4])
                )
# y_data

In [9]:
def compute_4D_spline_coefs(y_data):
    """
    Compute 4D spline coefficient matrix c_i1i2i3i4

    INPUTs:
        y_data: 4D numpy array of y data (real scalar) at x_grid points
    """
    N = 4  # dimension of the problem

    for j in range(N):
        n[j] = y_data.shape[j] - 1  # number of grid interval n in each dimension

    k = 0  # 1-st dimension
    c_shape = ()
    for j in range(N):
        if j <= k:
            c_shape += (n[j] + 3,)
        else:
            c_shape += (n[j] + 1,)
    c_i1q2q3q4 = np.zeros(c_shape)
    del c_shape
    for q2 in range(n[1] + 1):
        for q3 in range(n[2] + 1):
            for q4 in range(n[3] + 1):
                c_i1q2q3q4[1, q2, q3, q4] = y_data[0, q2, q3, q4] / 6  # c_{2}
                c_i1q2q3q4[n[k] + 1, q2, q3, q4] = (
                    y_data[n[k], q2, q3, q4] / 6
                )  # c_{n+2}
                A = np.zeros((n[k] - 1, n[k] - 1))
                for i in range(n[k] - 1):
                    A[i, i] = 4
                    if i + 1 < n[k] - 1:
                        A[i, i + 1] = A[i + 1, i] = 1
                B = np.zeros(n[k] - 1)
                B[0] = y_data[1, q2, q3, q4] - c_i1q2q3q4[1, q2, q3, q4]
                B[n[k] - 2] = (
                    y_data[n[k] - 1, q2, q3, q4] - c_i1q2q3q4[n[k] + 1, q2, q3, q4]
                )
                B[1 : n[k] - 2] = y_data[2 : n[k] - 1, q2, q3, q4]
                sol = linalg.solve(A, B)
                c_i1q2q3q4[2 : n[k] + 1, q2, q3, q4] = sol
                c_i1q2q3q4[0, q2, q3, q4] = (
                    2 * c_i1q2q3q4[1, q2, q3, q4] - c_i1q2q3q4[2, q2, q3, q4]
                )
                c_i1q2q3q4[n[k] + 2, q2, q3, q4] = (
                    2 * c_i1q2q3q4[n[k] + 1, q2, q3, q4] - c_i1q2q3q4[n[k], q2, q3, q4]
                )

    k = 1  # 2nd dimension
    c_shape = ()
    for j in range(N):
        if j <= k:
            c_shape += (n[j] + 3,)
        else:
            c_shape += (n[j] + 1,)
    c_i1i2q3q4 = np.zeros(c_shape)
    del c_shape
    for i1 in range(n[0] + 3):
        for q3 in range(n[2] + 1):
            for q4 in range(n[3] + 1):
                c_i1i2q3q4[i1, 1, q3, q4] = c_i1q2q3q4[i1, 0, q3, q4] / 6  # c_{2}
                c_i1i2q3q4[i1, n[k] + 1, q3, q4] = (
                    c_i1q2q3q4[i1, n[k], q3, q4] / 6
                )  # c_{n+2}
                A = np.zeros((n[k] - 1, n[k] - 1))
                for i in range(n[k] - 1):
                    A[i, i] = 4
                    if i + 1 < n[k] - 1:
                        A[i, i + 1] = A[i + 1, i] = 1
                B = np.zeros(n[k] - 1)
                B[0] = c_i1q2q3q4[i1, 1, q3, q4] - c_i1i2q3q4[i1, 1, q3, q4]
                B[n[k] - 2] = (
                    c_i1q2q3q4[i1, n[k] - 1, q3, q4] - c_i1i2q3q4[i1, n[k] + 1, q3, q4]
                )
                B[1 : n[k] - 2] = c_i1q2q3q4[i1, 2 : n[k] - 1, q3, q4]

                sol = linalg.solve(A, B)
                c_i1i2q3q4[i1, 2 : n[k] + 1, q3, q4] = sol
                c_i1i2q3q4[i1, 0, q3, q4] = (
                    2 * c_i1i2q3q4[i1, 1, q3, q4] - c_i1i2q3q4[i1, 2, q3, q4]
                )
                c_i1i2q3q4[i1, n[k] + 2, q3, q4] = (
                    2 * c_i1i2q3q4[i1, n[k] + 1, q3, q4] - c_i1i2q3q4[i1, n[k], q3, q4]
                )

    k = 2  # 3rd dimension
    c_shape = ()
    for j in range(N):
        if j <= k:
            c_shape += (n[j] + 3,)
        else:
            c_shape += (n[j] + 1,)
    c_i1i2i3q4 = np.zeros(c_shape)
    del c_shape
    for i1 in range(n[0] + 3):
        for i2 in range(n[1] + 3):
            for q4 in range(n[3] + 1):
                c_i1i2i3q4[i1, i2, 1, q4] = c_i1i2q3q4[i1, i2, 0, q4] / 6  # c_{2}
                c_i1i2i3q4[i1, i2, n[k] + 1, q4] = (
                    c_i1i2q3q4[i1, i2, n[k], q4] / 6
                )  # c_{n+2}
                A = np.zeros((n[k] - 1, n[k] - 1))
                for i in range(n[k] - 1):
                    A[i, i] = 4
                    if i + 1 < n[k] - 1:
                        A[i, i + 1] = A[i + 1, i] = 1
                B = np.zeros(n[k] - 1)
                B[0] = c_i1i2q3q4[i1, i2, 1, q4] - c_i1i2i3q4[i1, i2, 1, q4]
                B[n[k] - 2] = (
                    c_i1i2q3q4[i1, i2, n[k] - 1, q4] - c_i1i2i3q4[i1, i2, n[k] + 1, q4]
                )
                B[1 : n[k] - 2] = c_i1i2q3q4[i1, i2, 2 : n[k] - 1, q4]

                sol = linalg.solve(A, B)
                c_i1i2i3q4[i1, i2, 2 : n[k] + 1, q4] = sol
                c_i1i2i3q4[i1, i2, 0, q4] = (
                    2 * c_i1i2i3q4[i1, i2, 1, q4] - c_i1i2i3q4[i1, i2, 2, q4]
                )
                c_i1i2i3q4[i1, i2, n[k] + 2, q4] = (
                    2 * c_i1i2i3q4[i1, i2, n[k] + 1, q4] - c_i1i2i3q4[i1, i2, n[k], q4]
                )

    k = 3  # 4th dimension
    c_shape = ()
    for j in range(N):
        if j <= k:
            c_shape += (n[j] + 3,)
        else:
            c_shape += (n[j] + 1,)
    c_i1i2i3i4 = np.zeros(c_shape)
    del c_shape
    for i1 in range(n[0] + 3):
        for i2 in range(n[1] + 3):
            for i3 in range(n[2] + 3):
                c_i1i2i3i4[i1, i2, i3, 1] = c_i1i2i3q4[i1, i2, i3, 0] / 6  # c_{2}
                c_i1i2i3i4[i1, i2, i3, n[k] + 1] = (
                    c_i1i2i3q4[i1, i2, i3, n[k]] / 6
                )  # c_{n+2}
                A = np.zeros((n[k] - 1, n[k] - 1))
                for i in range(n[k] - 1):
                    A[i, i] = 4
                    if i + 1 < n[k] - 1:
                        A[i, i + 1] = A[i + 1, i] = 1
                B = np.zeros(n[k] - 1)
                B[0] = c_i1i2i3q4[i1, i2, i3, 1] - c_i1i2i3i4[i1, i2, i3, 1]
                B[n[k] - 2] = (
                    c_i1i2i3q4[i1, i2, i3, n[k] - 1] - c_i1i2i3i4[i1, i2, i3, n[k] + 1]
                )
                B[1 : n[k] - 2] = c_i1i2i3q4[i1, i2, i3, 2 : n[k] - 1]
                sol = linalg.solve(A, B)
                c_i1i2i3i4[i1, i2, i3, 2 : n[k] + 1] = sol
                c_i1i2i3i4[i1, i2, i3, 0] = (
                    2 * c_i1i2i3i4[i1, i2, i3, 1] - c_i1i2i3i4[i1, i2, i3, 2]
                )
                c_i1i2i3i4[i1, i2, i3, n[k] + 2] = (
                    2 * c_i1i2i3i4[i1, i2, i3, n[k] + 1] - c_i1i2i3i4[i1, i2, i3, n[k]]
                )

    return c_i1i2i3i4

In [39]:
c_i1i2i3i4 = compute_4D_spline_coefs(y_data)

In [40]:
c_i1i2i3i4[6, 9, 10, 6]

0.0001003257034482869

### compute spline interpolation and its gradient using JAX

In [41]:
import jax.numpy as jnp
from jax import lax

In [42]:
# 4D-spline function (jittable and auto-differentiable)
def s4D(x, a, h, c_i1i2i3i4):
    """
    4D-spline interpolation

    INPUTs
        x: 4-dim x vector (float) at which interplated y-value is evaluated
        a: 4-dim vector (float) of the lower boundary of the each of the x-dimension
        h: 4-dim vector (float) of the grid interval of the each of the x-dimension
        c_i1i2i3i4: spline coefficient (4-dim array)
    """

    def u(ii, aa, hh, xx):
        t = jnp.abs((xx - aa) / hh + 2 - ii)
        return lax.cond(
            t <= 1, lambda t: 4.0 - 6.0 * t**2 + 3.0 * t**3, lambda t: (2.0 - t) ** 3, t
        ) * jnp.heaviside(2.0 - t, 1.0)

    def f(carry, i1, i2, i3, i4, c_i1i2i3i4, a, h, x):
        val = (
            c_i1i2i3i4[i1 - 1, i2 - 1, i3 - 1, i4 - 1]
            * u(i1, a[0], h[0], x[0])
            * u(i2, a[1], h[1], x[1])
            * u(i3, a[2], h[2], x[2])
            * u(i4, a[3], h[3], x[3])
        )
        carry += val
        return carry, val

    i1arr = jnp.arange(1, c_i1i2i3i4.shape[0] + 1)
    i2arr = jnp.arange(1, c_i1i2i3i4.shape[1] + 1)
    i3arr = jnp.arange(1, c_i1i2i3i4.shape[2] + 1)
    i4arr = jnp.arange(1, c_i1i2i3i4.shape[3] + 1)

    carry, val = lax.scan(
        lambda s1, i1: lax.scan(
            lambda s2, i2: lax.scan(
                lambda s3, i3: lax.scan(
                    lambda s4, i4: f(
                        s4,
                        i1=i1,
                        i2=i2,
                        i3=i3,
                        i4=i4,
                        c_i1i2i3i4=c_i1i2i3i4,
                        a=a,
                        h=h,
                        x=x,
                    ),
                    s3,
                    i4arr,
                ),
                s2,
                i3arr,
            ),
            s1,
            i2arr,
        ),
        0.0,
        i1arr,
    )

    return carry

In [43]:
c_i1i2i3i4_jnp = jnp.array(c_i1i2i3i4)
a_jnp = jnp.array(a)
h_jnp = jnp.array(h)

In [44]:
x = jnp.array([0.5, 1.0, 1.5, 2.0])

In [45]:
s4D(x, a_jnp, h_jnp, c_i1i2i3i4_jnp)

DeviceArray(0.3659122, dtype=float32)

In [16]:
y_data[5, 10, 15, 20]

0.36591228786591046

In [17]:
from jax import grad, jit

In [18]:
s4D_jitted = jit(s4D)

In [19]:
s4D_jitted(x, a_jnp, h_jnp, c_i1i2i3i4_jnp)

DeviceArray(0.3659122, dtype=float32)

In [20]:
ds4D = grad(s4D)

In [21]:
ds4D(x, a_jnp, h_jnp, c_i1i2i3i4_jnp)

DeviceArray([ 0.669772  ,  0.23494934,  0.02594866, -0.16746247], dtype=float32)

In [22]:
ds4D_jitted = jit(grad(s4D))

In [23]:
ds4D_jitted(x, a_jnp, h_jnp, c_i1i2i3i4_jnp)

DeviceArray([ 0.669772  ,  0.23494934,  0.02594866, -0.16746247], dtype=float32)

In [24]:
x = jnp.array([1.0, 2.0, 3.0, 4.0])

In [25]:
ds4D_jitted(x, a_jnp, h_jnp, c_i1i2i3i4_jnp)

DeviceArray([-0.05483124,  0.03503776,  0.57090926, -0.07293978], dtype=float32)

In [26]:
x = jnp.array([0.5, 1.0, 1.5, 2.0])
ds4D_jitted(x, a_jnp, h_jnp, c_i1i2i3i4_jnp)

DeviceArray([ 0.669772  ,  0.23494934,  0.02594866, -0.16746247], dtype=float32)

In [27]:
from jax import value_and_grad

In [28]:
s4D_fun = jit(value_and_grad(s4D))

In [29]:
s4D_fun(x, a_jnp, h_jnp, c_i1i2i3i4_jnp)

(DeviceArray(0.3659122, dtype=float32),
 DeviceArray([ 0.669772  ,  0.23494934,  0.02594866, -0.16746247], dtype=float32))

In [30]:
x = jnp.array([0.5, 1.0, 1.5, 2.0])

In [31]:
s4D_fun(x, a_jnp, h_jnp, c_i1i2i3i4_jnp)

(DeviceArray(0.3659122, dtype=float32),
 DeviceArray([ 0.669772  ,  0.23494934,  0.02594866, -0.16746247], dtype=float32))