In [5]:
%load_ext autoreload
%autoreload 2
import site
import sys
import time
site.addsitedir('..')
from jax.config import config

config.update("jax_enable_x64", True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
import numpy as np
import jax.numpy as jnp
from  matplotlib import pyplot as plt
from src.interpolate import interpolate
import jax

Create the grids and the volume

In [449]:
nx = 16

x_freq = jnp.fft.fftfreq(nx, 1/nx)
y_freq = x_freq
z_freq = x_freq

x_grid = np.array([x_freq[1], len(x_freq)])
y_grid = np.array([y_freq[1], len(y_freq)])
z_grid = np.array([z_freq[1], len(z_freq)])

vol = jnp.array(np.random.randn(nx,nx,nx))

Generate points on the grid

In [450]:
N = 10

# The points will be on the grid between low and high. 
# Going outside the range by half, to ensure wrap-around works well.
low = -nx/2-nx/4
high = nx/2 + nx/4

# Actually not going outside the range for now
low = -nx/2
high = nx/2

pts = jnp.array(np.random.randint(low, high, size = (3,N))).astype(jnp.float64)


In [451]:
pts.shape

(3, 10)

In [452]:
@jax.jit
def interpolate_fun(coords):
    return interpolate(coords, x_grid, y_grid, z_grid, vol, "tri")

The interpolated values, i.e. the data.

In [453]:
data = interpolate_fun(pts)

In [454]:
@jax.jit
def loss_fun(coords):
    return 1/(2*N)*jnp.sum((interpolate_fun(coords) - data)**2)

@jax.jit
def loss_fun_grad(coords):
    return jax.grad(loss_fun)(coords)

#@jax.jit
#def loss_fun_grad_array(coords, ivals):
#    return jax.vmap(loss_fun_grad, in_axes = (1, 0))(coords, ivals).T

In [455]:
loss_fun(pts)

DeviceArray(0., dtype=float64)

In [456]:
loss_fun_grad(pts)

DeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float64)

And solve the inverse problem - for some reason some points in pts don't converge with GD. CG doesn't work at all so far.

In [472]:
#TODO: move CG implemented here to a separate file and make it work.

x0 = jnp.array(np.random.randn(3, N) * nx/2)
x0 = pts + 0.1*np.random.randn(3,N)
zero = jnp.zeros(pts.shape)

AA = lambda c : loss_fun_grad(c) - loss_fun_grad(zero)
b = - loss_fun_grad(zero)
x = x0

N_iter = 100

r = b - AA(x)
p = r
for i in range(N_iter):
    r0 = r
    
    AAp = AA(p)
    alpha = jnp.dot(r.ravel(),r.ravel())/jnp.dot(p.ravel(), AAp.ravel())
    x = x + alpha * p
    r = r - alpha * AAp
    beta = jnp.dot(r.ravel(),r.ravel())/jnp.dot(r0.ravel(),r0.ravel())
    p = r + beta*p
     
    if jnp.mod(i, 10) == 0:
        loss = loss_fun(x)
        print("iter = ", i, ", max loss = ", loss)



iter =  0 , max loss =  0.328256820925883
iter =  10 , max loss =  2355520.3406002787
iter =  20 , max loss =  2355510.950100164
iter =  30 , max loss =  2355510.6675147978
iter =  40 , max loss =  2355511.138549726
iter =  50 , max loss =  2355511.546842881
iter =  60 , max loss =  2355518.6896798084
iter =  70 , max loss =  2355530.3551045116
iter =  80 , max loss =  2355531.109605692
iter =  90 , max loss =  2355531.10977351


In [430]:
loss_fun(pts, data)

DeviceArray(0., dtype=float64)

In [463]:
jnp.dot(r.ravel(), r.ravel())

DeviceArray(2309.31678337, dtype=float64)

In [469]:
alpha

DeviceArray(1.06842555e-17, dtype=float64)

In [470]:
beta

DeviceArray(1.09502754, dtype=float64)