### 3D natural cubic spline method assuming equidistant grids in each dimension, of Habermann and Kindermann 2007,
### jittable, auto-differentiable using JAX, 2022/10/4

In [15]:
import numpy as np
from scipy import linalg

In [16]:
N=3 # dimension (number of independent variables)

In [17]:
# lower and uppder bounds of x-coordinate in each dimension [a,b]
a=[0,0,0] # [1st dim, 2nd dim, 3rd dim]
b=[1,2,3] # [1st dim, 2nd dim, 3rd dim]
a= np.array(a, dtype=float)
b= np.array(b, dtype=float)

In [18]:
# number of grid interval n in each dimension
n=[10,20,30] # [1st dim, 2nd dim, 3rd dim]
n= np.array(n, dtype=int)

In [19]:
# grid interval n in each dimension
h=(b-a)/n

In [20]:
#x_grid=(np.linspace(a[0],b[0],n[0]+1), np.linspace(a[1],b[1],n[1]+1))  # tuple of (1st dim grid points , 2nd dim grid points) 
x_grid=()
for j in range(N):
    x_grid += (np.linspace(a[j],b[j],n[j]+1),)
# x_grid[0] is numpy array of 1st dim grid points
# x_grid[1] is numpy array of 2nd dim grid points
# x_grid

In [21]:
grid_shape=()
for j in range(N):
    grid_shape += (n[j]+1,)
y_data= np.zeros(grid_shape)

In [22]:
# test y_data
for q1 in range(n[0]+1):
    for q2 in range(n[1]+1):
        for q3 in range(n[2]+1):
            y_data[q1,q2,q3]= np.sin(x_grid[0][q1])*np.sin(x_grid[1][q2])*np.sin(x_grid[2][q3])
#y_data

In [23]:
def compute_3D_spline_coefs(y_data):
    '''
    Compute 3D spline coefficient matrix c_i1i2i3
    
    INPUTs:
        y_data: 3D numpy array of y data (real scalar) at x_grid points 
    '''
    
    N=3 # dimension of the problem
    
    for j in range(N):
        n[j]= y_data.shape[j]-1 #number of grid interval n in each dimension
    
    
    k=0 # 1-st dimension
    c_shape=()
    for j in range(N):
        if j <= k :
            c_shape += (n[j]+3,)
        else:
            c_shape += (n[j]+1,)
    c_i1q2q3=np.zeros(c_shape)
    del c_shape
    for q2 in range(n[1]+1):
        for q3 in range(n[2]+1):
            c_i1q2q3[1,q2,q3]=y_data[0,q2,q3]/6 # c_{2}
            c_i1q2q3[n[k]+1,q2,q3]=y_data[n[k],q2,q3]/6 # c_{n+2}
            A = np.zeros((n[k]-1,n[k]-1))
            for i in range(n[k]-1):
                A[i,i]=4
                if i+1 < n[k]-1 :
                    A[i,i+1]= A[i+1,i]=1
            B = np.zeros(n[k]-1)
            B[0]=y_data[1,q2,q3]-c_i1q2q3[1,q2,q3]
            B[n[k]-2]=y_data[n[k]-1,q2,q3]-c_i1q2q3[n[k]+1,q2,q3]
            B[1:n[k]-2]=y_data[2:n[k]-1,q2,q3]
            sol = linalg.solve(A, B)
            c_i1q2q3[2:n[k]+1,q2,q3]=sol
            c_i1q2q3[0,q2,q3]=2*c_i1q2q3[1,q2,q3]-c_i1q2q3[2,q2,q3]
            c_i1q2q3[n[k]+2,q2,q3]=2*c_i1q2q3[n[k]+1,q2,q3]-c_i1q2q3[n[k],q2,q3]
            
    k=1 # 2nd dimension
    c_shape=()
    for j in range(N):
        if j <= k :
            c_shape += (n[j]+3,)
        else:
            c_shape += (n[j]+1,)
    c_i1i2q3=np.zeros(c_shape)
    del c_shape
    for i1 in range(n[0]+3):
        for q3 in range(n[2]+1):
            c_i1i2q3[i1,1,q3]=c_i1q2q3[i1,0,q3]/6 # c_{2}
            c_i1i2q3[i1,n[k]+1,q3]=c_i1q2q3[i1,n[k],q3]/6 # c_{n+2}
            A = np.zeros((n[k]-1,n[k]-1))
            for i in range(n[k]-1):
                A[i,i]=4
                if i+1 < n[k]-1 :
                    A[i,i+1]= A[i+1,i]=1
            B = np.zeros(n[k]-1)
            B[0]=c_i1q2q3[i1,1,q3]-c_i1i2q3[i1,1,q3]
            B[n[k]-2]=c_i1q2q3[i1,n[k]-1,q3]-c_i1i2q3[i1,n[k]+1,q3]
            B[1:n[k]-2]=c_i1q2q3[i1,2:n[k]-1,q3]
            sol = linalg.solve(A, B)
            c_i1i2q3[i1,2:n[k]+1,q3]=sol
            c_i1i2q3[i1,0,q3]=2*c_i1i2q3[i1,1,q3]-c_i1i2q3[i1,2,q3]
            c_i1i2q3[i1,n[k]+2,q3]=2*c_i1i2q3[i1,n[k]+1,q3]-c_i1i2q3[i1,n[k],q3]
            
    k=2 # 3rd dimension
    c_shape=()
    for j in range(N):
        if j <= k :
            c_shape += (n[j]+3,)
        else:
            c_shape += (n[j]+1,)
    c_i1i2i3=np.zeros(c_shape)
    del c_shape
    for i1 in range(n[0]+3):
        for i2 in range(n[1]+3):
            c_i1i2i3[i1,i2,1]=c_i1i2q3[i1,i2,0]/6 # c_{2}
            c_i1i2i3[i1,i2,n[k]+1]=c_i1i2q3[i1,i2,n[k]]/6 # c_{n+2}
            A = np.zeros((n[k]-1,n[k]-1))
            for i in range(n[k]-1):
                A[i,i]=4
                if i+1 < n[k]-1 :
                    A[i,i+1]= A[i+1,i]=1
            B = np.zeros(n[k]-1)
            B[0]=c_i1i2q3[i1,i2,1]-c_i1i2i3[i1,i2,1]
            B[n[k]-2]=c_i1i2q3[i1,i2,n[k]-1]-c_i1i2i3[i1,i2,n[k]+1]
            B[1:n[k]-2]=c_i1i2q3[i1,i2,2:n[k]-1]
            sol = linalg.solve(A, B)
            c_i1i2i3[i1,i2,2:n[k]+1]=sol
            c_i1i2i3[i1,i2,0]=2*c_i1i2i3[i1,i2,1]-c_i1i2i3[i1,i2,2]
            c_i1i2i3[i1,i2,n[k]+2]=2*c_i1i2i3[i1,i2,n[k]+1]-c_i1i2i3[i1,i2,n[k]]
    
    return c_i1i2i3

