In [1]:
import numpy as onp

from jax.config import config ; config.update('jax_enable_x64', True)
# config.update('jax_disable_jit', True)

import jax.numpy as jnp
from jax import random
from jax import jit, grad, xla_computation
from jax import lax

import time

from jax_md import space
from jax_md import smap
from jax_md import energy
from jax_md import quantity
from jax_md import simulate
from jax_md import partition

import matplotlib
import matplotlib.pyplot as plt

import ase

import jax.profiler


## System

In [2]:
from uf3.representation import bspline
from uf3.data import composition
from uf3.regression import least_squares

element_list = ['W']
degree = 2

chemical_system = composition.ChemicalSystem(element_list=element_list,
                                             degree=degree)

r_min_map = {('W', 'W'): 1.5,
            }
r_max_map = {('W', 'W'): 5.5,
            }
resolution_map = {('W', 'W'): 25,
                 }
trailing_trim = 3

bspline_config = bspline.BSplineBasis(chemical_system,
                                      r_min_map=r_min_map,
                                      r_max_map=r_max_map,
                                      resolution_map=resolution_map,
                                      trailing_trim=trailing_trim)

model_2_body = least_squares.WeightedLinearModel(bspline_config)

model_2_body.load(filename="../../tungsten_extxyz/model_pair.json")

In [3]:
from uf3.forcefield.calculator import *

element_list = ['W']
degree = 3

chemical_system = composition.ChemicalSystem(element_list=element_list,
                                             degree=degree)

r_min_map = {("W", "W"): 1.5,
             ("W", "W", "W"): [1.5, 1.5, 1.5],
            }
r_max_map = {("W", "W"): 5.5,
             ("W", "W", "W"): [3.5, 3.5, 7.0],
            }
resolution_map = {("W", "W"): 25,
                  ("W", "W", "W"): [5, 5, 10],
                 }
trailing_trim = 3

bspline_config = bspline.BSplineBasis(chemical_system,
                                      r_min_map=r_min_map,
                                      r_max_map=r_max_map,
                                      resolution_map=resolution_map,
                                      trailing_trim=trailing_trim)

model_3_body = least_squares.WeightedLinearModel(bspline_config)

model_3_body.load(filename="../../tungsten_extxyz/model_uf23.json")

In [4]:
Nx = particles_per_side = 10
# Nx = particles_per_side = 50

spacing = np.float32(2.0)
side_length = Nx * spacing

R = onp.stack([np.array(r) for r in np.ndindex(Nx, Nx, Nx)]) * spacing
R = np.array(R, np.float64)

In [5]:
print(f"Number of atoms: {R.shape[0]}")
print(f"Box should be larger than 11 for JAX-MD, is: {side_length}")

Number of atoms: 1000
Box should be larger than 11 for JAX-MD, is: 20.0


In [6]:
pos = R
c = onp.identity(3) * side_length
ase.Atoms('W13824', positions=pos, cell=c, pbc=[True,True,True])
# ase.Atoms('W125000', positions=pos, cell=c, pbc=[True,True,True])

ValueError: Array "positions" has wrong length: 1000 != 13824.

In [6]:
displacement, shift = space.periodic(side_length)
format = partition.Dense
R = jnp.asarray(R)

In [7]:
calc = UFCalculator(model_3_body)

ndspline2 = calc.pair_potentials[('W','W')]
ndspline3 = calc.trio_potentials[('W','W','W')]

coefficients2 = jnp.asarray(ndspline2.coefficients)
coefficients2 = coefficients2[:,0]

knots2 = ndspline2.knots
knots2 = [jnp.asarray(i) for i in knots2]

coefficients3 = jnp.asarray(ndspline3.coefficients)
coefficients3 = coefficients3[:,:,:,0]

knots3 = ndspline3.knots
knots3 = [jnp.asarray(i) for i in knots3]

## Lennard Jones

In [9]:
from jax_md.energy import lennard_jones_neighbor_list
from jax_md import partition

In [10]:
neighbor_fn, energy_fn = lennard_jones_neighbor_list(displacement, side_length, r_cutoff=5.5, format=partition.Dense)

In [11]:
nbrs = neighbor_fn.allocate(R)
energy_fn = jit(energy_fn, backend='gpu')

In [12]:
nbrs.idx.shape

(125000, 115)

In [20]:
print(energy_fn(R, nbrs))

-239.80404308865104


In [21]:
displacement, shift = space.periodic(side_length)

init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3)
state = init_fn(random.PRNGKey(0), R, kT=1e-3, neighbor=nbrs)


