In [1]:
import numpy as onp

from jax.config import config ; config.update('jax_enable_x64', True)
import jax.numpy as jnp
from jax import random
from jax import jit, grad
from jax import lax

import time

from jax_md import space
from jax_md import smap
from jax_md import energy
from jax_md import quantity
from jax_md import simulate
from jax_md import partition

import matplotlib
import matplotlib.pyplot as plt

import ase


## System

In [2]:
from uf3.representation import bspline
from uf3.data import composition
from uf3.regression import least_squares

element_list = ['W']
degree = 2

chemical_system = composition.ChemicalSystem(element_list=element_list,
                                             degree=degree)

r_min_map = {('W', 'W'): 1.5,
            }
r_max_map = {('W', 'W'): 5.5,
            }
resolution_map = {('W', 'W'): 25,
                 }
trailing_trim = 3

bspline_config = bspline.BSplineBasis(chemical_system,
                                      r_min_map=r_min_map,
                                      r_max_map=r_max_map,
                                      resolution_map=resolution_map,
                                      trailing_trim=trailing_trim)

model_2_body = least_squares.WeightedLinearModel(bspline_config)

model_2_body.load(filename="../../tungsten_extxyz/model_pair.json")

In [3]:
from uf3.forcefield.calculator import *

element_list = ['W']
degree = 3

chemical_system = composition.ChemicalSystem(element_list=element_list,
                                             degree=degree)

r_min_map = {("W", "W"): 1.5,
             ("W", "W", "W"): [1.5, 1.5, 1.5],
            }
r_max_map = {("W", "W"): 5.5,
             ("W", "W", "W"): [3.5, 3.5, 7.0],
            }
resolution_map = {("W", "W"): 25,
                  ("W", "W", "W"): [5, 5, 10],
                 }
trailing_trim = 3

bspline_config = bspline.BSplineBasis(chemical_system,
                                      r_min_map=r_min_map,
                                      r_max_map=r_max_map,
                                      resolution_map=resolution_map,
                                      trailing_trim=trailing_trim)

model_3_body = least_squares.WeightedLinearModel(bspline_config)

model_3_body.load(filename="../../tungsten_extxyz/model_uf23.json")

In [8]:
Nx = particles_per_side = 10
# Nx = particles_per_side = 50

spacing = np.float32(2.0)
side_length = Nx * spacing

R = onp.stack([np.array(r) for r in np.ndindex(Nx, Nx, Nx)]) * spacing
R = np.array(R, np.float64)

In [9]:
print(f"Number of atoms: {R.shape[0]}")
print(f"Box should be larger than 11 for JAX-MD, is: {side_length}")

Number of atoms: 1000
Box should be larger than 11 for JAX-MD, is: 20.0


In [10]:
pos = R
c = onp.identity(3) * side_length
ase.Atoms('W1000', positions=pos, cell=c, pbc=[True,True,True])
# ase.Atoms('W125000', positions=pos, cell=c, pbc=[True,True,True])

Atoms(symbols='W1000', pbc=True, cell=[20.0, 20.0, 20.0])

In [11]:
displacement, shift = space.periodic(side_length)
format = partition.Dense
R = jnp.asarray(R)

In [12]:
calc = UFCalculator(model_3_body)

ndspline2 = calc.pair_potentials[('W','W')]
ndspline3 = calc.trio_potentials[('W','W','W')]

coefficients2 = jnp.asarray(ndspline2.coefficients)
coefficients2 = coefficients2[:,0]

knots2 = ndspline2.knots
knots2 = [jnp.asarray(i) for i in knots2]

coefficients3 = jnp.asarray(ndspline3.coefficients)
coefficients3 = coefficients3[:,:,:,0]

knots3 = ndspline3.knots
knots3 = [jnp.asarray(i) for i in knots3]

## Lennard Jones

In [13]:
from jax_md.energy import lennard_jones_neighbor_list
from jax_md import partition

