# Training 2 Body UFP in JAX MD

In [1]:
import jax
from jax import vmap, grad, value_and_grad
import jax.numpy as jnp

from jax.config import config
config.update("jax_enable_x64", True)

from functools import partial

from uf3.util.random import random_spline
from uf3.jax.jax_splines import ndSpline_unsafe

from uf3.jax.training import loss_uf2
from uf3.regression.regularize import get_regularizer_matrix



In [2]:
def per_atom_f(fn, c):
    def f(x, coefficients=c):
        s = partial(fn, coefficients=coefficients)
        tmp = vmap(s)(x)
        return jnp.sum(tmp)
    return grad(f)


In [3]:
c, k, x = random_spline([20],0,10,300,(3,),123)
# reference = lambda x: jnp.sum(vmap(ndSpline_unsafe(c,k,(3,)))(x))
s = ndSpline_unsafe(c,k,(3,))

ref_e = vmap(s)
ref_f = grad(lambda x :jnp.sum(vmap(s)(x)))
ref_ff = per_atom_f(s, c)



In [4]:
train_x = x[:100]
train_e = ref_e(train_x)
train_f = ref_f(train_x)

test_x = x[100:]
test_e = ref_e(test_x) 
test_f = ref_f(test_x)

In [5]:
coeff = jnp.zeros_like(c)
new = lambda c: grad(lambda x :jnp.sum(vmap(ndSpline_unsafe(c,k,(3,)))(x)))

In [7]:
# loss_uf2(coeff, train_x, train_e, train_f, new)

In [14]:
coefficients = coeff
S = train_x
E = train_e
F = train_f
kappa = 0.5
lam=1.0

sigE = jnp.sum((E - E.mean())**2)
sigF = jnp.sum((F - F.mean())**2)
dE = 0.0
dF = 0.0
# for s, e, f in zip(S,E,F):
se = partial(s, coefficients=coefficients)
v = vmap(se)(S)
g = ref_ff(S, coefficients=coefficients)
dE += (v - E) ** 2
dF += (-g - F) ** 2

E_term = (kappa / sigE) * jnp.sum(dE)

In [20]:

F_term = ((1-kappa) / sigF) * jnp.sum(dF)


In [23]:

D = jnp.asarray(get_regularizer_matrix(len(coefficients)))
L_term = lam * jnp.sum((D * coefficients)**2)

In [29]:
E_term + F_term + L_term

DeviceArray(1.12840551, dtype=float64)