<a href="https://colab.research.google.com/github/google/jax-md/blob/master/notebooks/lanl_summer_school_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Import & Util

!pip install -q git+https://www.github.com/google/jax-md
!pip install dm-haiku

import jax.numpy as np
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('pdf', 'svg')
import matplotlib.pyplot as plt
import seaborn as sns
import pickle

import warnings
warnings.simplefilter("ignore")

sns.set_style(style='white')
background_color = [56 / 256] * 3
def plot(x, y, *args):
  plt.plot(x, y, *args, linewidth=3)
  plt.gca().set_facecolor([1, 1, 1])
def draw(R, **kwargs):
  if 'c' not in kwargs:
    kwargs['color'] = [1, 1, 0.9]
  ax = plt.axes(xlim=(0, float(np.max(R[:, 0]))), 
                ylim=(0, float(np.max(R[:, 1]))))
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
  ax.set_facecolor(background_color)
  plt.scatter(R[:, 0], R[:, 1],  marker='o', s=1024, **kwargs)
  plt.gcf().patch.set_facecolor(background_color)
  plt.gcf().set_size_inches(6, 6)
  plt.tight_layout()
def draw_big(R, **kwargs):
  if 'c' not in kwargs:
    kwargs['color'] = [1, 1, 0.9]
  fig = plt.figure(dpi=128)
  ax = plt.axes(xlim=(0, float(np.max(R[:, 0]))),
                ylim=(0, float(np.max(R[:, 1]))))
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
  ax.set_facecolor(background_color)
  s = plt.scatter(R[:, 0], R[:, 1], marker='o', s=0.5, **kwargs)
  s.set_rasterized(True)
  plt.gcf().patch.set_facecolor(background_color)
  plt.gcf().set_size_inches(10, 10)
  plt.tight_layout()
def draw_displacement(R, dR):
  plt.quiver(R[:, 0], R[:, 1], dR[:, 0], dR[:, 1], color=[1, 0.5, 0.5])

# Progress Bars

from IPython.display import HTML, display
import time

def ProgressIter(iter_fun, iter_len=0):
  if not iter_len:
    iter_len = len(iter_fun)
  out = display(progress(0, iter_len), display_id=True)
  for i, it in enumerate(iter_fun):
    yield it
    out.update(progress(i + 1, iter_len))

def progress(value, max):
    return HTML("""
        <progress
            value='{value}'
            max='{max}',
            style='width: 45%'
        >
            {value}
        </progress>
    """.format(value=value, max=max))

# Data Loading

!wget -O silica_train.npz https://www.dropbox.com/s/3dojk4u4di774ve/silica_train.npz?dl=0
!wget https://github.com/google/jax-md/blob/master/examples/models/si_gnn.pickle?raw=true

import numpy as onp

with open('silica_train.npz', 'rb') as f:
  files = onp.load(f)
  Rs, Es, Fs = [np.array(x) for x in (files['arr_0'], files['arr_1'], files['arr_2'])]
  Rs = Rs[:200]
  Es = Es[:200]
  Fs = Fs[:200]
  test_Rs, test_Es, test_Fs = [np.array(x) for x in (files['arr_3'], files['arr_4'], files['arr_5'])]

## LANL Summer School Demo

www.github.com/google/jax-md -> notebooks -> lanl_summer_school_demo.ipynb

### Energy and Automatic Differentiation

$u(r) = \begin{cases}\frac13(1 - r)^3 & \text{if $r < 1$} \\ 0 & \text{otherwise} \end{cases}$

In [None]:
import jax.numpy as np

def soft_sphere(r):
  return np.where(r < 1, 
                  1/3 * (1 - r) ** 3,
                  0.)

print(soft_sphere(0.5))

In [None]:
r = np.linspace(0, 2., 200)
plot(r, soft_sphere(r))

We can compute its derivative automatically

In [None]:
from jax.api import grad

du_dr = grad(soft_sphere)

print(du_dr(0.5))

We can vectorize the derivative computation over many radii

In [None]:
from jax.api import vmap

du_dr_v = vmap(du_dr)

plot(r, soft_sphere(r))
plot(r, -du_dr_v(r))