In [14]:
neighbor_fn, energy_fn = lennard_jones_neighbor_list(displacement, side_length, r_cutoff=5.5, format=partition.Dense)

In [18]:
nbrs = neighbor_fn.allocate(R)
energy_fn = jit(energy_fn, backend='cpu')

In [19]:
nbrs.idx.shape

(1000, 115)

In [20]:
print(energy_fn(R, nbrs))

-239.80404308865104


In [21]:
displacement, shift = space.periodic(side_length)

init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3)
state = init_fn(random.PRNGKey(0), R, kT=1e-3, neighbor=nbrs)


def body_fn(i, state):
    state, nbrs = state
    nbrs = nbrs.update(state.position)
    state = apply_fn(state, neighbor=nbrs)
    return state, nbrs


print('Step\tKE\tPE\tTotal Energy\ttime/step')
print('----------------------------------------')
PE = []
KE = []
print_every = 4
step = 0
old_time = time.time_ns()
while step < 40:
    new_state, nbrs = lax.fori_loop(0, 100, body_fn, (state, nbrs))
    if nbrs.did_buffer_overflow:
        print("Neighbor list overflowed, reallocating.")
        nbrs = neighbor_fn.allocate(state.position)
    else:
        state = new_state
        step += 1

    if step % print_every == 0 and step > 0:
        new_time = time.time_ns()
        PE += [energy_fn(state.position, nbrs)]
        KE += [quantity.kinetic_energy(state.velocity)]
        print(
            "{}\t{:.2f}\t{:.2f}\t{:.2f}\t\t{:.2f}ms".format(
                step * print_every,
                KE[-1],
                PE[-1],
                KE[-1] + PE[-1],
                (new_time - old_time) / print_every / 100.0 / 1000000.0,
            )
        )
        old_time = new_time

Step	KE	PE	Total Energy	time/step
----------------------------------------
16	220.31	-30008.80	-29788.48		25.29ms
32	364.11	-30152.59	-29788.48		22.64ms
Neighbor list overflowed, reallocating.
48	823.34	-30611.82	-29788.48		34.18ms
64	2417.24	-32205.72	-29788.48		27.95ms
80	9413.04	-39201.53	-29788.49		26.43ms
96	28060.07	-57848.68	-29788.61		26.63ms
112	47036.03	-76824.78	-29788.75		26.72ms
128	59273.28	-89062.02	-29788.74		26.93ms
144	69008.78	-98797.53	-29788.75		26.91ms
160	78446.54	-108235.31	-29788.76		27.09ms


## Stillinger-Weber

In [9]:
from jax_md.energy import stillinger_weber_neighbor_list
from jax_md import partition

In [32]:
neighbor_fn, energy_fn =stillinger_weber_neighbor_list(displacement, side_length, cutoff=5.5)

In [33]:
nbrs = neighbor_fn.allocate(R)
energy_fn = jit(energy_fn)

In [34]:
nbrs.idx.shape

(125000, 115)

In [35]:
print(energy_fn(R, nbrs))

196125389.47292036


In [21]:
displacement, shift = space.periodic(side_length)

init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3)
state = init_fn(random.PRNGKey(0), R, kT=1e-3, neighbor=nbrs)


def body_fn(i, state):
    state, nbrs = state
    nbrs = nbrs.update(state.position)
    state = apply_fn(state, neighbor=nbrs)
    return state, nbrs


print('Step\tKE\tPE\tTotal Energy\ttime/step')
print('----------------------------------------')
PE = []
KE = []
print_every = 4
step = 0
old_time = time.time_ns()
while step < 40:
    new_state, nbrs = lax.fori_loop(0, 100, body_fn, (state, nbrs))
    if nbrs.did_buffer_overflow:
        print("Neighbor list overflowed, reallocating.")
        nbrs = neighbor_fn.allocate(state.position)
    else:
        state = new_state
        step += 1

    if step % print_every == 0 and step > 0:
        new_time = time.time_ns()
        PE += [energy_fn(state.position, nbrs)]
        KE += [quantity.kinetic_energy(state.velocity)]
        print(
            "{}\t{:.2f}\t{:.2f}\t{:.2f}\t\t{:.2f}ms".format(
                step * print_every,
                KE[-1],
                PE[-1],
                KE[-1] + PE[-1],
                (new_time - old_time) / print_every / 100.0 / 1000000.0,
            )
        )
        old_time = new_time

