<a href="https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/sand_castle.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Imports & Utils


!pip install jax-md
!wget https://raw.githubusercontent.com/google/jax-md/main/examples/models/sand_castle.png

import imageio
import jax.numpy as jnp

from IPython.display import set_matplotlib_formats
set_matplotlib_formats('pdf', 'svg')
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style(style='white')
background_color = [56 / 256] * 3

def make_from_image(filename, size_in_pixels):
  position = []
  angle = []
  color = []

  img = imageio.imread(filename)

  scale = 2**(1/6)
  ratio = jnp.sqrt(1 - 0.25)
  for i, y in enumerate(range(0, img.shape[0], size_in_pixels)):
    for x in range(0, img.shape[1], size_in_pixels):
      r, g, b, a = img[y, x]
      if a == 255:
        hshift = size_in_pixels * (i % 2) / 2.0
        position += [[scale * (x + hshift) / size_in_pixels, scale * (img.shape[0] - y) / size_in_pixels * ratio]]
        color += [[r / 255, g / 255, b / 255]]
  img_size = jnp.array(img.shape[:2]).T / size_in_pixels * scale
  box_size = jnp.max(img_size) * 1.5
  position = jnp.array(position, jnp.float64) + box_size / 2.0 - img_size / 2
  color = jnp.array(color, jnp.float64)

  return box_size, position, color

# Sand Castle

In this demo we simulate a sand castle and then demolish it using a projectile.

## Load the sand castle

In [None]:
box, positions, colors = make_from_image('sand_castle.png', 24)

In [None]:
from jax_md.colab_tools import renderer

renderer.render(box,
                renderer.Disk(positions, color=colors))

In [None]:
print(f'There are {len(positions)} grains.')

## Spaces

In [None]:
from jax_md import space

displacement_fn, shift_fn = space.periodic(box)

In [None]:
positions[0]

In [None]:
displacement_fn(positions[0], positions[-1])

In [None]:
shift_fn(positions[0], jnp.array([10.0, 0.0]))

## Energy

"Energy" in Physics plays a similar role to "Loss" in machine learning. 

Write down an energy function between two grains of sand, $\epsilon(r)$. 

The total energy will be the sum of all pairs of energies.

$$E = \sum_{i,j} \epsilon(r_{ij})$$

where $r_{ij}$ is the distance between grain $i$ and grain $j$.


We want to model wet sand:

*   Grains are hard (no interpenetration).
*   Grains stick together a little bit.
*   Grains far away from one another don't notice each other.

In [None]:
from jax_md import energy

rs = jnp.linspace(0.5, 2.5)
plt.plot(rs, energy.lennard_jones(rs))

plt.ylim([-1, 1])
plt.xlim([0, 2.5])
plt.xlabel('$r_{ij}$')
plt.ylabel('$\\epsilon$')

In [None]:
sand_energy = energy.lennard_jones_pair(displacement_fn)

sand_energy(positions)

## Simulate

In [None]:
from jax import random

simulation_steps = 10000
write_every = 50
key = random.PRNGKey(1)

In [None]:
from jax_md import simulate
from jax import jit

init_fn, step_fn = simulate.nvt_langevin(sand_energy, shift_fn, dt=5e-3, kT=0.0, gamma=1e-2)

sand = init_fn(key, positions)
step_fn = jit(step_fn)

In [None]:
trajectory = []

for i in range(simulation_steps):
  if i % write_every == 0:
    trajectory += [sand.position]
    
  sand = step_fn(sand)

trajectory = jnp.stack(trajectory)

In [None]:
renderer.render(box, renderer.Disk(trajectory, color=colors))

## Simulate slightly faster...

In [None]:
from jax import lax

def simulation_fn(i, sand_trajectory):
  sand, trajectory = sand_trajectory

  trajectory = trajectory.at[i].set(sand.position)
  sand = lax.fori_loop(0, write_every, lambda _, s: step_fn(s), sand)

  return sand, trajectory

In [None]:
write_steps = simulation_steps // write_every
n = positions.shape[0]

sand = init_fn(random.PRNGKey(0), positions)
trajectory = jnp.zeros((write_steps, n, 2))
sand, trajectory = lax.fori_loop(0, write_steps, simulation_fn, (sand, trajectory))

In [None]:
renderer.render(box, renderer.Disk(trajectory, color=colors))

## Let's blow it up!

### The projectile

In [None]:
projectile = jnp.array([1.0, box / 3.0])

radius = jnp.array(2.0)
strength = 1000.0
velocity = jnp.array([3e-2, 0.0])

Model the projectile by adding a term to the energy,

$$E = \sum_{i,j}\epsilon(r_{ij}) + \sum_i \epsilon_p(r_{ip})$$

where $r_{ip}$ is the distance between grain $i$ and the projectile.

Want the projectile to only repel the sand (no attraction).

In [None]:
from jax_md import energy

rs = jnp.linspace(0.5, 2.5)
plt.plot(rs, energy.lennard_jones(rs))
plt.plot(rs, energy.soft_sphere(rs, epsilon=strength))

