In [2]:
import uf3.jax.jax_splines as jsp
from uf3.jax import potentials

import jax.numpy as jnp
import numpy as onp

from jax import jit, vmap, grad
from functools import partial



In [3]:
from uf3.representation import bspline
from uf3.data import composition
from uf3.regression import least_squares

element_list = ['W']
degree = 2

chemical_system = composition.ChemicalSystem(element_list=element_list,
                                             degree=degree)

r_min_map = {('W', 'W'): 1.5,
            }
r_max_map = {('W', 'W'): 5.5,
            }
resolution_map = {('W', 'W'): 25,
                 }
trailing_trim = 3

bspline_config = bspline.BSplineBasis(chemical_system,
                                      r_min_map=r_min_map,
                                      r_max_map=r_max_map,
                                      resolution_map=resolution_map,
                                      trailing_trim=trailing_trim)

model = least_squares.WeightedLinearModel(bspline_config)

model.load(filename="../tungsten_extxyz/model_pair.json")

coefficients = jnp.asarray(model.coefficients[1:])
knots = jnp.asarray(model.bspline_config.knots_map[('W','W')])

print(coefficients)
print(knots)

[ 6.43001175e+00  5.97354269e+00  5.06058788e+00  3.69113111e+00
  1.65205860e+00  9.23859119e-01  2.81258196e-01 -5.36413584e-03
 -1.65669978e-01 -2.39954889e-01 -2.77954251e-01 -2.80998796e-01
 -2.49653071e-01 -1.77409589e-01 -1.00099854e-01 -6.82599768e-02
 -2.00034436e-02 -3.28071229e-02 -3.43219452e-02 -3.43618244e-02
 -2.88400631e-02 -2.12175436e-02 -6.90983492e-04 -9.25139058e-03
 -8.08694772e-03  0.00000000e+00  0.00000000e+00  0.00000000e+00]
[1.5  1.5  1.5  1.5  1.66 1.82 1.98 2.14 2.3  2.46 2.62 2.78 2.94 3.1
 3.26 3.42 3.58 3.74 3.9  4.06 4.22 4.38 4.54 4.7  4.86 5.02 5.18 5.34
 5.5  5.5  5.5  5.5 ]


In [4]:
def uf2_interaction(
    dr,
    species,
    coefficients = None,
    knots: jnp.ndarray = None,
    cutoff: float = 5.5,
):
    k = 3
    mint = knots[k]
    maxt = knots[-k]
    # TODO lower cut_off might have to be modified or knots and coefficients have to be corespondingly set
    within_cutoff = (dr > 0) & (dr < cutoff) & (dr >= mint) & (dr < maxt)
    dr = jnp.where(within_cutoff, dr, 0.0)
    dr0 = jnp.where(species==0, dr, 0.0)
    dr1 = jnp.where(species==1, dr, 0.0)
    spline = jit(vmap(partial(jsp.deBoor_factor_unsafe, k, knots)))
    U = jnp.where(
        within_cutoff, jnp.sum(coefficients['(0,0)'] * spline(dr0), 1), 0.0
    )  # TODO check performance vs einsum
    U = U + jnp.where(
        within_cutoff, jnp.sum(coefficients['(1,1)'] * spline(dr1), 1), 0.0
    )
    return U

In [5]:
dr = jnp.asarray([2.0,2])
species = jnp.asarray([0,1])

In [6]:
test = {}
test['(0,0)'] = coefficients
test['(1,1)'] = jnp.ones_like(coefficients)

In [7]:
potentials.uf2_interaction(dr,coefficients, knots)

DeviceArray([1.7074257, 1.7074257], dtype=float32)

In [8]:
uf2_interaction(dr,species,coefficients=test,knots=knots)

DeviceArray([1.7074257 , 0.99999994], dtype=float32)

In [9]:
f = lambda c: jnp.sum((jnp.ones(2) - uf2_interaction(dr, species, coefficients=c, knots=knots)) ** 2)

In [10]:
f(test)

DeviceArray(0.50045115, dtype=float32)

In [11]:
df = grad(f)
df(test)

{'(0,0)': DeviceArray([0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.5797353e-01,
              9.2250890e-01, 3.3390841e-01, 4.6056215e-04, 0.0000000e+00,
              0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
              0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
              0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
              0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
              0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],            dtype=float32),
 '(1,1)': DeviceArray([ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
              -1.3310170e-08, -7.7726625e-08, -2.8133686e-08,
              -3.8804984e-11,  0.0000000e+00,  0.0000000e+00,
               0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
               0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
               0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
               0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
  

Gradients for specific coefficients

In [12]:
def f2(c, cs):
    cs['(1,1)'] = c
    return jnp.sum((jnp.ones(2) - uf2_interaction(dr, species, coefficients=cs, knots=knots)) ** 2)

In [14]:
f2(jnp.ones_like(coefficients), test)

DeviceArray(0.50045115, dtype=float32)

In [16]:
df2 = grad(f2)
df2(coefficients, test)

DeviceArray([0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.5797353e-01,
             9.2250890e-01, 3.3390841e-01, 4.6056215e-04, 0.0000000e+00,
             0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
             0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
             0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
             0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
             0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],            dtype=float32)