# Reconstruction from interpolated values with known locations

In [1]:
%load_ext autoreload
%autoreload 2
import site
import sys
import time
site.addsitedir('..')
from jax.config import config

config.update("jax_enable_x64", True)

In [2]:
import numpy as np
import jax.numpy as jnp
from  matplotlib import pyplot as plt
from src.interpolate import *
from src.algorithm import conjugate_gradient
import jax
from src.utils import mip_z

Create the grids and the volume

In [3]:
nx = 5

x_freq = jnp.fft.fftfreq(nx, 1/nx)
y_freq = x_freq
z_freq = x_freq

X, Y, Z = jnp.meshgrid(x_freq, y_freq, z_freq)


x_grid = np.array([x_freq[1], len(x_freq)])
y_grid = np.array([y_freq[1], len(y_freq)])
z_grid = np.array([z_freq[1], len(z_freq)])

vol = jnp.array(np.random.randn(nx,nx,nx)) + 100

all_coords = jnp.array([X.ravel(), Y.ravel(), Z.ravel()])

Generate points on or off the grid

In [4]:
N = 2000

# The points will be on the grid between low and high. 
# Going outside the range by half, to ensure wrap-around works well.
low = -3*nx
high = 3*nx

# Actually not going outside the range for now
#low = -nx/2
#high = nx/2

#pts = jnp.array(np.random.randint(low, high, size = (3,N))).astype(jnp.float64)

# Points between grid points
pts = 3*nx * np.random.randn(3,N)


In [5]:
@jax.jit
def interpolate_fun(vol):
    return interpolate(pts, x_grid, y_grid, z_grid, vol, "tri")

The interpolated values, i.e. the data. When the points above are on the grid, this is the volume.

In [6]:
data = interpolate_fun(vol)

In [7]:
@jax.jit
def loss_fun(v):
    return 1/(2*len(data))*jnp.sum((interpolate_fun(v) - data)**2)

@jax.jit
def loss_fun_grad(v):
    return jax.grad(loss_fun)(v)

In [8]:
loss_fun(vol)

DeviceArray(0., dtype=float64)

In [9]:
loss_fun_grad(vol).shape

(5, 5, 5)

And solve the inverse problem

In [10]:
v0 = jnp.array(np.random.randn(nx,nx,nx))
#v0 = vol + 0.1*np.random.randn(nx,nx,nx)


vk = v0
N_iter = 2000
alpha = 100
for k in range(N_iter):
    vk = vk - alpha * loss_fun_grad(vk)
    
    if jnp.mod(k,100) == 0:
        loss = loss_fun(vk)
        print("iter ", k, ", loss = ", loss)

err = jnp.mean(jnp.abs(vk-vol))
print("err =", err)

iter  0 , loss =  197.5650115972626
iter  100 , loss =  7.114944364811719e-06
iter  200 , loss =  4.246194136691046e-09
iter  300 , loss =  3.071253985583201e-12
iter  400 , loss =  2.2856403460897854e-15
iter  500 , loss =  1.7127486624861303e-18
iter  600 , loss =  1.2864654141846269e-21
iter  700 , loss =  9.669405116672652e-25
iter  800 , loss =  1.0234744493209826e-27
iter  900 , loss =  4.138427417661846e-28
iter  1000 , loss =  4.138427417661846e-28
iter  1100 , loss =  4.138427417661846e-28
iter  1200 , loss =  4.138427417661846e-28
iter  1300 , loss =  4.138427417661846e-28
iter  1400 , loss =  4.138427417661846e-28
iter  1500 , loss =  4.138427417661846e-28
iter  1600 , loss =  4.138427417661846e-28
iter  1700 , loss =  4.138427417661846e-28
iter  1800 , loss =  4.138427417661846e-28
iter  1900 , loss =  4.138427417661846e-28
err = 1.0470557754160837e-13


In [11]:
jnp.mean(jnp.abs(vk-vol))

DeviceArray(1.04705578e-13, dtype=float64)

In [12]:
v0 = jnp.array(np.random.randn(nx,nx,nx))
zero = jnp.zeros(vol.shape)
AA = lambda a : loss_fun_grad(a) - loss_fun_grad(zero)
b = - loss_fun_grad(zero)


vcg, max_iter = conjugate_gradient(AA, b, v0, 100, verbose = True)
err = jnp.mean(jnp.abs(vcg-vol))
print("err =", err)

Iter 0 ||r|| = 0.5433335111973929
Iter 1 ||r|| = 0.17033140194651533
Iter 2 ||r|| = 0.06299527291969312
Iter 3 ||r|| = 0.02835262999559435
Iter 4 ||r|| = 0.01654688960755851
Iter 5 ||r|| = 0.010126006269190992
Iter 6 ||r|| = 0.0051273733375634175
Iter 7 ||r|| = 0.003014525261038509
Iter 8 ||r|| = 0.0017783121354024917
Iter 9 ||r|| = 0.0012745583766309836
Iter 10 ||r|| = 0.0008379696159835487
Iter 11 ||r|| = 0.0005625667629377647
Iter 12 ||r|| = 0.00033549928016624416
Iter 13 ||r|| = 0.0002515352752152086
Iter 14 ||r|| = 0.00014084130462132892
Iter 15 ||r|| = 8.015050056803218e-05
Iter 16 ||r|| = 5.154101317950883e-05
Iter 17 ||r|| = 3.2301895216020814e-05
Iter 18 ||r|| = 2.0895866165151074e-05
Iter 19 ||r|| = 1.3639403757749954e-05
Iter 20 ||r|| = 8.899939064961428e-06
Iter 21 ||r|| = 5.041797005316252e-06
Iter 22 ||r|| = 2.7971540031916253e-06
Iter 23 ||r|| = 1.391775348655293e-06
Iter 24 ||r|| = 8.453442106579741e-07
Iter 25 ||r|| = 5.123527600919914e-07
Iter 26 ||r|| = 3.29205664391