In [2]:
%load_ext autoreload
%autoreload 2
import site
import sys
import time
site.addsitedir('..')
from jax.config import config

config.update("jax_enable_x64", True)

In [3]:
import numpy as np
import jax.numpy as jnp
from  matplotlib import pyplot as plt
from src.interpolate import *
from src.algorithm import conjugate_gradient
import jax


Create the grids and the volume

In [46]:
nx = 128

x_freq = jnp.fft.fftfreq(nx, 1/nx)
y_freq = x_freq
z_freq = x_freq

X, Y, Z = jnp.meshgrid(x_freq, y_freq, z_freq)


x_grid = np.array([x_freq[1], len(x_freq)])
y_grid = np.array([y_freq[1], len(y_freq)])
z_grid = np.array([z_freq[1], len(z_freq)])

vol = jnp.array(np.random.randn(nx,nx,nx)) + 100

all_coords = jnp.array([X.ravel(), Y.ravel(), Z.ravel()])

Generate points on the grid

In [47]:
N = 10

# The points will be on the grid between low and high. 
# Going outside the range by half, to ensure wrap-around works well.
#low = -nx/2-nx/4
#high = nx/2 + nx/4

# Actually not going outside the range for now
low = -nx/2
high = nx/2

pts = jnp.array(np.random.randint(low, high, size = (3,N))).astype(jnp.float64)


In [48]:
pts.shape

(3, 10)

In [49]:
all_coords.shape

(3, 2097152)

In [50]:
@jax.jit
def interpolate_fun(vol):
    return interpolate(all_coords, x_grid, y_grid, z_grid, vol, "tri")

The interpolated values, i.e. the data.

In [51]:
data = interpolate_fun(vol)

In [52]:
@jax.jit
def loss_fun(v):
    return 1/(2*nx*nx*nx)*jnp.sum((interpolate_fun(v) - data)**2)

@jax.jit
def loss_fun_grad(v):
    return jax.grad(loss_fun)(v)

#@jax.jit
#def loss_fun_grad_array(coords, ivals):
#    return jax.vmap(loss_fun_grad, in_axes = (1, 0))(coords, ivals).T

In [55]:
loss_fun(vol)

DeviceArray(0., dtype=float64)

In [56]:
loss_fun_grad(vol).shape

(128, 128, 128)

And solve the inverse problem

In [24]:
v0 = jnp.array(np.random.randn(nx,nx,nx))
#v0 = vol + 0.1*np.random.randn(nx,nx,nx)


vk = v0
N_iter = 20000
alpha = 1
for k in range(N_iter):
    vk = vk - alpha * loss_fun_grad(vkx)
    
    if jnp.mod(k,1000) == 0:
        loss = loss_fun(vk)
        print("iter ", k, ", loss = ", loss)

err = jnp.max(jnp.abs(vk-vol))
print("err =", err)

iter  0 , loss =  4920.983093914427
iter  1000 , loss =  0.0005192735256950996
iter  2000 , loss =  5.479494432746607e-11
iter  3000 , loss =  5.782096105923676e-18
iter  4000 , loss =  5.737725394277006e-25
iter  5000 , loss =  3.881448089177049e-25
iter  6000 , loss =  3.881448089177049e-25
iter  7000 , loss =  3.881448089177049e-25
iter  8000 , loss =  3.881448089177049e-25
iter  9000 , loss =  3.881448089177049e-25
iter  10000 , loss =  3.881448089177049e-25
iter  11000 , loss =  3.881448089177049e-25
iter  12000 , loss =  3.881448089177049e-25
iter  13000 , loss =  3.881448089177049e-25
iter  14000 , loss =  3.881448089177049e-25
iter  15000 , loss =  3.881448089177049e-25
iter  16000 , loss =  3.881448089177049e-25
iter  17000 , loss =  3.881448089177049e-25
iter  18000 , loss =  3.881448089177049e-25
iter  19000 , loss =  3.881448089177049e-25
err = 8.810729923425242e-13


In [59]:
v0 = jnp.array(np.random.randn(nx,nx,nx))
zero = jnp.zeros(vol.shape)
AA = lambda a : loss_fun_grad(a) - loss_fun_grad(zero)
b = - loss_fun_grad(zero)


vcg, max_iter = conjugate_gradient(AA, b, v0, 10, verbose = True)
err = jnp.max(jnp.abs(vcg-vol))
print("err =", err)

Iter 0 ||r|| = 5.941095886675326e-12
Iter 1 ||r|| = 7.024772364448904e-10
Iter 2 ||r|| = 7.024521148288821e-10
Iter 3 ||r|| = 7.024269933413921e-10
Iter 4 ||r|| = 7.024018742725975e-10
Iter 5 ||r|| = 7.023767564888896e-10
Iter 6 ||r|| = 7.023516399769891e-10
Iter 7 ||r|| = 7.023265211682924e-10
Iter 8 ||r|| = 7.023014072735951e-10
Iter 9 ||r|| = 7.022762939345974e-10
err = 0.00010429372564146888


In [39]:
jnp.max(jnp.abs(pts-xk))

NameError: name 'pts' is not defined

In [16]:
pts[:,0] 

DeviceArray([ 1., -2.,  1.], dtype=float64)

In [17]:
xk[:,0]

DeviceArray([ 0.94920501, -1.96687241,  1.00613827], dtype=float64)

In [19]:
loss_fun(xk[:,0], data[0])

DeviceArray(0., dtype=float64)

In [20]:
lf = jnp.array([loss_fun(xk[:,i], data[i]) for i in range(10)])

In [21]:
jnp.sum(lf)

DeviceArray(1.00974196e-29, dtype=float64)

