<a href="https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/symd.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Sy(JAX)MD

In [166]:
%%capture

!pip install jax-md
!pip install symd

In [None]:
from symd import symd, groups 

import jax.numpy as jnp
from jax import random
from jax import config; config.update('jax_enable_x64', True)

## Setup a symmetric system using SyMD

Setup some simulation parameters, initialize the spatial group, and the constraint function.

In [None]:
group = 11
N = 1000
dim = 2


group = groups.load_group(group, dim)
in_unit = symd.asymm_constraints(group.asymm_unit)

Randomly initialize positions and velocities.

In [None]:

key = random.PRNGKey(0)
key, pos_key, vel_key = random.split(key, 3)
pos_key, vel_key = random.split(random.PRNGKey(0))
positions = random.uniform(pos_key, (N, dim))
positions = positions[jnp.array([in_unit(*p) for p in positions])]
N = positions.shape[0]

velocities = random.normal(vel_key, (N, dim))

Transform the positions and velocities using homogeneous coordinates to get all of the images.

In [None]:
homo_positions = jnp.concatenate((positions, jnp.ones((N, 1))), axis=-1)
homo_velocities = jnp.concatenate((velocities, jnp.zeros((N, 1))), axis=-1)
positions = []
velocities = []
colors = []

for s in group.genpos:
  g = symd.str2mat(s)
  xp = homo_positions @ g
  xp = jnp.fmod(xp, 1.0)
  positions += [xp[:, :2]]
  xv = homo_velocities @ g
  velocities += [xv[:, :2]]
  key, split = random.split(key)
  colors += [random.uniform(split, (1, 3)) * jnp.ones((N, 1))]

positions = jnp.concatenate(positions, axis=0) + 0.5
velocities = jnp.concatenate(velocities, axis=0)
colors = jnp.concatenate(colors, axis=0)

Transform the positions from fractional coordinates to real space (not necessary).

In [None]:
from jax_md import quantity
box = quantity.box_size_at_number_density(len(positions), 0.1, 2)
positions = positions * box

### Visualize the initial system using JAX MD

In [None]:
from jax_md import space
from jax_md.colab_tools import renderer

renderer.render(box, 
                renderer.Disk(positions, color=colors),
                resolution=(512, 512),
                background_color=[1, 1, 1])

## Simulate the system using JAX MD

First setup the space and a Lennard-Jones potential.

In [None]:
from jax import jit
from jax_md import space
from jax_md import energy
from jax_md import simulate   
from jax_md import minimize   
from jax_md import dataclasses

In [None]:
displacement, shift = space.periodic(box)
neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(displacement, box)

Perform a few steps of minimization so that the Lennard-Jones particles don't become unstable.


In [None]:
init_fn, step_fn = minimize.fire_descent(energy_fn, shift, dt_start=1e-7, dt_max=4e-7)
step_fn = jit(step_fn)

In [None]:
@jit
def sim_fn(state, nbrs):
  state = step_fn(state, neighbor=nbrs)
  nbrs = nbrs.update(state.position)
  return state, nbrs

In [None]:
# Setup the neighbor list (we have to allocate extra capacity so it doesn't
# overflow during the simulation).
nbrs = neighbor_fn.allocate(positions, extra_capacity=6)

# Initialize the minimizer.
state = init_fn(positions, neighbor=nbrs)

# Run 100 steps of minimization.
for i in range(100):
  state, nbrs = sim_fn(state, nbrs)
print(f'Did neighborlist overflow: {nbrs.did_buffer_overflow}')

Now do a simulation at constant temperature. First initialize the simulation environment.

In [None]:
init_fn, step_fn = simulate.nvt_nose_hoover(energy_fn, shift, dt=1e-3, kT=0.8)
step_fn = jit(step_fn)

Define a helper function to re-fold the particles after each step.

In [None]:
def fold_particles(group, box, n):
  def fold_fn(state):
    R, V = state.position, state.velocity
    R = R / box - 0.5
    R_homo = jnp.concatenate((R[:n], jnp.ones((n, 1))), axis=-1)
    V_homo = jnp.concatenate((V[:n], jnp.zeros((n, 1))), axis=-1)
    for i, s in enumerate(group.genpos):
      g = symd.str2mat(s)
      R = R.at[i * n:(i + 1) * n].set(jnp.fmod(R_homo @ g, 1.0)[:, :2])
      V = V.at[i * n:(i + 1) * n].set((V_homo @g)[:, :2])
    R = box * (R + 0.5)
    return dataclasses.replace(state, position=R, velocity=V)
  return fold_fn

Create the folding function and initialize the simulation.

In [None]:
fold_fn = fold_particles(group, box, N)

In [None]:
state = init_fn(key, state.position, neighbor=nbrs)
# We need to replace the velocities that JAX MD generates with the symmetric 
# velocities.
state = dataclasses.replace(state, velocity=velocities)

Run the simulation for 20000 steps, recording every 100 steps.

In [None]:
from jax import lax

def sim_fn(i, state_nbrs):
  state, nbrs = state_nbrs
  state = step_fn(state, neighbor=nbrs)
  state = fold_fn(state)
  nbrs = nbrs.update(state.position)
  return state, nbrs

trajectory = []
for i in range(200):
  trajectory += [state.position]
  state, nbrs = lax.fori_loop(0, 100, sim_fn, (state, nbrs))
trajectory = jnp.stack(trajectory)
print(f'Did neighborlist overflow: {nbrs.did_buffer_overflow}')

In [None]:
renderer.render(box, 
                renderer.Disk(trajectory, color=colors),
                resolution=(512, 512),
                background_color=[1, 1, 1])