In [20]:
%load_ext autoreload
%autoreload 2
import site
import sys
import time
site.addsitedir('..')
from jax.config import config

config.update("jax_enable_x64", True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
import numpy as np
import jax.numpy as jnp
from  matplotlib import pyplot as plt
from src.interpolate import *
from src.algorithm import conjugate_gradient
import jax


Create the grids and the volume

In [22]:
nx = 128

x_freq = jnp.fft.fftfreq(nx, 1/nx)
y_freq = x_freq
z_freq = x_freq

X, Y, Z = jnp.meshgrid(x_freq, y_freq, z_freq)


x_grid = np.array([x_freq[1], len(x_freq)])
y_grid = np.array([y_freq[1], len(y_freq)])
z_grid = np.array([z_freq[1], len(z_freq)])

vol = jnp.array(np.random.randn(nx,nx,nx)) + 100

all_coords = jnp.array([X.ravel(), Y.ravel(), Z.ravel()])

Generate points on the grid

In [23]:
N = 10

# The points will be on the grid between low and high. 
# Going outside the range by half, to ensure wrap-around works well.
#low = -nx/2-nx/4
#high = nx/2 + nx/4

# Actually not going outside the range for now
low = -nx/2
high = nx/2

pts = jnp.array(np.random.randint(low, high, size = (3,N))).astype(jnp.float64)


In [24]:
pts.shape

(3, 10)

In [25]:
all_coords.shape

(3, 2097152)

In [26]:
@jax.jit
def interpolate_fun(vol):
    return interpolate(all_coords, x_grid, y_grid, z_grid, vol, "tri")

The interpolated values, i.e. the data.

In [27]:
data = interpolate_fun(vol)

In [28]:
@jax.jit
def loss_fun(v):
    return 1/(2*nx*nx*nx)*jnp.sum((interpolate_fun(v) - data)**2)

@jax.jit
def loss_fun_grad(v):
    return jax.grad(loss_fun)(v)

#@jax.jit
#def loss_fun_grad_array(coords, ivals):
#    return jax.vmap(loss_fun_grad, in_axes = (1, 0))(coords, ivals).T

In [29]:
loss_fun(vol)

DeviceArray(0., dtype=float64)

In [30]:
loss_fun_grad(vol).shape

(128, 128, 128)

And solve the inverse problem

In [37]:
v0 = jnp.array(np.random.randn(nx,nx,nx))
#v0 = vol + 0.1*np.random.randn(nx,nx,nx)


vk = v0
N_iter = 1000
alpha = 100000
for k in range(N_iter):
    vk = vk - alpha * loss_fun_grad(vk)
    
    if jnp.mod(k,100) == 0:
        loss = loss_fun(vk)
        print("iter ", k, ", loss = ", loss)

err = jnp.max(jnp.abs(vk-vol))
print("err =", err)

iter  0 , loss =  4535.322380631188
iter  100 , loss =  0.25873208591614383
iter  200 , loss =  1.4760205926795128e-05
iter  300 , loss =  8.420435302020501e-10
iter  400 , loss =  4.803708771203873e-14
iter  500 , loss =  2.7404306649972042e-18
iter  600 , loss =  1.5633979759919096e-22
iter  700 , loss =  1.0097419586828951e-26
iter  800 , loss =  1.0097419586828951e-26
iter  900 , loss =  1.0097419586828951e-26
err = 1.4210854715202004e-13


In [38]:
v0 = jnp.array(np.random.randn(nx,nx,nx))
zero = jnp.zeros(vol.shape)
AA = lambda a : loss_fun_grad(a) - loss_fun_grad(zero)
b = - loss_fun_grad(zero)


vcg, max_iter = conjugate_gradient(AA, b, v0, 10, verbose = True)
err = jnp.max(jnp.abs(vcg-vol))
print("err =", err)

Iter 0 ||r|| = 5.940862895774757e-12
Iter 1 ||r|| = 8.057426014748657e-10
Iter 2 ||r|| = 8.057207009244758e-10
Iter 3 ||r|| = 8.056988003673399e-10
Iter 4 ||r|| = 8.056769015589757e-10
Iter 5 ||r|| = 8.056550036238406e-10
Iter 6 ||r|| = 8.056331065628839e-10
Iter 7 ||r|| = 8.056112079882385e-10
Iter 8 ||r|| = 8.05589311799068e-10
Iter 9 ||r|| = 8.055674154852273e-10
err = 0.0001371806204559789
