<a href="https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/talk_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Import & Util

!pip install git+https://github.com/google/jax-md.git
!pip install dm-haiku
!pip install optax

import jax.numpy as np
from jax import device_put
from jax.config import config
# TODO: Uncomment this and enable warnings when XLA bug is fixed.
import warnings; warnings.simplefilter('ignore')
config.update('jax_enable_x64', True)
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('pdf', 'svg')
import matplotlib.pyplot as plt
import seaborn as sns
import pickle

import warnings
warnings.simplefilter("ignore")

sns.set_style(style='white')
background_color = [56 / 256] * 3
def plot(x, y, *args):
  plt.plot(x, y, *args, linewidth=3)
  plt.gca().set_facecolor([1, 1, 1])
def draw(R, **kwargs):
  if 'c' not in kwargs:
    kwargs['color'] = [1, 1, 0.9]
  ax = plt.axes(xlim=(0, float(np.max(R[:, 0]))), 
                ylim=(0, float(np.max(R[:, 1]))))
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
  ax.set_facecolor(background_color)
  plt.scatter(R[:, 0], R[:, 1],  marker='o', s=1024, **kwargs)
  plt.gcf().patch.set_facecolor(background_color)
  plt.gcf().set_size_inches(6, 6)
  plt.tight_layout()
def draw_big(R, **kwargs):
  if 'c' not in kwargs:
    kwargs['color'] = [1, 1, 0.9]
  fig = plt.figure(dpi=128)
  ax = plt.axes(xlim=(0, float(np.max(R[:, 0]))),
                ylim=(0, float(np.max(R[:, 1]))))
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
  ax.set_facecolor(background_color)
  s = plt.scatter(R[:, 0], R[:, 1], marker='o', s=0.5, **kwargs)
  s.set_rasterized(True)
  plt.gcf().patch.set_facecolor(background_color)
  plt.gcf().set_size_inches(10, 10)
  plt.tight_layout()
def draw_displacement(R, dR):
  plt.quiver(R[:, 0], R[:, 1], dR[:, 0], dR[:, 1], color=[1, 0.5, 0.5])

# Progress Bars

from IPython.display import HTML, display
import time

def ProgressIter(iter_fun, iter_len=0):
  if not iter_len:
    iter_len = len(iter_fun)
  out = display(progress(0, iter_len), display_id=True)
  for i, it in enumerate(iter_fun):
    yield it
    out.update(progress(i + 1, iter_len))

def progress(value, max):
    return HTML("""
        <progress
            value='{value}'
            max='{max}',
            style='width: 45%'
        >
            {value}
        </progress>
    """.format(value=value, max=max))

# Data Loading

!wget -O silica_train.npz https://www.dropbox.com/s/3dojk4u4di774ve/silica_train.npz?dl=0
!wget https://raw.githubusercontent.com/google/jax-md/main/examples/models/si_gnn.pickle

import numpy as onp

with open('silica_train.npz', 'rb') as f:
  files = onp.load(f)
  Rs, Es, Fs = [device_put(x) for x in (files['arr_0'], files['arr_1'], files['arr_2'])]
  Rs = Rs[:10]
  Es = Es[:10]
  Fs = Fs[:10]
  test_Rs, test_Es, test_Fs = [device_put(x) for x in (files['arr_3'], files['arr_4'], files['arr_5'])]
  test_Rs = test_Rs[:200]
  test_Es = test_Es[:200]
  test_Fs = test_Fs[:200]

def tile(box_size, positions, tiles):
  pos = positions
  for dx in range(tiles):
    for dy in range(tiles):
      for dz in range(tiles):
        if dx == 0 and dy == 0 and dz == 0:
          continue
        pos = np.concatenate((pos, positions + box_size * np.array([[dx, dy, dz]])))
  box_size = box_size * tiles
  pos /= box_size
  return box_size, pos

## Demo

www.github.com/google/jax-md -> notebooks -> talk_demo.ipynb

### Energy and Automatic Differentiation

$u(r) = \begin{cases}\frac13(1 - r)^3 & \text{if $r < 1$} \\ 0 & \text{otherwise} \end{cases}$

In [None]:
import jax.numpy as np

def soft_sphere(r):
  return np.where(r < 1, 
                  1/3 * (1 - r) ** 3,
                  0.)

print(soft_sphere(0.5))

In [None]:
r = np.linspace(0, 2., 200)
plot(r, soft_sphere(r))

We can compute its derivative automatically

In [None]:
from jax import grad