def body_fn(i, state):
    state, nbrs = state
    nbrs = nbrs.update(state.position)
    state = apply_fn(state, neighbor=nbrs)
    return state, nbrs


print('Step\tKE\tPE\tTotal Energy\ttime/step')
print('----------------------------------------')
PE = []
KE = []
print_every = 4
step = 0
old_time = time.time_ns()
while step < 40:
    new_state, nbrs = lax.fori_loop(0, 100, body_fn, (state, nbrs))
    if nbrs.did_buffer_overflow:
        print("Neighbor list overflowed, reallocating.")
        nbrs = neighbor_fn.allocate(state.position)
    else:
        state = new_state
        step += 1

    if step % print_every == 0 and step > 0:
        new_time = time.time_ns()
        PE += [energy_fn(state.position, nbrs)]
        KE += [quantity.kinetic_energy(state.velocity)]
        print(
            "{}\t{:.2f}\t{:.2f}\t{:.2f}\t\t{:.2f}ms".format(
                step * print_every,
                KE[-1],
                PE[-1],
                KE[-1] + PE[-1],
                (new_time - old_time) / print_every / 100.0 / 1000000.0,
            )
        )
        old_time = new_time

Step	KE	PE	Total Energy	time/step
----------------------------------------
16	220.31	-30008.80	-29788.48		25.29ms
32	364.11	-30152.59	-29788.48		22.64ms
Neighbor list overflowed, reallocating.
48	823.34	-30611.82	-29788.48		34.18ms
64	2417.24	-32205.72	-29788.48		27.95ms
80	9413.04	-39201.53	-29788.49		26.43ms
96	28060.07	-57848.68	-29788.61		26.63ms
112	47036.03	-76824.78	-29788.75		26.72ms
128	59273.28	-89062.02	-29788.74		26.93ms
144	69008.78	-98797.53	-29788.75		26.91ms
160	78446.54	-108235.31	-29788.76		27.09ms


## Stillinger-Weber

In [14]:
from jax_md.energy import stillinger_weber_neighbor_list, _sw_angle_interaction
from jax_md import partition

In [10]:
neighbor_fn, energy_fn =stillinger_weber_neighbor_list(displacement, side_length, cutoff=5.5)

In [11]:
nbrs = neighbor_fn.allocate(R)
energy_fn = jit(energy_fn)

In [12]:
nbrs.idx.shape

(1000, 115)

In [13]:
print(energy_fn(R, nbrs))

1569003.1157833629


In [18]:
from jax import vmap
from functools import partial
sigma: float = 2.0951
gamma: float = 1.2
cutoff: float = 3.77118
three_body_fn = partial(_sw_angle_interaction, gamma, sigma, cutoff)
three_body_fn = vmap(vmap(vmap(three_body_fn, (0, None)), (None, 0)))

In [19]:
dR = space.map_product(displacement)(R, R)

In [20]:
c = xla_computation(three_body_fn)(dR,dR)

In [None]:
print(c.as_hlo_text())

In [21]:
displacement, shift = space.periodic(side_length)

init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3)
state = init_fn(random.PRNGKey(0), R, kT=1e-3, neighbor=nbrs)


def body_fn(i, state):
    state, nbrs = state
    nbrs = nbrs.update(state.position)
    state = apply_fn(state, neighbor=nbrs)
    return state, nbrs


print('Step\tKE\tPE\tTotal Energy\ttime/step')
print('----------------------------------------')
PE = []
KE = []
print_every = 4
step = 0
old_time = time.time_ns()
while step < 40:
    new_state, nbrs = lax.fori_loop(0, 100, body_fn, (state, nbrs))
    if nbrs.did_buffer_overflow:
        print("Neighbor list overflowed, reallocating.")
        nbrs = neighbor_fn.allocate(state.position)
    else:
        state = new_state
        step += 1

    if step % print_every == 0 and step > 0:
        new_time = time.time_ns()
        PE += [energy_fn(state.position, nbrs)]
        KE += [quantity.kinetic_energy(state.velocity)]
        print(
            "{}\t{:.2f}\t{:.2f}\t{:.2f}\t\t{:.2f}ms".format(
                step * print_every,
                KE[-1],
                PE[-1],
                KE[-1] + PE[-1],
                (new_time - old_time) / print_every / 100.0 / 1000000.0,
            )
        )
        old_time = new_time

