# 4D Radial Basis Function interpolation on time axis using scipy RBF (CPU) and pykeops (GPU)
Demo using multiquadric RBF

In [1]:
import numpy as np
import nibabel as nib

In [2]:
imgs = np.array([None]*2)

In [3]:
#for i, _ in enumerate(imgs):
for i in range(2):
    folder = f"{i+1:02d}"
    imgs[i] = nib.load("../../Elies-longitudinal-data-test/"+folder+"/Flair.nii.gz")

In [4]:
zrdim, yrdim, xrdim = imgs[0].shape

In [5]:
imgs_data = np.array([img.get_fdata(dtype=np.float32) for img in imgs])

In [6]:
zdim, ydim, xdim = 10,10,10
intervals = (10,)

In [7]:
#imgs_data = np.random.randn(2,zdim,ydim,xdim)
imgs_data = imgs_data[:, zrdim//2:zrdim//2+zdim, yrdim//2:yrdim//2+ydim, xrdim//2:xrdim//2+xdim]
imgs_data.shape

(2, 10, 10, 10)

In [8]:
def gaussian_kernel(x, y, sigma=.1):
    x_i = LazyTensor(x[:, None, :])  # (M, 1, :)
    y_j = LazyTensor(y[None, :, :])  # (1, N, :)
    D_ij = ((x_i - y_j) ** 2).sum(-1)  # (M, N) symbolic matrix of squared distances
    return (- D_ij / (2 * sigma ** 2)).exp()  # (M, N) symbolic Gaussian kernel matrix
def laplacian_kernel(x, y, sigma=.1):
    x_i = LazyTensor(x[:, None, :])  # (M, 1, :)
    y_j = LazyTensor(y[None, :, :])  # (1, N, :)
    D_ij = ((x_i - y_j) ** 2).sum(-1)  # (M, N) symbolic matrix of squared distances
    return (- D_ij.sqrt() / sigma).exp()  # (M, N) symbolic Laplacian kernel matrix
def multiquadric_kernel(x, y, epsilon=1):
    x_i = LazyTensor(x[:, None, :])  # (M, 1, :)
    y_j = LazyTensor(y[None, :, :])  # (1, N, :)
    D_ij = ((x_i - y_j) ** 2).sum(-1)  # (M, N) symbolic matrix of squared distances
    return ((1/epsilon * D_ij) ** 2 + 1).sqrt()

In [9]:
def flatten_tensor(T):
    # Flattens the M x N x ..., x D tensor T while preserving original indices
    # to the output shape (MxNx...xD, F, M, N, ..., D)
    # where F is a vector of the values in T, and M, N, ..., D a vector of 
    # original indices in each dimension in T (for the values in F).
    # https://stackoverflow.com/questions/46135070/generalise-slicing-operation-in-a-numpy-array/46135084#46135084
    n = T.ndim
    grid = np.ogrid[tuple(map(slice, T.shape))]
    out = np.empty(T.shape + (n+1,), dtype=T.dtype)
    for i in range(n):
        out[...,i+1] = grid[i]
    out[...,0] = T
    out.shape = (-1,n+1)
    # Return everything
    return out
    # Only return voxels that are not np.nan
    #return out[~np.isnan(out[:,0])]
    # Only return voxels that are not zero
    #return out[out[:,0] != 0]

In [10]:
# Flatten the stacked data, for use in Rbf
imgs_data_flattened = flatten_tensor(imgs_data)

In [11]:
# Get the colums in the flattened data
# The voxel values
b = imgs_data_flattened[:,0]
#"""
# Time coordinates of the voxel values
t = imgs_data_flattened[:,1]
# Z coordinates of the voxel values
z = imgs_data_flattened[:,2]
# Y coordinates of the voxel values
y = imgs_data_flattened[:,3]
# X coordinates of the voxel values
x = imgs_data_flattened[:,4]
#"""

x_all = imgs_data_flattened[:,1:]

## Scipy Rbf (CPU)

In [12]:
from scipy.interpolate import Rbf

In [13]:
"""
# Make grids of indices with resolutions we want after the interpolation
grids = [np.mgrid[time_idx:time_idx+1:1/interval_duration, 0:zdim, 0:ydim, 0:xdim] \
for time_idx, interval_duration in enumerate(intervals)]
# Stack all grids
TI, ZI, YI, XI = np.hstack(tuple(grids))

# Create radial basis functions
#rbf_clinst = Rbf(t, z, y, x, f, function="multiquadric", norm='euclidean')
rbf = Rbf(t, z, y, x, b, function='multiquadric') # If scipy 1.1.0 , only euclidean, default

# Interpolate the voxel values f to have values for the indices in the grids,
# resulting in interpolated voxel values FI
# This uses the Rbfs
FI = rbf(TI, ZI, YI, XI)

data_interpolated = FI

volfig()
volshow(data_interpolated)
"""

'\n# Make grids of indices with resolutions we want after the interpolation\ngrids = [np.mgrid[time_idx:time_idx+1:1/interval_duration, 0:zdim, 0:ydim, 0:xdim] for time_idx, interval_duration in enumerate(intervals)]\n# Stack all grids\nTI, ZI, YI, XI = np.hstack(tuple(grids))\n\n# Create radial basis functions\n#rbf_clinst = Rbf(t, z, y, x, f, function="multiquadric", norm=\'euclidean\')\nrbf = Rbf(t, z, y, x, b, function=\'multiquadric\') # If scipy 1.1.0 , only euclidean, default\n\n# Interpolate the voxel values f to have values for the indices in the grids,\n# resulting in interpolated voxel values FI\n# This uses the Rbfs\nFI = rbf(TI, ZI, YI, XI)\n\ndata_interpolated = FI\n\nvolfig()\nvolshow(data_interpolated)\n'

In [15]:
import os.path
import sys
sys.path.append('/home/ivar/Downloads/keops')

In [16]:
import torch
from pykeops.torch import LazyTensor
use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

In [17]:
xi = np.asarray([np.asarray(a, dtype=np.float_).flatten()
                              for a in (t, z, y, x)])
N = xi.shape[-1]
ximax = np.amax(xi, axis=1)
ximin = np.amin(xi, axis=1)
edges = ximax - ximin
edges = edges[np.nonzero(edges)]
newepsilon = np.power(np.prod(edges)/N, 1.0/edges.size)
newepsilon

0.7770060192334054

In [18]:
epsilon = torch.from_numpy(np.array(newepsilon)).type(dtype)
epsilon

tensor(0.7770, device='cuda:0')

In [19]:
b = torch.from_numpy(b).type(dtype).view(-1,1)
x_all = torch.from_numpy(x_all).type(dtype)

In [20]:
#K_xx = gaussian_kernel(x, x, sigma=100)
#K_xx = laplacian_kernel(x_all, x_all)
K_xx = multiquadric_kernel(x_all, x_all, epsilon=epsilon)

In [21]:
alpha = 10  # Ridge regularization
a = K_xx.solve(b, alpha=alpha)
#a = K_xx.solve(b)

Compiling libKeOpstorch892fdfdd61 in /home/ivar/Downloads/keops/pykeops/common/../build/build-libKeOpstorch892fdfdd61:
       formula: Sum_Reduction((Sqrt((Square((Var(1,1,2) * Sum(Square((Var(2,4,0) - Var(3,4,1)))))) + IntCst(1))) * Var(0,1,1)),0)
       aliases: Var(0,1,1); Var(1,1,2); Var(2,4,0); Var(3,4,1); 
       dtype  : float32
... Done.


In [22]:
#Z = torch.linspace(0, zdim-1, zdim).type(dtype)
Z = torch.from_numpy(np.mgrid[0:zdim]).type(dtype)
Y = torch.from_numpy(np.mgrid[0:ydim]).type(dtype)
X = torch.from_numpy(np.mgrid[0:xdim]).type(dtype)
T = torch.stack(tuple(torch.from_numpy(np.mgrid[time_idx:time_idx+1:1/interval_duration]).type(dtype) for time_idx, interval_duration in enumerate(intervals)), dim=0).view(-1)
T, Z, Y, X = torch.meshgrid(T, Z, Y, X)
t = torch.stack((T.contiguous().view(-1), \
                 Z.contiguous().view(-1), \
                 Y.contiguous().view(-1), \
                 X.contiguous().view(-1)), dim=1)

K_tx = multiquadric_kernel(t, x_all, epsilon=epsilon)
mean_t = K_tx @ a
mean_t = mean_t.view(np.sum(intervals), zdim, ydim, xdim)

Compiling libKeOpstorch0b06cb4531 in /home/ivar/Downloads/keops/pykeops/common/../build/build-libKeOpstorch0b06cb4531:
       formula: Sum_Reduction((Sqrt((Square((Var(0,1,2) * Sum(Square((Var(1,4,0) - Var(2,4,1)))))) + IntCst(1))) * Var(3,1,1)),0)
       aliases: Var(0,1,2); Var(1,4,0); Var(2,4,1); Var(3,1,1); 
       dtype  : float32
... Done.