du_dr = grad(soft_sphere)

print(du_dr(0.5))

We can vectorize the derivative computation over many radii

In [None]:
from jax import vmap

du_dr_v = vmap(du_dr)

plot(r, soft_sphere(r))
plot(r, -du_dr_v(r))

### Randomly Initialize a System

In [None]:
from jax import random

key = random.PRNGKey(0)

particle_count = 128
dim = 2

In [None]:
from jax_md.quantity import box_size_at_number_density

# number_density = N / V
box_size = box_size_at_number_density(particle_count = particle_count, 
                                      number_density = 1.2, 
                                      spatial_dimension = dim)

R = random.uniform(key, (particle_count, dim), maxval=box_size)

In [None]:
from jax_md.colab_tools import renderer
renderer.render(box_size, renderer.Disk(R), resolution=(512, 512))

### Displacements and Distances


In [None]:
from jax_md import space

displacement, shift = space.periodic(box_size)

print(displacement(R[0], R[1]))

In [None]:
metric = space.metric(displacement)

print(metric(R[0], R[1]))

Compute distances between pairs of points

In [None]:
displacement = space.map_product(displacement)
metric = space.map_product(metric)

print(metric(R[:3], R[:3]))

### Total energy of a system

In [None]:
def energy(R):
  dr = metric(R, R)
  return 0.5 * np.sum(soft_sphere(dr))

In [None]:
print(energy(R))

In [None]:
print(grad(energy)(R).shape)

### Minimization

In [None]:
from jax_md.minimize import fire_descent

init_fn, apply_fn = fire_descent(energy, shift)

In [None]:
state = init_fn(R)

trajectory = []

while np.max(np.abs(state.force)) > 1e-4:
  state = apply_fn(state)
  trajectory += [state.position]

In [None]:
trajectory = np.stack(trajectory)

renderer.render(box_size,
                renderer.Disk(trajectory),
                resolution=(512, 512))

In [None]:
cond_fn = lambda state: np.max(np.abs(state.force)) > 1e-4

### Making it Fast

In [None]:
def minimize(R):
  init, apply = fire_descent(energy, shift)

  state = init(R)

  for _ in range(20):
    state = apply(state)

  return energy(state.position)

In [None]:
%%timeit
minimize(R).block_until_ready()

In [None]:
from jax import jit

# Just-In-Time compile to GPU
minimize = jit(minimize)

In [None]:
# The first call incurs a compilation cost
minimize(R)

In [None]:
%%timeit
minimize(R).block_until_ready()

In [None]:
from jax.lax import while_loop

def minimize(R):
  init_fn, apply_fn = fire_descent(energy, shift)

  state = init_fn(R)
  # Using a JAX loop reduces compilation cost
  state = while_loop(cond_fun=cond_fn,
                     body_fun=apply_fn,
                     init_val=state)

  return state.position

In [None]:
from jax import jit

minimize = jit(minimize)

In [None]:
R_is = minimize(R)

In [None]:
%%timeit
minimize(R).block_until_ready()

### Physical Properties

In [None]:
displacement, shift = space.periodic(box_size)

In [None]:
from jax_md import energy

soft_sphere = energy.soft_sphere_pair(displacement, alpha=3)

In [None]:
from jax_md import quantity

quantity.pressure(soft_sphere, R_is, box_size)

In [None]:
quantity.stress(soft_sphere, R_is, box_size)

In [None]:
from jax_md import elasticity
moduli_fn = elasticity.athermal_moduli(soft_sphere, tether_strength=1e-4)

In [None]:
elastic_constants = moduli_fn(R_is, box_size)

In [None]:
quantity.bulk_modulus(elastic_constants)

In [None]:
from functools import partial

@jit
def elastic_moduli(number_density, key):
  # Randomly initialize particles.
  box_size = box_size_at_number_density(particle_count    = particle_count, 
                                        number_density    = number_density, 
                                        spatial_dimension = dim)
  R = random.uniform(key, (particle_count, dim), maxval=box_size)

  # Create the space and energy function.
  displacement, shift = space.periodic_general(box_size, 
                                               fractional_coordinates=False)
  soft_sphere = energy.soft_sphere_pair(displacement, alpha=3)

  # Minimize at no strain.
  init_fn, apply_fn = fire_descent(soft_sphere, shift)

  state = init_fn(R)
  state = while_loop(cond_fn, apply_fn, state)

  # Compute the bulk modulus.
  elastic_constants = moduli_fn(state.position, box_size)
  return quantity.bulk_modulus(elastic_constants)

