### 1D natural cubic spline method assuming equidistant grids in each dimension, of Habermann and Kindermann 2007,
### jittable, auto-differentiable using JAX, 2022/10/4

In [82]:
import numpy as np
from scipy import linalg

In [83]:
N = 1  # dimension (number of independent variables)

In [84]:
# lower and uppder bounds of x-coordinate in each dimension [a,b]
a = [0]  # [1st dim, 2nd dim]
b = [1]  # [1st dim, 2nd dim]
a = np.array(a, dtype=float)
b = np.array(b, dtype=float)

In [85]:
# number of grid interval n in each dimension
n = [10]  # [1st dim, 2nd dim]
n = np.array(n, dtype=int)

In [86]:
# grid interval n in each dimension
h = (b - a) / n

In [87]:
# x_grid=(np.linspace(a[0],b[0],n[0]+1), np.linspace(a[1],b[1],n[1]+1))  # tuple of (1st dim grid points , 2nd dim grid points)
x_grid = ()
for j in range(N):
    x_grid += (np.linspace(a[j], b[j], n[j] + 1),)
# x_grid[0] is numpy array of 1st dim grid points
# x_grid[1] is numpy array of 2nd dim grid points
# x_grid

In [88]:
grid_shape = ()
for j in range(N):
    grid_shape += (n[j] + 1,)
y_data = np.zeros(grid_shape)

In [89]:
# test y_data
for q1 in range(n[0] + 1):
    y_data[q1] = np.sin(x_grid[0][q1])
# y_data

In [90]:
def compute_1D_spline_coefs(y_data):
    """
    Compute 1D spline coefficient matrix c_i1i2

    INPUTs:
        y_data: 1D numpy array of y data (real scalar) at x_grid points
    """
    N = 1  # dimension of the problem

    for j in range(N):
        n[j] = y_data.shape[j] - 1  # number of grid interval n in each dimension

    k = 0  # 1-st dimension
    c_shape = ()
    for j in range(N):
        if j <= k:
            c_shape += (n[j] + 3,)
        else:
            c_shape += (n[j] + 1,)
    c_i1 = np.zeros(c_shape)
    del c_shape

    c_i1[1] = y_data[0] / 6  # c_{2}
    c_i1[n[k] + 1] = y_data[n[k]] / 6  # c_{n+2}
    A = np.zeros((n[k] - 1, n[k] - 1))
    for i in range(n[k] - 1):
        A[i, i] = 4
        if i + 1 < n[k] - 1:
            A[i, i + 1] = A[i + 1, i] = 1
    B = np.zeros(n[k] - 1)
    B[0] = y_data[1] - c_i1[1]
    B[n[k] - 2] = y_data[n[k] - 1] - c_i1[n[k] + 1]
    B[1 : n[k] - 2] = y_data[2 : n[k] - 1]
    sol = linalg.solve(A, B)
    c_i1[2 : n[k] + 1] = sol
    c_i1[0] = 2 * c_i1[1] - c_i1[2]
    c_i1[n[k] + 2] = 2 * c_i1[n[k] + 1] - c_i1[n[k]]

    return c_i1

In [91]:
c_i1 = compute_1D_spline_coefs(y_data)

In [92]:
c_i1[6]

0.08003786432695967

### compute spline interpolation and its gradient using JAX

In [93]:
import jax.numpy as jnp
from jax import lax

In [94]:
# 1D-spline function (jittable and auto-differentiable)
def s1D(x, a, h, c_i1):
    """
    1D-spline interpolation

    INPUTs
        x: 1-dim x vector (float) at which interplated y-value is evaluated
        a: 1-dim vector (float) of the lower boundary of the each of the x-dimension
        h: 1-dim vector (float) of the grid interval of the each of the x-dimension
        c_i1: spline coefficient (1-dim array)
    """

    def u(ii, aa, hh, xx):
        t = jnp.abs((xx - aa) / hh + 2 - ii)
        return lax.cond(
            t <= 1, lambda t: 4.0 - 6.0 * t**2 + 3.0 * t**3, lambda t: (2.0 - t) ** 3, t
        ) * jnp.heaviside(2.0 - t, 1.0)

    def f(carry, i1, c_i1, a, h, x):
        val = c_i1[i1 - 1] * u(i1, a[0], h[0], x[0])
        carry += val
        return carry, val

    i1arr = jnp.arange(1, c_i1.shape[0] + 1)

    carry, val = lax.scan(
        lambda s1, i1: f(s1, i1=i1, c_i1=c_i1, a=a, h=h, x=x), 0.0, i1arr
    )

    return carry

In [95]:
c_i1_jnp = jnp.array(c_i1)
a_jnp = jnp.array(a)
h_jnp = jnp.array(h)

In [96]:
x = jnp.array([0.5])

In [97]:
s1D(x, a_jnp, h_jnp, c_i1_jnp)

DeviceArray(0.47942552, dtype=float32)

In [98]:
y_data[5]

0.479425538604203

In [99]:
from jax import grad, jit

In [100]:
s1D_jitted = jit(s1D)

In [101]:
s1D_jitted(x, a_jnp, h_jnp, c_i1_jnp)

DeviceArray(0.47942552, dtype=float32)

In [102]:
ds1D = grad(s1D)

In [103]:
ds1D(x, a_jnp, h_jnp, c_i1_jnp)

DeviceArray([0.87754846], dtype=float32)

In [104]:
ds1D_jitted = jit(grad(s1D))

In [105]:
ds1D_jitted(x, a_jnp, h_jnp, c_i1_jnp)

DeviceArray([0.87754846], dtype=float32)

In [106]:
x = jnp.array([0.5])
ds1D_jitted(x, a_jnp, h_jnp, c_i1_jnp)

DeviceArray([0.87754846], dtype=float32)

In [107]:
from jax import value_and_grad

In [108]:
s1D_fun = jit(value_and_grad(s1D))

In [109]:
s1D_fun(x, a_jnp, h_jnp, c_i1_jnp)

(DeviceArray(0.47942552, dtype=float32),
 DeviceArray([0.87754846], dtype=float32))

In [110]:
x = jnp.array([0.5])
s1D_fun(x, a_jnp, h_jnp, c_i1_jnp)

(DeviceArray(0.47942552, dtype=float32),
 DeviceArray([0.87754846], dtype=float32))