### Randomly Initialize a System

In [None]:
from jax import random

key = random.PRNGKey(0)

particle_count = 128
dim = 2

In [None]:
from jax_md.quantity import box_size_at_number_density

# number_density = N / V
side_length = box_size_at_number_density(particle_count = particle_count, 
                                         number_density = 1.0, 
                                         spatial_dimension = dim)

R = random.uniform(key, (particle_count, dim), maxval=side_length)
draw(R)

### Displacements and Distances


In [None]:
from jax_md import space

displacement, shift = space.periodic(side_length)

print(displacement(R[0], R[1]))

In [None]:
metric = space.metric(displacement)

print(metric(R[0], R[1]))

Compute distances between pairs of points

In [None]:
displacement = space.map_product(displacement)
metric = space.map_product(metric)

print(metric(R[:3], R[:3]))

### Total energy of a system

In [None]:
def energy(R):
  dr = metric(R, R)
  return 0.5 * np.sum(soft_sphere(dr))

In [None]:
print(energy(R))

In [None]:
print(grad(energy)(R).shape)

### Minimization

In [None]:
from jax_md.minimize import fire_descent

init_fn, apply_fn = fire_descent(energy, shift)

state = init_fn(R)

while np.max(np.abs(state.force)) > 1e-3:
  state = apply_fn(state)

draw(state.position)

In [None]:
cond_fn = lambda state: np.max(np.abs(state.force)) > 1e-3

### Making it Fast

In [None]:
def minimize(R):
  init, apply = fire_descent(energy, shift)

  state = init(R)

  for _ in range(20):
    state = apply(state)

  return energy(state.position)

In [None]:
%%timeit
minimize(R).block_until_ready()

In [None]:
from jax.api import jit

# Just-In-Time compile to GPU
minimize = jit(minimize)

In [None]:
# The first call incurs a compilation cost
minimize(R)

In [None]:
%%timeit
minimize(R).block_until_ready()

In [None]:
from jax.lax import while_loop

def minimize(R):
  init_fn, apply_fn = fire_descent(energy, shift)

  state = init_fn(R)
  # Using a JAX loop reduces compilation cost
  state = while_loop(cond_fun=cond_fn,
                     body_fun=apply_fn,
                     init_val=state)

  return state.position

In [None]:
from jax.api import jit

minimize = jit(minimize)

In [None]:
R_is = minimize(R)

In [None]:
%%timeit
minimize(R).block_until_ready()

### Elastic Moduli

In [None]:
def strain_energy(strain, R):
  dR = np.dot(displacement(R, R), strain)
  dr = space.distance(dR)
  return 0.5 * np.sum(soft_sphere(dr))

In [None]:
strain_energy(np.eye(2),  R_is)

In [None]:
from jax.api import hessian

K = hessian(strain_energy)(np.eye(2),  R_is)
print(K.shape)

In [None]:
print(K)

In [None]:
from jax_md.quantity import bulk_modulus

B = bulk_modulus(K)
print(B)

In [None]:
G = K[0, 1, 0, 1]
print(G)

In [None]:
from functools import partial

@jit
def elastic_moduli(number_density, key):
  # Randomly initialize particles.
  side_length = box_size_at_number_density(particle_count    = particle_count, 
                                           number_density    = number_density, 
                                           spatial_dimension = dim)
  R = random.uniform(key, (particle_count, dim))

  displacement, shift = space.periodic(side_length)
  displacement = space.map_product(displacement)

  # Define an energy function at a specific strain.
  def energy(strain, R):
    dR = displacement(R, R) @ strain
    dr = space.distance(dR)
    return 0.5 * np.sum(soft_sphere(dr))

  # Minimize at no strain.
  init_fn, apply_fn = fire_descent(partial(energy, np.eye(2)), shift)

  state = init_fn(R)
  state = while_loop(cond_fn, apply_fn, state)

  # Compute the elastic constants.
  K = hessian(energy)(np.eye(2), state.position)
  return bulk_modulus(K), K[0, 1, 0, 1]

In [None]:
number_densities = np.linspace(1.1, 1.6, 40)

elastic_moduli = vmap(elastic_moduli, in_axes=(0, None))
B, G = elastic_moduli(number_densities, key)