In [None]:
number_densities = np.linspace(1.25, 1.6, 20)

elastic_moduli = vmap(elastic_moduli, in_axes=(0, None))
B = elastic_moduli(number_densities, key)

plot(number_densities, B)

In [None]:
keys = random.split(key, 10)

elastic_moduli = vmap(elastic_moduli, in_axes=(None, 0))
B_ensemble = elastic_moduli(number_densities, keys)

for B in B_ensemble:
  plt.plot(number_densities, B)

plot(number_densities, np.mean(B_ensemble, axis=0), 'k')

### Going Big

In [None]:
key = random.PRNGKey(0)

particle_count = 128000
box_size = box_size_at_number_density(particle_count    = particle_count, 
                                      number_density    = 1.0, 
                                      spatial_dimension = dim)


R = random.uniform(key, (particle_count, dim)) * box_size

displacement, shift = space.periodic(box_size)

renderer.render(box_size,
                renderer.Disk(R),
                resolution=(512, 512))

In [None]:
from jax_md.energy import soft_sphere_neighbor_list

neighbor_fn, energy_fn = soft_sphere_neighbor_list(displacement, box_size)

init_fn, apply_fn = fire_descent(energy_fn, shift)

In [None]:
nbrs = neighbor_fn.allocate(R)
print(nbrs.idx.shape)

In [None]:
state = init_fn(R, neighbor=nbrs)

def cond_fn(state_and_nbrs):
  state, _ = state_and_nbrs
  return np.any(np.abs(state.force) > 1e-4)

def step_fn(state_and_nbrs):
  state, nbrs = state_and_nbrs
  nbrs = nbrs.update(state.position)
  state = apply_fn(state, neighbor=nbrs)
  return state, nbrs

state, nbrs = while_loop(cond_fn,
                         step_fn,
                         (state, nbrs))

renderer.render(box_size,
                renderer.Disk(state.position),
                resolution=(700, 700))

In [None]:
nbrs.did_buffer_overflow

In [None]:
nbrs = neighbor_fn.allocate(state.position)

In [None]:
nbrs.idx.shape

## Neural Network Potentials

Here is some data we loaded of a 64-atom Silicon system computed using DFT.

In [None]:
print(Rs.shape)  # Positions
print(Es.shape)  # Energies
print(Fs.shape)  # Forces

In [None]:
E_mean = np.mean(Es)
E_std = np.std(Es)

print(f'E_mean = {E_mean}, E_std = {E_std}')

In [None]:
plt.hist(Es)

Setup the system and a Graph Neural Network energy function

In [None]:
box_size = 10.862
displacement, shift = space.periodic(box_size)

In [None]:
init_fn, energy_fn = energy.graph_network(displacement, r_cutoff=3.0)

In [None]:
params = init_fn(key, test_Rs[0])
energy_fn(params, test_Rs[0])

In [None]:
vectorized_energy_fn = vmap(energy_fn, (None, 0))
predicted_Es = vectorized_energy_fn(params, test_Rs)
plt.plot(test_Es, predicted_Es, 'o')

Define a loss function.

In [None]:
def energy_loss_fn(params):
  return np.mean((vectorized_energy_fn(params, Rs) - Es) ** 2)

def force_loss_fn(params):
  # We want the gradient with respect to the position, not the parameters.
  grad_fn = vmap(grad(energy_fn, argnums=1), (None, 0))
  return np.mean((grad_fn(params, Rs) + Fs) ** 2)

@jit
def loss_fn(params):
  return energy_loss_fn(params) + force_loss_fn(params)

Take a few steps of gradient descent.

In [None]:
import optax

opt = optax.chain(optax.clip_by_global_norm(0.01),
                  optax.adam(1e-4))

opt_state = opt.init(params)

@jit
def update(params, opt_state):
  updates, opt_state = opt.update(grad(loss_fn)(params), opt_state)
  return optax.apply_updates(params, updates), opt_state

for i in ProgressIter(range(100)):
  params, opt_state = update(params, opt_state)
  if i % 10 == 0:
    print(f'Loss at step {i} is {loss_fn(params)}')

In [None]:
predicted_Es = vectorized_energy_fn(params, test_Rs)
plt.plot(test_Es, predicted_Es, 'o')

Now load a pretrained model.

In [None]:
with open('si_gnn.pickle', 'rb') as f:
  params = pickle.load(f)