In [22]:
lf

DeviceArray([0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
             0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
             0.00000000e+00, 1.00974196e-29, 0.00000000e+00,
             0.00000000e+00], dtype=float64)

In [23]:
idx = 0
x = xk[:,0]
p = pts[:,0]
d = data[0]

loss_fun(p,d)


Z

DeviceArray(0., dtype=float64)

In [24]:
interpolate_fun(x)

DeviceArray(99.0487436, dtype=float64)

In [25]:
interpolate_fun(p)

DeviceArray(99.0487436, dtype=float64)

In [26]:
interpolate(x, x_grid, y_grid, z_grid, vol, "nn")

DeviceArray(99.0487436, dtype=float64)

In [30]:
coords_x, nearest_pts_x = find_nearest_eight_grid_points_idx(x, x_grid, y_grid, z_grid)
print(x)
print(coords_x)
#print(x-coords_x)
print(nearest_pts_x[0])
print(nearest_pts_x[1])
print(x_grid)

interp_pts_x = tri_interp_point(coords_x, vol, nearest_pts_x)

print(interp_pts_x)

[ 0.94920501 -1.96687241  1.00613827]
[0.94920501 3.03312759 1.00613827]
[[0. 1.]
 [3. 4.]
 [1. 2.]]
[[0 1]
 [3 4]
 [1 2]]
[1. 5.]
99.04874360395219


In [31]:
coords_p, nearest_pts_p = find_nearest_eight_grid_points_idx(p, x_grid, y_grid, z_grid)
print(p)
print(coords_p)
#print(p-coords_p)
print(nearest_pts_p[0])
print(nearest_pts_p[1])

interp_pts_p = tri_interp_point(coords_p, vol, nearest_pts_p)
print(interp_pts_p)

[ 1. -2.  1.]
[1. 3. 1.]
[[1. 2.]
 [3. 4.]
 [1. 2.]]
[[1 2]
 [3 4]
 [1 2]]
99.0487436039522


In [33]:
vol[3,1,1] - interp_pts_x

DeviceArray(1.42108547e-14, dtype=float64)

#### Q: why is interp_pts_x equal to interp_pts_p (and to vol[3,1,1]) ????

In [234]:
cc = jnp.array([0.9492,3.0331,1.0061]) 
cc = coords_p.at[0].set(0.9492051)
cc = cc.at[1].set(3.0354838478076)
print("cc               =", cc)
print("coords_x         =", coords_x)
print("max(c-coords_x)  = ", jnp.max(jnp.abs(cc-coords_x)))
#print("i(cc)-i(coords_x)= ", tri_interp_point(cc, vol, nearest_pts_x) -tri_interp_point(coords_x, vol, nearest_pts_x))

print("interp(cc)       =", tri_interp_point(cc, vol, nearest_pts_x) - vol[3,1,1])
print("interp(coords_x) =", tri_interp_point(coords_x, vol, nearest_pts_x) - vol[3,1,1])



cc               = [0.9492051  3.03548385 1.        ]
coords_x         = [0.94920501 3.03312759 1.00613827]
max(c-coords_x)  =  0.006138267515643836
(3,)
(3, 2)
[ True  True  True]
[ True  True  True]
interp(cc)       = 0.0
(3,)
(3, 2)
[ True  True  True]
[ True  True  True]
interp(coords_x) = -1.4210854715202004e-14


In [235]:
nearest_pts_x

(DeviceArray([[0., 1.],
              [3., 4.],
              [1., 2.]], dtype=float64),
 DeviceArray([[0, 1],
              [3, 4],
              [1, 2]], dtype=int64))

In [240]:
coords_p

DeviceArray([1., 3., 1.], dtype=float64)

In [241]:
cc

DeviceArray([0.9492051 , 3.03548385, 1.        ], dtype=float64)

In [140]:
coords_p

DeviceArray([1., 3., 1.], dtype=float64)

In [88]:
jnp.max(jnp.abs(coords_x - coords_p))

DeviceArray(0.05079499, dtype=float64)

In [133]:
increments = [1e-17, 1e-16, 1e-15, 1e-14, 1e-13, 1e-12, 1e-11, 1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1]
for incr in increments:
    print(incr)
    print(tri_interp_point(coords_p + incr, vol, xnearest_pts_x) - tri_interp_point(coords_p, vol, nearest_pts_x))
    print(tri_interp_point(coords_p - incr, vol, nearest_pts_x) - tri_interp_point(coords_p, vol, nearest_pts_x))
    print()

1e-17
0.0
0.0

1e-16
0.0
-1.4210854715202004e-14

1e-15
0.0
-1.4210854715202004e-14

1e-14
1.4210854715202004e-14
-1.4210854715202004e-14

1e-13
9.947598300641403e-14
-9.947598300641403e-14

1e-12
7.958078640513122e-13
-7.815970093361102e-13

1e-11
7.958078640513122e-12
-7.958078640513122e-12

1e-10
7.963762982399203e-11
-7.963762982399203e-11

1e-09
7.964189308040659e-10
-7.964331416587811e-10

1e-08
7.964189308040659e-09
-7.964175097185944e-09

1e-07
7.96419925563896e-08
-7.964202097809903e-08

1e-06
7.964196271359469e-07
-7.964202382026997e-07

1e-05
7.964167679119782e-06
-7.964231329538052e-06

0.0001
7.963881169814613e-05
-7.96451775784135e-05

0.001
0.0007961011147017416
-0.000796737701875827

0.01
0.007931826586727198
-0.007995485302714656

0.1
0.0759154855338835
-0.08228135713223139

1
-0.0654474806920291
-0.5711396791456878

