In [1]:
import pytest

from uf3.jax.potentials import uf3_mapped, uf3_neighbor
from uf3.jax.jax_splines import featurization_with_gradients
import jax.numpy as jnp
from uf3.jax.jax_splines import *
from jax import vmap, grad, jacrev
from uf3.util.random import random_spline
import numpy as onp
import ndsplines
from numpy.testing import assert_allclose
from jax_md import space, energy

from jax.config import config
config.update("jax_enable_x64", True)
# config.update("jax_debug_nans", True)
# config.update('jax_disable_jit', True)

# from utils import make_random_spline

from jax_md import space
from uf3.jax.potentials import *
import ase
from uf3.data import composition
from uf3.representation import bspline
from uf3.regression import least_squares
from uf3.forcefield import calculator



In [2]:
N = 50
dimension = 3
box_size = 12.0
rng = onp.random.default_rng()
seed = rng.integers(0, 999)
print(f"Seed for energy test: {seed}")
rng = onp.random.default_rng(seed)
rng = onp.random.default_rng(848)

R = rng.uniform(0.0, box_size, (N, dimension))
R = jnp.asarray(R, dtype=jnp.float64)

box = jnp.eye(dimension) * box_size
displacement, shift = space.periodic_general(box, fractional_coordinates=False)

species = onp.concatenate([onp.zeros(N // 2), onp.ones(N - (N // 2))])
rng.shuffle(species)
species = jnp.asarray(species, dtype=jnp.int16)

pos = onp.asarray(R)
cell = onp.asarray(box)
pbc = onp.asarray([True, True, True])
atoms = ase.Atoms("W" + str(len(R)), positions=pos, cell=cell, pbc=pbc)
r_min_map = {("W", "W"): 1.5,
            ("W", "W", "W"): [1.5, 1.5, 1.5],
        }
r_max_map = {("W", "W"): 5.5,
            ("W", "W", "W"): [3.5, 3.5, 7.0],
        }
resolution_map = {("W", "W"): 25,
                ("W", "W", "W"): [5, 5, 10],
                }
trailing_trim = 3
chemical_system = composition.ChemicalSystem(element_list=["W"], degree=3)
bspline_config = bspline.BSplineBasis(chemical_system,
                                    r_min_map=r_min_map,
                                    r_max_map=r_max_map,
                                    resolution_map=resolution_map,
                                    trailing_trim=trailing_trim)

model = least_squares.WeightedLinearModel(bspline_config)
# model.coefficients = onp.ones()

model.load(filename="../../tungsten_extxyz/model_uf23.json")

calc = calculator.UFCalculator(model)
ndspline2 = calc.pair_potentials[('W','W')]
ndspline3 = calc.trio_potentials[('W','W','W')]
knots2 = [jnp.asarray(ndspline2.knots[0])]
knots3 = ndspline3.knots
knots3 = [jnp.asarray(i) for i in knots3]

coefficients2 = rng.standard_normal(len(knots2[0]) - 7) * 5
coefficients2 = onp.pad(coefficients2, (0, 3))
coefficients2 = jnp.asarray(coefficients2)


c3_shape = (ndspline3.coefficients.shape[0] -3,ndspline3.coefficients.shape[1] -3,ndspline3.coefficients.shape[2] -3)
coefficients3 = rng.standard_normal(c3_shape) * 5
coefficients3 = onp.pad(coefficients3, ((0, 3),) * 3)
coefficients3 = jnp.asarray(coefficients3)

coefficients2 = jnp.zeros_like(coefficients2)
# coefficients3 = jnp.ones_like(coefficients3)

pair = uf3_pair(
    displacement, [knots2, knots3], coefficients=[coefficients2, coefficients3]
)
nf, ef = uf3_neighbor(
    displacement,
    box_size,
    [knots2, knots3],
    coefficients=[coefficients2, coefficients3],
    cutoff=knots2[0][-1] + 0.5,
)

coeff_dict2 = {}
coeff_dict2[(0, 0)] = coefficients2
coeff_dict2[(0, 1)] = coefficients2
coeff_dict2[(1, 1)] = coefficients2

knot_dict2 = {}
knot_dict2[(0, 0)] = knots2
knot_dict2[(0, 1)] = knots2
knot_dict2[(1, 1)] = knots2




Seed for energy test: 925




In [13]:

coeff_dict3 = {}
coeff_dict3[(0, 0, 0)] = coefficients3
coeff_dict3[(0, 0, 1)] = coefficients3
coeff_dict3[(0, 1, 1)] = coefficients3
coeff_dict3[(1, 0, 0)] = coefficients3
coeff_dict3[(1, 0, 1)] = coefficients3
coeff_dict3[(1, 1, 1)] = coefficients3

knot_dict3 = {}
knot_dict3[(0, 0, 0)] = knots3
knot_dict3[(0, 0, 1)] = knots3
knot_dict3[(0, 1, 1)] = knots3
knot_dict3[(1, 0, 0)] = knots3
knot_dict3[(1, 0, 1)] = knots3
knot_dict3[(1, 1, 1)] = knots3

nfs, efs = uf3_neighbor(
    displacement,
    box_size,
    [knot_dict2, knot_dict3],
    species=jnp.ones_like(species),
    coefficients=[coeff_dict2, coeff_dict3],
    cutoff=knots2[0][-1] + 0.5,
)


In [15]:
len(species)

50

In [14]:
nbrs = nfs.allocate(R)
efs(R, neighbor=nbrs)

DeviceArray(-6.59145981, dtype=float64)

In [None]:

energy_1 = pair(R)

nbrs = nf.allocate(R)
energy_2 = ef(R, neighbor=nbrs)

nbrs = nfs.allocate(R)
energy_3 = efs(R, neighbor=nbrs)

# assert jnp.allclose(energy_1, energy_2)
# assert jnp.allclose(energy_2, energy_3)

ndspline2.coefficients = onp.asarray(coefficients2[:,None])
ndspline3.coefficients = onp.asarray(coefficients3[:,:,:,None])
calc.pair_potentials[('W','W')] = ndspline2
calc.trio_potentials[('W','W','W')] = ndspline3
atoms.set_calculator(calc)

energy_4 = atoms.get_potential_energy(force_consistent=True)

# assert jnp.allclose(energy_2, energy_4)

In [3]:
energy_1

DeviceArray(-6.59145981, dtype=float64)

In [4]:
energy_2

DeviceArray(-6.59145981, dtype=float64)

In [5]:
energy_3

DeviceArray(2.32717351, dtype=float64)

In [6]:
energy_4

-24.96473996921558