In [None]:
from functools import partial
energy_fn = partial(energy_fn, params)

In [None]:
predicted_Es = vmap(energy_fn)(test_Rs)
plt.plot(test_Es, predicted_Es, 'o')

In [None]:
from jax_md.quantity import force

force_fn = force(energy_fn)
predicted_Fs = force_fn(test_Rs[1])

plt.plot(test_Fs[1].reshape((-1,)), predicted_Fs.reshape((-1,)), 'o')

This energy can be used in a simulation

In [None]:
from jax_md.simulate import nvt_nose_hoover
from jax_md.quantity import temperature

K_B = 8.617e-5
dt = 1e-3
kT = K_B * 300 
Si_mass = 2.91086E-3

init_fn, apply_fn = nvt_nose_hoover(energy_fn, shift, dt, kT)

apply_fn = jit(apply_fn)

In [None]:
from jax.lax import fori_loop

state = init_fn(key, Rs[0], Si_mass, T_initial=300 * K_B)

@jit
def take_steps(state):
  return fori_loop(0, 100, lambda i, state: apply_fn(state), state)

times = np.arange(100) * dt
temperatures = []
trajectory = []

for _ in ProgressIter(times):
  state = take_steps(state)

  temperatures += [temperature(state.velocity, Si_mass) / K_B]
  trajectory += [state.position]

In [None]:
plot(times, temperatures)

In [None]:
trajectory = np.stack(trajectory)

renderer.render(box_size,
                renderer.Sphere(trajectory),
                resolution=(512,512))

Or, we can use neighbor lists to simulate a larger system.

In [None]:
box_size, R = tile(box_size, Rs[0], 4)

In [None]:
len(R)

In [None]:
from jax_md.space import periodic_general
from jax_md.energy import graph_network_neighbor_list

displacement, shift = space.periodic_general(box_size)

neighbor_fn, _, energy_fn = graph_network_neighbor_list(displacement, 
                                                        box_size,
                                                        r_cutoff=3.0,
                                                        dr_threshold=0.5,
                                                        fractional_coordinates=True)
energy_fn = partial(energy_fn, params)

In [None]:
nbrs = neighbor_fn.allocate(R)
nbrs.idx.shape

This time we'll run an NPT simulation.

In [None]:
from jax_md.simulate import npt_nose_hoover

K_B = 8.617e-5
dt = 1e-3
P_start = 0.0  
P_end = 0.05
kT = K_B * 300
Si_mass = 2.81086E-3

init_fn, step_fn = npt_nose_hoover(energy_fn, shift, dt, P_start, kT)

In [None]:
inner_steps = 20

@jit
def take_steps(state, nbrs, pressure):
  def sim_fn(i, state_nbrs):
    state, nbrs = state_nbrs
    state = step_fn(state, pressure=pressure, neighbor=nbrs)
    nbrs = nbrs.update(state.position, box=state.box)
    return state, nbrs  
  return fori_loop(0, inner_steps, sim_fn, (state, nbrs))

In [None]:
@jit
def compute_diagnostics(state, nbrs):
  temperature = quantity.temperature(momentum=state.momentum, mass=Si_mass) / K_B
  kinetic_energy = quantity.kinetic_energy(momentum=state.momentum, mass=Si_mass)
  pressure = quantity.pressure(energy_fn, state.position, state.box, kinetic_energy, neighbor=nbrs)
  position = space.transform(state.box, state.position)
  return temperature, pressure, position

In [None]:
total_steps = 2000
times = np.arange(0, total_steps, inner_steps) * dt
temperatures = []
pressures = []
trajectory = []

state = init_fn(key, R, box_size, Si_mass, neighbor=nbrs)

for t in ProgressIter(times):
  P_target = P_start if t < times[-1] / 2 else P_end
  state, nbrs = take_steps(state, nbrs, P_target)
  temperature, pressure, position = compute_diagnostics(state, nbrs)

  temperatures += [temperature]
  pressures += [pressure]
  trajectory += [position]

In [None]:
nbrs.did_buffer_overflow

In [None]:
plot(times, pressures)

In [None]:
plot(times, temperatures)

In [None]:
from jax_md import partition
trajectory = np.stack(trajectory)

blue = [0.2, 0.2, 1.0]
red = [1.0, 0.2, 0.2]

renderer.render(box_size,
                {
                    'atoms': renderer.Sphere(trajectory, color=blue),
                    'bonds': renderer.Bond('atoms', nbrs, color=red)
                },
                resolution=(512,512))