In [1]:
import numpy as np
import matplotlib.pyplot as plt
import ase

import timeit

f64 = np.float64
f32 = np.float32

from uf3.data.composition import ChemicalSystem
from uf3.representation.bspline import BSplineBasis
from uf3.regression.least_squares import WeightedLinearModel

from uf3.forcefield import calculator


from jax.config import config
config.update("jax_enable_x64", True)

from uf3.jax.potentials import uf2_pair, uf2_neighbor, get_stress_fn
from uf3.jax import potentialsOld
from jax import grad, jit, vmap, lax, hessian
import jax.numpy as jnp

from jax_md import space, smap, energy, minimize, quantity, simulate, partition, interpolate




In [2]:
element_list = ["W"]
degree = 2 # include two-body interactions (pair potential)

chemical_system = ChemicalSystem(element_list=element_list,
                                 degree=degree)

r_min_map = {('W', 'W'): 1.5,}  # minimum distance cutoff (angstroms)
r_max_map = {('W', 'W'): 5.5,}  # maximum distance cutoff (angstroms)
resolution_map = {('W', 'W'): 12}  # number of knot intervals

bspline_config = BSplineBasis(chemical_system,
                               r_min_map=r_min_map,
                               r_max_map=r_max_map,
                               resolution_map=resolution_map)

model_coefficients = [-7.66691574e+00,  1.25129255e+01,  5.93366978e+00,  2.64404190e+00,\
        3.89362649e-01, -1.80794686e-01, -2.84676830e-01, -2.67508086e-01,\
       -9.45645302e-02,  1.47269014e-02, -6.01878769e-03, -8.12415903e-03,\
        8.67558130e-03, -4.74148052e-03, -7.13338724e-04,  0.00000000e+00]

model = WeightedLinearModel(bspline_config)

model.coefficients = np.asarray(model_coefficients)

In [3]:
calc = calculator.UFCalculator(model)

pairs = calc.bspline_config.chemical_system.interactions_map[2]

knots = calc.bspline_config.knots_map[pairs[0]]
coefficients = calc.solutions[pairs[0]]

print(coefficients)

[ 1.25129255e+01  5.93366978e+00  2.64404190e+00  3.89362649e-01
 -1.80794686e-01 -2.84676830e-01 -2.67508086e-01 -9.45645302e-02
  1.47269014e-02 -6.01878769e-03 -8.12415903e-03  8.67558130e-03
 -4.74148052e-03 -7.13338724e-04  0.00000000e+00]


# Code Demo

## Setting up Parameters

### Coefficients and Knots for Tungsten

In [4]:
knots = [1.5, 1.5, 1.5, 1.5, 1.83333333, 2.16666667, 2.5, 2.83333333, 3.16666667, 3.5, 3.83333333, 4.16666667, 4.5, 4.83333333, 5.16666667, 5.5, 5.5, 5.5, 5.5]

coefficients = [ 1.25129255e+01,  5.93366978e+00,  2.64404190e+00,  3.89362649e-01, -1.80794686e-01, -2.84676830e-01, -2.67508086e-01, \
                -9.45645302e-02, 1.47269014e-02, -6.01878769e-03, -8.12415903e-03, 8.67558130e-03, -4.74148052e-03, -7.13338724e-04,  0.00000000e+00]

## Create a large Custom Cell

In [5]:
from jax_md.quantity import box_size_at_number_density

def creat_random_cell(particle_count, density):

    dim = 3
    box_size = box_size_at_number_density(particle_count    = particle_count, 
                                          number_density    = density, 
                                          spatial_dimension = dim)

    rng = np.random.default_rng(3)
    pos = rng.uniform(0.0, box_size, (particle_count,3))

    atoms = ase.Atoms('W'+str(particle_count),
                      positions=pos,
                      cell=[box_size, box_size, box_size],
                      pbc=[1,1,1])
    return atoms

In [6]:
size = 100

atoms = creat_random_cell(size, 1/15.0)

print(atoms.cell.cellpar()[0]) # Edges have to be at least 11 long for JAX MD to work right
   
R = atoms.get_positions(wrap=True)



11.447142425533317


## Using the ASE-based UFP Implementation

In [8]:
calc = calculator.UFCalculator(model)

geom = atoms.copy()
geom.set_calculator(calc)


In [9]:
#%%timeit -n1 -r5

energy = geom.get_potential_energy(force_consistent=True)
forces = geom.get_forces()
stress = geom.get_stress()


In [10]:
print("Energy:", energy)
print("Forces:\n", forces[0:5])
print("Max force:", np.max(np.abs(forces)))
print("Stresses (numerical):", stress)

Energy: 571.9160833398482
Forces:
 [[ 52.14344602 -63.39688627  21.10488864]
 [ 34.9254233    9.70466697 -37.2808141 ]
 [ 36.39109213  22.41294595  30.4890162 ]
 [ -8.00302773 -11.41969786   9.6287169 ]
 [  0.82743839  -7.72977273 -10.00567689]]
Max force: 229.68753385959016
Stresses (numerical): [-1.11388509 -1.16451295 -1.35145374  0.1701487   0.06297906  0.04704345]


## Using the JAX M.D. UFP Implementation

In [13]:
knots = jnp.asarray(knots)
coefficients = jnp.asarray(coefficients)

In [14]:
# setup displacement
box = atoms.get_cell().array

# ---Experimenting with different JAX-MD spaces --------
displacement, shift = space.periodic_general(box, fractional_coordinates=False)
    
ufp = uf2_pair(displacement, knots=knots, coefficients=coefficients)

In [21]:
#%%timeit -n1 -r5

# Use this to calculate the total energy
energy = ufp(R)

# Use grad to calculate the net force
force_fn = grad(ufp)
force = -force_fn(R)

# Stress
stress_fn = get_stress_fn(ufp, box)
stress = stress_fn(R)

In [23]:
print(energy * 2)
print(force[0:5] * 2)
print(np.max(np.abs(force * 2)))
print(stress * 2)

571.9160818159853
[[ 52.14344624 -63.39688655  21.10488873]
 [ 34.92542298   9.70466692 -37.28081374]
 [ 36.39109195  22.41294583  30.48901608]
 [ -8.00302773 -11.41969785   9.6287169 ]
 [  0.82743841  -7.72977272 -10.00567684]]
229.68753488472413
[[-1.02629656  0.02274775  0.04006818]
 [ 0.02274775 -1.08964242  0.14115386]
 [ 0.04006818  0.14115386 -1.28837716]]


In [None]:
%%timeit -n1 -r5

ufpj = jit(ufp)

# Use this to calculate the total energy
energy = ufpj(R)

# Use grad to calculate the net force
force_fn = grad(ufpj)
force = force_fn(R)

# Stress
stress_fn = get_stress_fn(ufpj, box)
stress = stress_fn(R)

165 ms ± 3.62 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)


In [None]:
# Hessian of Forces
he_fn = hessian(ufp)
he = he_fn(R)
print(he[0][0][0:5])

[[ 1.31975776e+02 -2.26713103e+02  7.83711339e+01]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00]
 [ 5.81072503e-02 -1.41280586e-02 -1.22263409e-02]
 [-1.21300490e-01  8.48324108e-02 -1.56314539e-01]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00]]
