In [1]:
import jax.numpy as np
import jax.ops
from jax import grad, jit
from functools import partial

In [11]:
data = np.load('data/HOOH.DFT.PBE-TS.light.MD.500K.50k.R_E_F_D_Q.npz')
X = np.array(data['R'][:10])
y = np.array(data['D'][:10])

In [3]:
import warnings
warnings.filterwarnings('ignore')

In [12]:
from IPython import embed

In [54]:
def fill_diagonal(a, value):
    return jax.ops.index_update(a, np.diag_indices(a.shape[0]), value)

def descriptor(x):
    distances = np.sum((x[:, None] - x[None, :])**2, axis=-1)
    distances = fill_diagonal(distances, 1) # because sqrt fails to compute gradient if called on 0s
    distances = np.sqrt(distances)
    D = 1 / distances
    D = np.tril(D)
    D = fill_diagonal(D, 0)
    return D.flatten()

In [55]:
def gaussian(x, x_, sigma=1):
    d, d_ = descriptor(x), descriptor(x_)
    sq_distance = np.sum((d - d_)**2)
    return np.exp(-sq_distance / sigma)
    

In [56]:
from jax import jacfwd, jacrev
def hessian(f):
    return jacfwd(jacrev(f))

In [57]:
from jax.config import config
config.update("jax_debug_nans", True)

In [59]:
_gaussian = partial(gaussian, X[0])
hess = hessian(_gaussian)
hess(X[1]).shape

(4, 3, 4, 3)

In [9]:
x, x_ = X[0], X[1]
d, d_ = descriptor(x), descriptor(x_)
sq_distance = np.sum((d - d_)**2)

In [28]:
sq_distance

DeviceArray(0.00263891, dtype=float32)