In [24]:
c_i1i2i3= compute_3D_spline_coefs(y_data)

In [25]:
c_i1i2i3[6,9,10]

0.0012534830772130796

### compute spline interpolation and its gradient using JAX

In [26]:
from jax import lax
import jax.numpy as jnp

In [27]:
#3D-spline function (jittable and auto-differentiable)
def s3D(x,a,h,c_i1i2i3):
    '''
    3D-spline interpolation
    
    INPUTs
        x: 2-dim x vector (float) at which interplated y-value is evaluated 
        a: 2-dim vector (float) of the lower boundary of the each of the x-dimension
        h: 2-dim vector (float) of the grid interval of the each of the x-dimension
        c_i1i2i3: spline coefficent (3-dim array)
    '''
    
    def u(ii,aa,hh,xx):
        t= jnp.abs((xx-aa)/hh + 2 - ii)
        return lax.cond(t <= 1, lambda t: 4.-6.*t**2+3.*t**3, lambda t: (2.-t)**3, t)*jnp.heaviside(2.-t, 1.)
    
    def f(carry,i1,i2,i3,c_i1i2i3,a,h,x):
        val = c_i1i2i3[i1-1,i2-1,i3-1]*u(i1,a[0],h[0],x[0])*u(i2,a[1],h[1],x[1])*u(i3,a[2],h[2],x[2])
        carry += val
        return carry,val
    
    i1arr=jnp.arange(1,c_i1i2i3.shape[0]+1)
    i2arr=jnp.arange(1,c_i1i2i3.shape[1]+1)
    i3arr=jnp.arange(1,c_i1i2i3.shape[2]+1)
    
    carry, val =  lax.scan(lambda s1, i1: lax.scan(lambda s2, i2: lax.scan(lambda s3, i3: f(s3, i1=i1, i2=i2, i3=i3, c_i1i2i3=c_i1i2i3,a=a,h=h,x=x), s2, i3arr), s1, i2arr), 0.0, i1arr)
    
    return carry

In [28]:
c_i1i2i3_jnp= jnp.array(c_i1i2i3)
a_jnp=jnp.array(a)
h_jnp=jnp.array(h)



In [29]:
x=jnp.array([0.5,1.0,1.6])

In [30]:
s3D(x,a_jnp,h_jnp,c_i1i2i3_jnp)

DeviceArray(0.4032507, dtype=float32)

In [31]:
y_data[5,10,15]

0.40241210089342777

In [32]:
from jax import grad, jit

In [33]:
s3D_jitted= jit(s3D)

In [34]:
s3D_jitted(x,a_jnp,h_jnp,c_i1i2i3_jnp)

DeviceArray(0.4032507, dtype=float32)

In [35]:
ds3D= grad(s3D)

In [36]:
ds3D(x,a_jnp,h_jnp,c_i1i2i3_jnp)

DeviceArray([ 0.7381165 ,  0.25892413, -0.01177985], dtype=float32)

In [37]:
ds3D_jitted= jit(grad(s3D))

In [38]:
ds3D_jitted(x,a_jnp,h_jnp,c_i1i2i3_jnp)

DeviceArray([ 0.7381165 ,  0.25892413, -0.01177985], dtype=float32)

In [39]:
x=jnp.array([0.5,1.0,1.6])
ds3D_jitted(x,a_jnp,h_jnp,c_i1i2i3_jnp)

DeviceArray([ 0.7381165 ,  0.25892413, -0.01177985], dtype=float32)

In [40]:
from jax import value_and_grad

In [41]:
s3D_fun= jit(value_and_grad(s3D))

In [42]:
s3D_fun(x,a_jnp,h_jnp,c_i1i2i3_jnp)

(DeviceArray(0.4032507, dtype=float32),
 DeviceArray([ 0.7381165 ,  0.25892413, -0.01177985], dtype=float32))

In [43]:
x=jnp.array([0.5,1.0,1.5])
s3D_fun(x,a_jnp,h_jnp,c_i1i2i3_jnp)

(DeviceArray(0.4024121, dtype=float32),
 DeviceArray([0.7365817 , 0.25838578, 0.02853701], dtype=float32))