##### 5D natural cubic spline method assuming equidistant grids in each dimension, of Habermann and Kindermann 2007,
### jittable, auto-differentiable using JAX, 2022/10/4

In [54]:
import numpy as np
from scipy import linalg

In [55]:
N=5 # dimension (number of independent variables)

In [56]:
# lower and uppder bounds of x-coordinate in each dimension [a,b]
a=[0,0,0,0,0] # [1st dim, 2nd dim, 3rd dim, 4th dim, 5th dim]
b=[1,2,3,4,5] # [1st dim, 2nd dim, 3rd dim, 4th dim, 5th dim]
a= np.array(a, dtype=float)
b= np.array(b, dtype=float)

In [57]:
# number of grid interval n in each dimension
n=[10,10,10,10,10] # [1st dim, 2nd dim, 3rd dim]
n= np.array(n, dtype=int)

In [58]:
# grid interval n in each dimension
h=(b-a)/n

In [59]:
#x_grid=(np.linspace(a[0],b[0],n[0]+1), np.linspace(a[1],b[1],n[1]+1))  # tuple of (1st dim grid points , 2nd dim grid points) 
x_grid=()
for j in range(N):
    x_grid += (np.linspace(a[j],b[j],n[j]+1),)
# x_grid[0] is numpy array of 1st dim grid points
# x_grid[1] is numpy array of 2nd dim grid points
#x_grid

In [60]:
grid_shape=()
for j in range(N):
    grid_shape += (n[j]+1,)
y_data= np.zeros(grid_shape)

In [61]:
# test y_data
for q1 in range(n[0]+1):
    for q2 in range(n[1]+1):
        for q3 in range(n[2]+1):
             for q4 in range(n[3]+1):
                    for q5 in range(n[3]+1):
                        y_data[q1,q2,q3,q4,q5]= np.sin(x_grid[0][q1])*np.sin(x_grid[1][q2])*np.sin(x_grid[2][q3])*np.sin(x_grid[3][q4])*np.sin(x_grid[4][q5])
#y_data