Step	KE	PE	Total Energy	time/step
----------------------------------------
16	220.31	-30008.80	-29788.48		25.29ms
32	364.11	-30152.59	-29788.48		22.64ms
Neighbor list overflowed, reallocating.
48	823.34	-30611.82	-29788.48		34.18ms
64	2417.24	-32205.72	-29788.48		27.95ms
80	9413.04	-39201.53	-29788.49		26.43ms
96	28060.07	-57848.68	-29788.61		26.63ms
112	47036.03	-76824.78	-29788.75		26.72ms
128	59273.28	-89062.02	-29788.74		26.93ms
144	69008.78	-98797.53	-29788.75		26.91ms
160	78446.54	-108235.31	-29788.76		27.09ms


## UF potentials

In [9]:
from uf3.jax.potentials import uf3_neighbor

In [19]:
c = [coefficients2, coefficients3]
k = [knots2, knots3]

In [20]:
neighbor_fn, energy_fn = uf3_neighbor(displacement, side_length, coefficients=c, knots=k, cutoff=5.5)

In [21]:
nbrs = neighbor_fn.allocate(R)
energy_fn = jit(energy_fn)

In [22]:
# Run once for jit
print(energy_fn(R, nbrs))
# print(grad(energy_fn)(R, nbrs))

26811.792926319988


In [None]:
nbrs.idx.shape

In [15]:
displacement, shift = space.periodic(side_length)

init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3)
state = init_fn(random.PRNGKey(0), R, kT=1e-3, neighbor=nbrs)


def body_fn(i, state):
    state, nbrs = state
    nbrs = nbrs.update(state.position)
    state = apply_fn(state, neighbor=nbrs)
    return state, nbrs


print('Step\tKE\tPE\tTotal Energy\ttime/step')
print('----------------------------------------')
PE = []
KE = []
print_every = 4
step = 0
old_time = time.time_ns()
while step < 40:
    new_state, nbrs = lax.fori_loop(0, 100, body_fn, (state, nbrs))
    if nbrs.did_buffer_overflow:
        print("Neighbor list overflowed, reallocating.")
        nbrs = neighbor_fn.allocate(state.position)
    else:
        state = new_state
        step += 1

    if step % print_every == 0 and step > 0:
        new_time = time.time_ns()
        PE += [energy_fn(state.position, nbrs)]
        KE += [quantity.kinetic_energy(state.velocity)]
        print(
            "{}\t{:.2f}\t{:.2f}\t{:.2f}\t\t{:.2f}ms".format(
                step * print_every,
                KE[-1],
                PE[-1],
                KE[-1] + PE[-1],
                (new_time - old_time) / print_every / 100.0 / 1000000.0,
            )
        )
        old_time = new_time

Step	KE	PE	Total Energy	time/step
----------------------------------------
16	597.64	334525.19	335122.84		60.27ms
32	63447.30	271675.94	335123.23		56.77ms
Neighbor list overflowed, reallocating.
32	63447.30	271675.94	335123.23		16.34ms
48	184899.66	150221.47	335121.13		77.15ms
64	201742.86	133378.02	335120.87		70.36ms
80	207305.33	127815.40	335120.73		70.40ms
96	211815.78	123304.81	335120.59		70.49ms
112	213208.65	121911.91	335120.57		70.45ms
128	215093.71	120026.74	335120.45		70.50ms
144	215924.58	119195.88	335120.46		70.19ms
160	216459.48	118660.95	335120.43		70.44ms