Step	KE	PE	Total Energy	time/step
----------------------------------------
16	220.31	-30008.80	-29788.48		25.29ms
32	364.11	-30152.59	-29788.48		22.64ms
Neighbor list overflowed, reallocating.
48	823.34	-30611.82	-29788.48		34.18ms
64	2417.24	-32205.72	-29788.48		27.95ms
80	9413.04	-39201.53	-29788.49		26.43ms
96	28060.07	-57848.68	-29788.61		26.63ms
112	47036.03	-76824.78	-29788.75		26.72ms
128	59273.28	-89062.02	-29788.74		26.93ms
144	69008.78	-98797.53	-29788.75		26.91ms
160	78446.54	-108235.31	-29788.76		27.09ms


## UF potentials

In [8]:
from uf3.jax.potentials import uf3_neighbor

In [9]:
c = [coefficients2, coefficients3]
k = [knots2, knots3]

In [10]:
neighbor_fn, energy_fn = uf3_neighbor(displacement, side_length, coefficients=c, knots=k, cutoff=5.5)

In [11]:
nbrs = neighbor_fn.allocate(R)
energy_fn = jit(energy_fn)

In [12]:
# jax.profiler.save_device_memory_profile("potential1.prof")

In [13]:
# Run once for jit
print(energy_fn(R, nbrs))
# print(grad(energy_fn)(R, nbrs))

26811.792926319944


In [14]:
nbrs.idx.shape

(1000, 115)

In [15]:
print(grad(energy_fn)(R, nbrs))
# For #R 13824 with 115 nbrs 
# 54 GB with checkpoint
# 329 GB without checkpoint

# without checkpoint max 1000 R with 115 nbrs at 268ms
# with checkpoint max 3375 R with 115 nbrs at 1.58s


[[ 2.01390987e-13  9.00043928e-14 -2.52263488e-13]
 [ 2.54130050e-13  1.89709359e-13  2.07750483e-13]
 [ 2.13162821e-14  1.35003120e-13 -6.77027878e-14]
 ...
 [-2.20268248e-13 -1.40925466e-13  6.86881108e-14]
 [ 5.56603375e-14 -9.92608773e-14  3.31609740e-14]
 [-1.15983747e-13  1.35003120e-13 -1.13901943e-13]]


In [16]:
%%timeit
grad(energy_fn)(R, nbrs).block_until_ready()
# for #R 1000 and 115 nbrs
# without checkpoint 268ms
# with checkpoint 397ms 7 runs 1 loop

1.58 s ± 366 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [22]:
from uf3.jax.potentials import uf3_mapped

In [23]:
f = uf3_mapped(knots3)
c = xla_computation(f)(dR,dR,coefficients=coefficients3)

In [None]:
print(c.as_hlo_text())

In [16]:
displacement, shift = space.periodic(side_length)

init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3)
state = init_fn(random.PRNGKey(0), R, kT=1e-3, neighbor=nbrs)


def body_fn(i, state):
    state, nbrs = state
    nbrs = nbrs.update(state.position)
    state = apply_fn(state, neighbor=nbrs)
    return state, nbrs


print('Step\tKE\tPE\tTotal Energy\ttime/step')
print('----------------------------------------')
PE = []
KE = []
print_every = 4
step = 0
old_time = time.time_ns()
while step < 40:
    new_state, nbrs = lax.fori_loop(0, 100, body_fn, (state, nbrs))
    if nbrs.did_buffer_overflow:
        print("Neighbor list overflowed, reallocating.")
        nbrs = neighbor_fn.allocate(state.position)
    else:
        state = new_state
        step += 1

    if step % print_every == 0 and step > 0:
        new_time = time.time_ns()
        PE += [energy_fn(state.position, nbrs)]
        KE += [quantity.kinetic_energy(state.velocity)]
        print(
            "{}\t{:.2f}\t{:.2f}\t{:.2f}\t\t{:.2f}ms".format(
                step * print_every,
                KE[-1],
                PE[-1],
                KE[-1] + PE[-1],
                (new_time - old_time) / print_every / 100.0 / 1000000.0,
            )
        )
        old_time = new_time

Step	KE	PE	Total Energy	time/step
----------------------------------------
Neighbor list overflowed, reallocating.
16	97844.10	-32208.65	65635.45		447.78ms
Neighbor list overflowed, reallocating.
32	350249.14	14721.93	364971.07		561.76ms
48	615169.38	30518.33	645687.71		589.56ms
64	892525.49	33819.02	926344.51		585.80ms
80	1165043.09	39274.53	1204317.63		585.81ms
96	1419552.27	38744.08	1458296.35		585.85ms
112	1657270.45	45162.44	1702432.90		585.87ms
128	1874076.56	35335.55	1909412.11		585.92ms
Neighbor list overflowed, reallocating.
144	2127293.70	43320.50	2170614.20		747.10ms
160	2332224.48	46032.59	2378257.07		578.13ms