In [62]:
def compute_5D_spline_coefs(y_data):
    '''
    Compute 5D spline coefficient matrix c_i1i2i3i4i5
    
    INPUTs:
        y_data: 5D numpy array of y data (real scalar) at x_grid points 
    '''
    
    N=5 # dimension of the problem
    
    for j in range(N):
        n[j]= y_data.shape[j]-1 #number of grid interval n in each dimension
    
    k=0 # 1-st dimension
    c_shape=()
    for j in range(N):
        if j <= k :
            c_shape += (n[j]+3,)
        else:
            c_shape += (n[j]+1,)
    c_i1q2q3q4q5=np.zeros(c_shape)
    del c_shape
    for q2 in range(n[1]+1):
        for q3 in range(n[2]+1):
            for q4 in range(n[3]+1):
                for q5 in range(n[4]+1):
                    c_i1q2q3q4q5[1,q2,q3,q4,q5]=y_data[0,q2,q3,q4,q5]/6 # c_{2}
                    c_i1q2q3q4q5[n[k]+1,q2,q3,q4,q5]=y_data[n[k],q2,q3,q4,q5]/6 # c_{n+2}
                    A = np.zeros((n[k]-1,n[k]-1))
                    for i in range(n[k]-1):
                        A[i,i]=4
                        if i+1 < n[k]-1 :
                            A[i,i+1]= A[i+1,i]=1
                    B = np.zeros(n[k]-1)
                    B[0]=y_data[1,q2,q3,q4,q5]-c_i1q2q3q4q5[1,q2,q3,q4,q5]
                    B[n[k]-2]=y_data[n[k]-1,q2,q3,q4,q5]-c_i1q2q3q4q5[n[k]+1,q2,q3,q4,q5]
                    B[1:n[k]-2]=y_data[2:n[k]-1,q2,q3,q4,q5]
                    sol = linalg.solve(A, B)
                    c_i1q2q3q4q5[2:n[k]+1,q2,q3,q4,q5]=sol
                    c_i1q2q3q4q5[0,q2,q3,q4,q5]=2*c_i1q2q3q4q5[1,q2,q3,q4,q5]-c_i1q2q3q4q5[2,q2,q3,q4,q5]
                    c_i1q2q3q4q5[n[k]+2,q2,q3,q4,q5]=2*c_i1q2q3q4q5[n[k]+1,q2,q3,q4,q5]-c_i1q2q3q4q5[n[k],q2,q3,q4,q5]

    k=1 # 2nd dimension
    c_shape=()
    for j in range(N):
        if j <= k :
            c_shape += (n[j]+3,)
        else:
            c_shape += (n[j]+1,)
    c_i1i2q3q4q5=np.zeros(c_shape)
    del c_shape
    for i1 in range(n[0]+3):
        for q3 in range(n[2]+1):
            for q4 in range(n[3]+1):
                for q5 in range(n[4]+1):
                    c_i1i2q3q4q5[i1,1,q3,q4,q5]=c_i1q2q3q4q5[i1,0,q3,q4,q5]/6 # c_{2}
                    c_i1i2q3q4q5[i1,n[k]+1,q3,q4,q5]=c_i1q2q3q4q5[i1,n[k],q3,q4,q5]/6 # c_{n+2}
                    A = np.zeros((n[k]-1,n[k]-1))
                    for i in range(n[k]-1):
                        A[i,i]=4
                        if i+1 < n[k]-1 :
                            A[i,i+1]= A[i+1,i]=1
                    B = np.zeros(n[k]-1)
                    B[0]=c_i1q2q3q4q5[i1,1,q3,q4,q5]-c_i1i2q3q4q5[i1,1,q3,q4,q5]
                    B[n[k]-2]=c_i1q2q3q4q5[i1,n[k]-1,q3,q4,q5]-c_i1i2q3q4q5[i1,n[k]+1,q3,q4,q5]
                    B[1:n[k]-2]=c_i1q2q3q4q5[i1,2:n[k]-1,q3,q4,q5]

                    sol = linalg.solve(A, B)
                    c_i1i2q3q4q5[i1,2:n[k]+1,q3,q4,q5]=sol
                    c_i1i2q3q4q5[i1,0,q3,q4,q5]=2*c_i1i2q3q4q5[i1,1,q3,q4,q5]-c_i1i2q3q4q5[i1,2,q3,q4,q5]
                    c_i1i2q3q4q5[i1,n[k]+2,q3,q4,q5]=2*c_i1i2q3q4q5[i1,n[k]+1,q3,q4,q5]-c_i1i2q3q4q5[i1,n[k],q3,q4,q5]

    k=2 # 3rd dimension
    c_shape=()
    for j in range(N):
        if j <= k :
            c_shape += (n[j]+3,)
        else:
            c_shape += (n[j]+1,)
    c_i1i2i3q4q5=np.zeros(c_shape)
    del c_shape
    for i1 in range(n[0]+3):
        for i2 in range(n[1]+3):
            for q4 in range(n[3]+1):
                for q5 in range(n[4]+1):
                    c_i1i2i3q4q5[i1,i2,1,q4,q5]=c_i1i2q3q4q5[i1,i2,0,q4,q5]/6 # c_{2}
                    c_i1i2i3q4q5[i1,i2,n[k]+1,q4,q5]=c_i1i2q3q4q5[i1,i2,n[k],q4,q5]/6 # c_{n+2}
                    A = np.zeros((n[k]-1,n[k]-1))
                    for i in range(n[k]-1):
                        A[i,i]=4
                        if i+1 < n[k]-1 :
                            A[i,i+1]= A[i+1,i]=1
                    B = np.zeros(n[k]-1)
                    B[0]=c_i1i2q3q4q5[i1,i2,1,q4,q5]-c_i1i2i3q4q5[i1,i2,1,q4,q5]
                    B[n[k]-2]=c_i1i2q3q4q5[i1,i2,n[k]-1,q4,q5]-c_i1i2i3q4q5[i1,i2,n[k]+1,q4,q5]
                    B[1:n[k]-2]=c_i1i2q3q4q5[i1,i2,2:n[k]-1,q4,q5]

                    sol = linalg.solve(A, B)
                    c_i1i2i3q4q5[i1,i2,2:n[k]+1,q4,q5]=sol
                    c_i1i2i3q4q5[i1,i2,0,q4,q5]=2*c_i1i2i3q4q5[i1,i2,1,q4,q5]-c_i1i2i3q4q5[i1,i2,2,q4,q5]
                    c_i1i2i3q4q5[i1,i2,n[k]+2,q4,q5]=2*c_i1i2i3q4q5[i1,i2,n[k]+1,q4,q5]-c_i1i2i3q4q5[i1,i2,n[k],q4,q5]

    k=3 # 4th dimension
    c_shape=()
    for j in range(N):
        if j <= k :
            c_shape += (n[j]+3,)
        else:
            c_shape += (n[j]+1,)
    c_i1i2i3i4q5=np.zeros(c_shape)
    del c_shape
    for i1 in range(n[0]+3):
        for i2 in range(n[1]+3):
            for i3 in range(n[2]+3):
                for q5 in range(n[4]+1):
                    c_i1i2i3i4q5[i1,i2,i3,1,q5]=c_i1i2i3q4q5[i1,i2,i3,0,q5]/6 # c_{2}
                    c_i1i2i3i4q5[i1,i2,i3,n[k]+1,q5]=c_i1i2i3q4q5[i1,i2,i3,n[k],q5]/6 # c_{n+2}
                    A = np.zeros((n[k]-1,n[k]-1))
                    for i in range(n[k]-1):
                        A[i,i]=4
                        if i+1 < n[k]-1 :
                            A[i,i+1]= A[i+1,i]=1
                    B = np.zeros(n[k]-1)
                    B[0]=c_i1i2i3q4q5[i1,i2,i3,1,q5]-c_i1i2i3i4q5[i1,i2,i3,1,q5]
                    B[n[k]-2]=c_i1i2i3q4q5[i1,i2,i3,n[k]-1,q5]-c_i1i2i3i4q5[i1,i2,i3,n[k]+1,q5]
                    B[1:n[k]-2]=c_i1i2i3q4q5[i1,i2,i3,2:n[k]-1,q5]
                    sol = linalg.solve(A, B)
                    c_i1i2i3i4q5[i1,i2,i3,2:n[k]+1,q5]=sol
                    c_i1i2i3i4q5[i1,i2,i3,0,q5]=2*c_i1i2i3i4q5[i1,i2,i3,1,q5]-c_i1i2i3i4q5[i1,i2,i3,2,q5]
                    c_i1i2i3i4q5[i1,i2,i3,n[k]+2,q5]=2*c_i1i2i3i4q5[i1,i2,i3,n[k]+1,q5]-c_i1i2i3i4q5[i1,i2,i3,n[k],q5]
    
    k=4 # 5th dimension
    c_shape=()
    for j in range(N):
        if j <= k :
            c_shape += (n[j]+3,)
        else:
            c_shape += (n[j]+1,)
    c_i1i2i3i4i5=np.zeros(c_shape)
    del c_shape
    for i1 in range(n[0]+3):
        for i2 in range(n[1]+3):
            for i3 in range(n[2]+3):
                for i4 in range(n[3]+3):
                    c_i1i2i3i4i5[i1,i2,i3,i4,1]=c_i1i2i3i4q5[i1,i2,i3,i4,0]/6 # c_{2}
                    c_i1i2i3i4i5[i1,i2,i3,i4,n[k]+1]=c_i1i2i3i4q5[i1,i2,i3,i4,n[k]]/6 # c_{n+2}
                    A = np.zeros((n[k]-1,n[k]-1))
                    for i in range(n[k]-1):
                        A[i,i]=4
                        if i+1 < n[k]-1 :
                            A[i,i+1]= A[i+1,i]=1
                    B = np.zeros(n[k]-1)
                    B[0]=c_i1i2i3i4q5[i1,i2,i3,i4,1]-c_i1i2i3i4i5[i1,i2,i3,i4,1]
                    B[n[k]-2]=c_i1i2i3i4q5[i1,i2,i3,i4,n[k]-1]-c_i1i2i3i4i5[i1,i2,i3,i4,n[k]+1]
                    B[1:n[k]-2]=c_i1i2i3i4q5[i1,i2,i3,i4,2:n[k]-1]
                    sol = linalg.solve(A, B)
                    c_i1i2i3i4i5[i1,i2,i3,i4,2:n[k]+1]=sol
                    c_i1i2i3i4i5[i1,i2,i3,i4,0]=2*c_i1i2i3i4i5[i1,i2,i3,i4,1]-c_i1i2i3i4i5[i1,i2,i3,i4,2]
                    c_i1i2i3i4i5[i1,i2,i3,i4,n[k]+2]=2*c_i1i2i3i4i5[i1,i2,i3,i4,n[k]+1]-c_i1i2i3i4i5[i1,i2,i3,i4,n[k]]
    
    return c_i1i2i3i4i5

