<a href="https://colab.research.google.com/github/google/jax-md/blob/master/notebooks/neurips_spotlight_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Imports & Utils
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
  
sns.set_style(style='white')

import warnings
warnings.filterwarnings("ignore")

!wget -O silica_train.npz https://www.dropbox.com/s/3dojk4u4di774ve/silica_train.npz?dl=0
!wget https://github.com/google/jax-md/blob/master/examples/models/si_gnn.pickle?raw=true

import numpy as onp
from jax.api import device_put

box_size = 10.862

with open('silica_train.npz', 'rb') as f:
  files = onp.load(f)
  qm_positions, qm_energies, qm_forces = [device_put(x) for x in (files['arr_3'], files['arr_4'], files['arr_5'])]
  qm_positions = qm_positions[:300]
  qm_energies = qm_energies[:300]
  qm_forces = qm_forces[:300]

## Demo

www.github.com/google/jax-md -> notebooks -> neurips_spotlight_demo.ipynb

In [None]:
!pip install jax-md

Data from a quantum mechanical simulation of Silicon.

In [None]:
print(f'Box Size = {box_size}')
print(qm_positions.shape)
print(qm_energies.shape)
print(qm_forces.shape)

Visualize states inside colab. 

In [None]:
from jax_md.colab_tools import renderer

renderer.render(box_size, 
                {
                    'atom': renderer.Sphere(qm_positions[0]),
                }, 
                resolution=[400, 400])

### Every simulation starts by defining a space.

In [None]:
from jax_md import space

displacement_fn, shift_fn = space.periodic(box_size, wrapped=False)

The `displacement_fn` computes displacement between points

In [None]:
displacement_fn(qm_positions[0, 0], qm_positions[0, 3])

The `shift_fn` moves points

In [None]:
import jax.numpy as np

shift_fn(qm_positions[0, 0], 
         np.array([1.0, 0.0, 0.0]))

### Load a pretrained Graph Neural Network

In [None]:
from jax_md import energy

init_fn, energy_fn = energy.graph_network(displacement_fn, r_cutoff=3.0) 

In [None]:
import pickle

with open('si_gnn.pickle?raw=true', 'rb') as f:
  params = pickle.load(f)

In [None]:
print(f'Predicted E = {energy_fn(params, qm_positions[0])}')
print(f'Actual E = {qm_energies[0]}')

In [None]:
import functools

energy_fn = functools.partial(energy_fn, params) 

In [None]:
from jax import vmap

vectorized_energy_fn = vmap(energy_fn)
plt.plot(qm_energies, vectorized_energy_fn(qm_positions), 'o')
plt.show()

In [None]:
from jax_md import quantity

force_fn = quantity.force(energy_fn)
predicted_forces = force_fn(qm_positions[1])

plt.plot(qm_forces[1].reshape((-1,)), 
         predicted_forces.reshape((-1,)), 'o')
plt.show()

### Using the network in a simulation

In [None]:
from jax_md.simulate import nvt_nose_hoover

K_B = 8.617e-5
dt = 5e-3
kT = K_B * 300 
Si_mass = 2.91086E-3

init_fn, step_fn = nvt_nose_hoover(energy_fn, shift_fn, dt, kT, tau=1.0)

In [None]:
from jax import jit
step_fn = jit(step_fn)

In [None]:
from jax import random

key = random.PRNGKey(0)
state = init_fn(key, qm_positions[0], Si_mass)

In [None]:
positions = []

for i in range(5000):
  state = step_fn(state)

  if i % 25 == 0:
    positions += [state.position]

In [None]:
positions = np.stack(positions)

In [None]:
renderer.render(box_size, 
                {
                    'atom': renderer.Sphere(positions),
                }, 
                resolution=[400, 400])