<a href="https://colab.research.google.com/github/routhleck/jax-md/blob/main/examples-with-unit/2-Faster_Simulation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
#@title Import & Util

# %%capture

# !pip install git+https://github.com/google/jax-md
!pip install git+https://github.com/routhleck/jax-md.git
!pip install brainunit
!pip install brainstate

import jax.numpy as np
from jax import device_put
from jax import config
# TODO: Uncomment this and enable warnings when XLA bug is fixed.
import warnings; warnings.simplefilter('ignore')
config.update('jax_enable_x64', True)
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('pdf', 'svg')
import matplotlib.pyplot as plt
import seaborn as sns
import pickle

import warnings
warnings.simplefilter("ignore")

sns.set_style(style='white')
background_color = [56 / 256] * 3
def plot(x, y, *args):
  plt.plot(x, y, *args, linewidth=3)
  plt.gca().set_facecolor([1, 1, 1])
def draw(R, **kwargs):
  if 'c' not in kwargs:
    kwargs['color'] = [1, 1, 0.9]
  ax = plt.axes(xlim=(0, float(jnp.max(R[:, 0]))), 
                ylim=(0, float(jnp.max(R[:, 1]))))
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
  ax.set_facecolor(background_color)
  plt.scatter(R[:, 0], R[:, 1],  marker='o', s=1024, **kwargs)
  plt.gcf().patch.set_facecolor(background_color)
  plt.gcf().set_size_inches(6, 6)
  plt.tight_layout()
def draw_big(R, **kwargs):
  if 'c' not in kwargs:
    kwargs['color'] = [1, 1, 0.9]
  fig = plt.figure(dpi=128)
  ax = plt.axes(xlim=(0, float(jnp.max(R[:, 0]))),
                ylim=(0, float(jnp.max(R[:, 1]))))
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
  ax.set_facecolor(background_color)
  s = plt.scatter(R[:, 0], R[:, 1], marker='o', s=0.5, **kwargs)
  s.set_rasterized(True)
  plt.gcf().patch.set_facecolor(background_color)
  plt.gcf().set_size_inches(10, 10)
  plt.tight_layout()
def draw_displacement(R, dR):
  plt.quiver(R[:, 0], R[:, 1], dR[:, 0], dR[:, 1], color=[1, 0.5, 0.5])

# Progress Bars

from IPython.display import HTML, display
import time

def ProgressIter(iter_fun, iter_len=0):
  if not iter_len:
    iter_len = len(iter_fun)
  out = display(progress(0, iter_len), display_id=True)
  for i, it in enumerate(iter_fun):
    yield it
    out.update(progress(i + 1, iter_len))

def progress(value, max):
    return HTML("""
        <progress
            value='{value}'
            max='{max}',
            style='width: 45%'
        >
            {value}
        </progress>
    """.format(value=value, max=max))

# 2-Faster Simulation

### Energy and Automatic Differentiation

$$
u(r) = \begin{cases}\frac13(1 - r)^3 & \text{if $r < 1$} \\ 0 & \text{otherwise} \end{cases}
$$

In [2]:
import jax.numpy as jnp
import brainunit as u
import brainstate as bst

@u.assign_units(r=u.angstrom, result=u.eV)
def soft_sphere(r):
  return jnp.where(r < 1, 
                   1/3 * (1 - r) ** 3,
                   0.)

print(soft_sphere(0.5 * u.angstrom))

0.04166667 * electronvolt


### Randomly Initialize a System

In [3]:
from jax import random

key = random.PRNGKey(1)

particle_count = 128
number_density = 1.2 / u.angstrom ** 2
dim = 2

In [4]:
from jax_md import quantity

# number_density = N / V
box_size = quantity.box_size_at_number_density(particle_count = particle_count, 
                                               number_density = number_density, 
                                               spatial_dimension = dim)

R = bst.random._random_for_unit.uniform_for_unit(key, (particle_count, dim), minval=0*u.angstrom, maxval=box_size)

### Displacements and Distances


In [15]:
from jax_md import space

displacement, shift = space.periodic(box_size)

print(displacement(R[0], R[1]))

ArrayImpl([ 2.6709671, -4.09407854], dtype=float32) * angstrom


In [17]:
metric = space.metric(displacement)

print(metric(R[0], R[1]))

4.888307 * angstrom


Compute distances between pairs of points

In [18]:
v_displacement = space.map_product(displacement)
v_metric = space.map_product(metric)

print(v_metric(R[:3], R[:3]))

ArrayImpl([[0.       , 4.88830709, 4.64909458],
           [4.88830709, 0.       , 4.22363997],
           [4.64909458, 4.22363997, 0.       ]], dtype=float32) * angstrom


### Total Energy of a System

In [13]:
def energy_fn(R):
  dr = v_metric(R, R)
  return 0.5 * u.math.sum(soft_sphere(dr))

### Minimization

### Faster Simulation Through Compilation

In [5]:
import jax

cond_fn = lambda state: u.math.max(u.math.abs(state.force)) > 1e-4 * u.IMF

In [11]:
from jax_md import minimize
def min(R):
  init, apply = minimize.fire_descent(energy_fn, shift)

  state = init(R)

  for _ in range(20):
    state = apply(state)

  return energy_fn(state.position)

In [19]:
%%timeit
jax.block_until_ready(min(R))

96.6 ms ± 7.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [20]:
from jax import jit

# Just-In-Time compile to GPU
min = jit(min)

In [21]:
# The first call incurs a compilation cost
min(R)

22.277498 * electronvolt

In [22]:
%%timeit
jax.block_until_ready(min(R))

6.15 ms ± 101 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [23]:
from jax import lax

@jit
def min(R):
  init_fn, apply_fn = minimize.fire_descent(energy_fn, shift)

  state = init_fn(R)
  # Using a JAX loop reduces compilation cost
  state = lax.while_loop(cond_fun=cond_fn,
                         body_fun=apply_fn,
                         init_val=state)

  return state.position

In [24]:
R_is = min(R)

In [25]:
%%timeit
jax.block_until_ready(min(R))

128 ms ± 2.72 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