In [63]:
c_i1i2i3i4i5= compute_5D_spline_coefs(y_data)


In [64]:
c_i1i2i3i4i5[5,5,5,5,5]

3.335532324807976e-05

### compute spline interpolation and its gradient using JAX

In [65]:
from jax import lax
import jax.numpy as jnp

In [66]:
#5D-spline function (jittable and auto-differentiable)
def s5D(x,a,h,c_i1i2i3i4i5):
    '''
    5D-spline interpolation
    
    INPUTs
        x: 5-dim x vector (float) at which interplated y-value is evaluated 
        a: 5-dim vector (float) of the lower boundary of the each of the x-dimension
        h: 5-dim vector (float) of the grid interval of the each of the x-dimension
        c_i1i2i3i4i5: spline coefficent (5-dim array)
    '''
    
    def u(ii,aa,hh,xx):
        t= jnp.abs((xx-aa)/hh + 2 - ii)
        return lax.cond(t <= 1, lambda t: 4.-6.*t**2+3.*t**3, lambda t: (2.-t)**3, t)*jnp.heaviside(2.-t, 1.)
    
    def f(carry,i1,i2,i3,i4,i5,c_i1i2i3i4i5,a,h,x):
        val = c_i1i2i3i4i5[i1-1,i2-1,i3-1,i4-1,i5-1]*u(i1,a[0],h[0],x[0])*u(i2,a[1],h[1],x[1])*u(i3,a[2],h[2],x[2])*u(i4,a[3],h[3],x[3])*u(i5,a[4],h[4],x[4])
        carry += val
        return carry,val
    
    i1arr=jnp.arange(1,c_i1i2i3i4i5.shape[0]+1)
    i2arr=jnp.arange(1,c_i1i2i3i4i5.shape[1]+1)
    i3arr=jnp.arange(1,c_i1i2i3i4i5.shape[2]+1)
    i4arr=jnp.arange(1,c_i1i2i3i4i5.shape[3]+1)
    i5arr=jnp.arange(1,c_i1i2i3i4i5.shape[4]+1)
    
    carry, val =  lax.scan(lambda s1, i1: lax.scan(lambda s2, i2: lax.scan(lambda s3, i3: lax.scan(lambda s4, i4:  lax.scan(lambda s5, i5: f(s5, i1=i1, i2=i2, i3=i3, i4=i4, i5=i5, c_i1i2i3i4i5=c_i1i2i3i4i5,a=a,h=h,x=x), s4, i5arr), s3, i4arr), s2, i3arr), s1, i2arr), 0.0, i1arr)
    
    return carry