plot(number_densities, B)
plot(number_densities, G)

In [None]:
plot(number_densities, G / B)
plt.ylim([0, 1])

In [None]:
keys = random.split(key, 5)

elastic_moduli = vmap(elastic_moduli, in_axes=(None, 0))
B, G = elastic_moduli(number_densities, keys)

for b, g in zip(B, G):
  plt.plot(number_densities, g / b)

plot(number_densities, np.mean(G / B, axis=0), 'k')
plt.ylim([0, 1])

### Going Big

In [None]:
key = random.PRNGKey(0)

particle_count = 128000
side_length = box_size_at_number_density(particle_count    = particle_count, 
                                         number_density    = 1.0, 
                                         spatial_dimension = dim)


R = random.uniform(key, (particle_count, dim)) * side_length

displacement, shift = space.periodic(side_length)

draw_big(R)

In [None]:
from jax_md.energy import soft_sphere_neighbor_list

neighbor_fn, energy_fn = soft_sphere_neighbor_list(displacement, side_length)

init_fn, apply_fn = fire_descent(energy_fn, shift)

In [None]:
nbrs = neighbor_fn(R)
print(nbrs.idx.shape)

In [None]:
from jax.lax import fori_loop

state = init_fn(R, neighbor=nbrs)

@jit
def take_steps(state_and_nbrs):
  def step(i, state_and_nbrs):
    state, nbrs = state_and_nbrs
    nbrs = neighbor_fn(state.position, nbrs)
    state = apply_fn(state, neighbor=nbrs)
    return state, nbrs
  return fori_loop(0, 10, step, state_and_nbrs)

for i in ProgressIter(range(80)):
  new_state, nbrs = take_steps((state, nbrs))

  if nbrs.did_buffer_overflow:
    nbrs = neighbor_fn(state.position)
  else:
    state = new_state

draw_big(state.position)

## Neural Network Potentials

Here is some data we loaded of a 64-atom Silicon system computed using DFT.

In [None]:
print(Rs.shape)  # Positions
print(Es.shape)  # Energies
print(Fs.shape)  # Forces

In [None]:
E_mean = np.mean(Es)
E_std = np.std(Es)

print(f'E_mean = {E_mean}, E_std = {E_std}')

In [None]:
plt.hist(Es)

Setup the system and a Graph Neural Network energy function

In [None]:
box_size = 10.862
displacement, shift = space.periodic(box_size)

In [None]:
from jax_md.energy import graph_network

init_fn, energy_fn = graph_network(displacement, r_cutoff=3.0)

with open('si_gnn.pickle?raw=true', 'rb') as f:
  params = pickle.load(f)

In [None]:
energy_fn(params, test_Rs[0])

In [None]:
@jit
def scaled_energy_fn(R, **kwargs):
  return energy_fn(params, R) * E_std + E_mean

In [None]:
predicted_Es = vmap(scaled_energy_fn)(Rs)
plt.plot(Es, predicted_Es, 'o')

In [None]:
from jax_md.quantity import force

force_fn = force(scaled_energy_fn)
predicted_Fs = force_fn(Rs[1])

plt.plot(Fs[1].reshape((-1,)), predicted_Fs.reshape((-1,)), 'o')

This energy can be used in a simulation

In [None]:
from jax_md.simulate import nvt_nose_hoover
from jax_md.quantity import temperature

K_B = 8.617e-5
dt = 1e-3
kT = K_B * 300 
Si_mass = 2.91086E-3

init_fn, apply_fn = nvt_nose_hoover(scaled_energy_fn, shift, dt, kT, tau=2.5)

apply_fn = jit(apply_fn)

In [None]:
state = init_fn(key, Rs[0], Si_mass, T_initial=300 * K_B)

@jit
def take_steps(state):
  return fori_loop(0, 1000, lambda i, state: apply_fn(state), state)

print('Energy (eV)\tTemperature (K)')
for i in range(10):
  state = take_steps(state)

  print('{:.02f}\t\t\t{:.02f}'.format(
      scaled_energy_fn(state.position),
      temperature(state.velocity, Si_mass) / K_B))