In [2]:
%load_ext autoreload
%autoreload 2
import site
import sys
import time
site.addsitedir('..')
from jax.config import config

config.update("jax_enable_x64", True)

In [3]:
import numpy as np
import jax.numpy as jnp
from  matplotlib import pyplot as plt
from src.interpolate import *
from src.algorithm import conjugate_gradient
import jax


Create the grids and the volume

In [46]:
nx = 128

x_freq = jnp.fft.fftfreq(nx, 1/nx)
y_freq = x_freq
z_freq = x_freq

X, Y, Z = jnp.meshgrid(x_freq, y_freq, z_freq)


x_grid = np.array([x_freq[1], len(x_freq)])
y_grid = np.array([y_freq[1], len(y_freq)])
z_grid = np.array([z_freq[1], len(z_freq)])

vol = jnp.array(np.random.randn(nx,nx,nx)) + 100

all_coords = jnp.array([X.ravel(), Y.ravel(), Z.ravel()])

Generate points on the grid

In [47]:
N = 10

# The points will be on the grid between low and high. 
# Going outside the range by half, to ensure wrap-around works well.
#low = -nx/2-nx/4
#high = nx/2 + nx/4

# Actually not going outside the range for now
low = -nx/2
high = nx/2

pts = jnp.array(np.random.randint(low, high, size = (3,N))).astype(jnp.float64)


In [48]:
pts.shape

(3, 10)

In [49]:
all_coords.shape

(3, 2097152)

In [50]:
@jax.jit
def interpolate_fun(vol):
    return interpolate(all_coords, x_grid, y_grid, z_grid, vol, "tri")

The interpolated values, i.e. the data.

In [51]:
data = interpolate_fun(vol)

In [52]:
@jax.jit
def loss_fun(v):
    return 1/(2*nx*nx*nx)*jnp.sum((interpolate_fun(v) - data)**2)

@jax.jit
def loss_fun_grad(v):
    return jax.grad(loss_fun)(v)

#@jax.jit
#def loss_fun_grad_array(coords, ivals):
#    return jax.vmap(loss_fun_grad, in_axes = (1, 0))(coords, ivals).T

In [55]:
loss_fun(vol)

DeviceArray(0., dtype=float64)

In [56]:
loss_fun_grad(vol).shape

(128, 128, 128)

And solve the inverse problem

In [59]:
v0 = jnp.array(np.random.randn(nx,nx,nx))
zero = jnp.zeros(vol.shape)
AA = lambda a : loss_fun_grad(a) - loss_fun_grad(zero)
b = - loss_fun_grad(zero)


vcg, max_iter = conjugate_gradient(AA, b, v0, 10, verbose = True)
err = jnp.max(jnp.abs(vcg-vol))
print("err =", err)

Iter 0 ||r|| = 5.941095886675326e-12
Iter 1 ||r|| = 7.024772364448904e-10
Iter 2 ||r|| = 7.024521148288821e-10
Iter 3 ||r|| = 7.024269933413921e-10
Iter 4 ||r|| = 7.024018742725975e-10
Iter 5 ||r|| = 7.023767564888896e-10
Iter 6 ||r|| = 7.023516399769891e-10
Iter 7 ||r|| = 7.023265211682924e-10
Iter 8 ||r|| = 7.023014072735951e-10
Iter 9 ||r|| = 7.022762939345974e-10
err = 0.00010429372564146888