In [67]:
c_i1i2i3i4i5_jnp= jnp.array(c_i1i2i3i4i5)
a_jnp=jnp.array(a)
h_jnp=jnp.array(h)

In [68]:
x=jnp.array([0.5,1.0,1.5,2.0,2.5])

In [69]:
s5D(x,a_jnp,h_jnp,c_i1i2i3i4i5_jnp)

DeviceArray(0.2189883, dtype=float32)

In [70]:
y_data[5,5,5,5,5]

0.21898831147309558

In [71]:
from jax import grad, jit

In [72]:
s5D_jitted= jit(s5D)

In [73]:
s5D_jitted(x,a_jnp,h_jnp,c_i1i2i3i4i5_jnp)

DeviceArray(0.2189883, dtype=float32)

In [74]:
ds5D= grad(s5D)

In [75]:
ds5D(x,a_jnp,h_jnp,c_i1i2i3i4i5_jnp)

DeviceArray([ 0.40083998,  0.14059053,  0.01552507, -0.10017766,
             -0.29297203], dtype=float32)

In [76]:
ds5D_jitted= jit(grad(s5D))

In [77]:
ds5D_jitted(x,a_jnp,h_jnp,c_i1i2i3i4i5_jnp)

DeviceArray([ 0.40083998,  0.14059053,  0.01552507, -0.10017766,
             -0.29297203], dtype=float32)

In [85]:
x=jnp.array([0.7,1.0,1.5,2.0,2.5])

In [86]:
ds5D_jitted(x,a_jnp,h_jnp,c_i1i2i3i4i5_jnp)

DeviceArray([ 0.3491447 ,  0.1889156 ,  0.02086146, -0.13461155,
             -0.39367482], dtype=float32)

In [87]:
from jax import value_and_grad

In [88]:
s5D_fun= jit(value_and_grad(s5D))

In [89]:
s5D_fun(x,a_jnp,h_jnp,c_i1i2i3i4i5_jnp)

(DeviceArray(0.29426068, dtype=float32),
 DeviceArray([ 0.3491447 ,  0.1889156 ,  0.02086146, -0.13461155,
              -0.39367482], dtype=float32))

In [90]:
x=jnp.array([0.7,1.0,1.5,2.0,2.5])

In [91]:
s5D_fun(x,a_jnp,h_jnp,c_i1i2i3i4i5_jnp)

(DeviceArray(0.29426068, dtype=float32),
 DeviceArray([ 0.3491447 ,  0.1889156 ,  0.02086146, -0.13461155,
              -0.39367482], dtype=float32))