plt.ylim([-1, 10])
plt.xlim([0, 2.5])
plt.xlabel('$r_{ij}$')
plt.ylabel('$\\epsilon$')

In [None]:
def projectile_energy(sand, projectile):
  distance = jnp.linalg.norm(sand - projectile, axis=-1)
  e = energy.soft_sphere(distance, sigma=radius + 1.0, epsilon=strength)
  return jnp.sum(e)

def total_energy(sand, projectile, **kwargs):
  return sand_energy(sand) + projectile_energy(sand, projectile)

### Run the simulation

In [None]:
from jax_md import dataclasses

@dataclasses.dataclass
class SandCastle:
  sand: simulate.NVTLangevinState
  projectile: jnp.ndarray

In [None]:
simulation_steps = 10000
write_every = 50
write_steps = simulation_steps // write_every

In [None]:
from jax_md import simulate

init_fn, step_fn = simulate.nvt_langevin(total_energy, shift_fn, dt=5e-3, kT=0.0)

In [None]:
from jax import lax

def simulation_fn(i, state_trajectory):
  state, traj = state_trajectory

  traj = SandCastle(
      traj.sand.at[i].set(state.sand.position),
      traj.projectile.at[i].set(state.projectile)
  )

  def total_step_fn(_, state):
    return SandCastle(
        step_fn(state.sand, projectile=state.projectile),
        state.projectile + velocity
    )

  state = lax.fori_loop(0, write_every, total_step_fn, state)

  return state, traj

In [None]:
n = positions.shape[0]

state = SandCastle(
    init_fn(key, positions, projectile=projectile),
    projectile
)
trajectory = SandCastle(
    jnp.zeros((write_steps, n, 2)),
    jnp.zeros((write_steps, 2))
)

state, trajectory = lax.fori_loop(0, write_steps, simulation_fn, (state, trajectory))

In [None]:
renderer.render(
    box,
    {
        'sand': renderer.Disk(trajectory.sand, color=colors),
        'projectile': renderer.Disk(trajectory.projectile[:, None, :], 
                                    diameter=radius * 2)
    }
)

## Scaling Up

So far at each step we have been computing the interaction between every pair of grains.

But grains that are far apart don't affect each other.

In [None]:
box, positions, colors = make_from_image('sand_castle.png', 6)

In [None]:
len(positions)

In [None]:
from jax_md.colab_tools import renderer

renderer.render(box, renderer.Disk(positions, color=colors))

In [None]:
displacement_fn, shift_fn = space.periodic(box)

### Neighbor lists

In [None]:
neighbor_fn, sand_energy = energy.lennard_jones_neighbor_list(displacement_fn, box)

In [None]:
nbrs = neighbor_fn.allocate(positions)

In [None]:
nbrs.idx.shape

In [None]:
def total_energy(sand, projectile, neighbor, **kwargs):
  return sand_energy(sand, neighbor) + projectile_energy(sand, projectile)

### Simulation

In [None]:
simulation_steps = 30000
write_every = 400
write_steps = simulation_steps // write_every

projectile = jnp.array([1.0, box / 3.0])
radius = jnp.array(8.0)

In [None]:
from jax_md import partition

@dataclasses.dataclass
class SandCastle:
  sand: simulate.NVTLangevinState
  projectile: jnp.ndarray
  neighbor: partition.NeighborList

In [None]:
from jax_md import simulate

init_fn, step_fn = simulate.nvt_langevin(total_energy, shift_fn, dt=5e-3, kT=0.0, gamma=1e-2)

In [None]:
from jax import lax

def simulation_fn(i, state_trajectory):
  state, traj = state_trajectory

  traj = SandCastle(
     traj.sand.at[i].set(state.sand.position),
     traj.projectile.at[i].set(state.projectile),
     None 
  )

  def total_step_fn(_, state):
    sand = step_fn(state.sand,
                   projectile=state.projectile,
                   neighbor=state.neighbor)
    projectile = state.projectile + velocity
    neighbor = state.neighbor.update(state.sand.position)
    return SandCastle(sand, projectile, neighbor)

  state = lax.fori_loop(0, write_every, total_step_fn, state)

  return state, traj

In [None]:
n = positions.shape[0]

state = SandCastle(
    init_fn(random.PRNGKey(0), positions, projectile=projectile, neighbor=nbrs),
    projectile,
    nbrs
)
trajectory = SandCastle(
    jnp.zeros((write_steps, n, 2)),
    jnp.zeros((write_steps, 2)),
    None
)

state, trajectory = lax.fori_loop(0, write_steps, simulation_fn, (state, trajectory))

In [None]:
state.neighbor.did_buffer_overflow

In [None]:
renderer.render(
    box,
    {
        'sand': renderer.Disk(trajectory.sand, color=colors),
        'projectile': renderer.Disk(trajectory.projectile[:, None, :], 
                                    diameter=radius * 2)
    },
    buffer_size=10
)