<a href="https://colab.research.google.com/github/routhleck/jax-md/blob/main/examples-with-unit/3-Physical_Quantities.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Computing physical quantities

In [1]:
#@title Imports and Definitions

#!pip install jax-md
# !pip install -q git+https://www.github.com/google/jax-md

import numpy as onp

import jax.numpy as jnp
from jax import config
config.update('jax_enable_x64', True)

from jax import random
from jax import jit, lax, grad, vmap, hessian
import jax.scipy as jsp
import brainstate as bst
import brainunit as u

from jax_md import space, energy, smap, simulate, minimize, util, elasticity, quantity, partition
from jax_md.colab_tools import renderer

f32 = jnp.float32
f64 = jnp.float64

from functools import partial

import matplotlib
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 16})

def format_plot(x, y):  
  plt.grid(True)
  plt.xlabel(x, fontsize=20)
  plt.ylabel(y, fontsize=20)
  
def finalize_plot(shape=(1, 0.7)):
  plt.gcf().set_size_inches(
    shape[0] * 1.5 * plt.gcf().get_size_inches()[1], 
    shape[1] * 1.5 * plt.gcf().get_size_inches()[1])
  
def run_minimization_while(energy_fn, R_init, shift, max_grad_thresh = 1e-12, max_num_steps=1000000, **kwargs):
  init, apply=minimize.fire_descent(jit(energy_fn), shift, **kwargs)
  apply = jit(apply)

  @jit
  def get_maxgrad(state):
    return jnp.amax(jnp.abs(state.force))

  @jit
  def cond_fn(val):
    state, i = val
    return jnp.logical_and(get_maxgrad(state) > max_grad_thresh, i < max_num_steps)

  @jit
  def body_fn(val):
    state, i = val
    return apply(state), i + 1

  state = init(R_init)
  state, num_iterations = lax.while_loop(cond_fn, body_fn, (state, 0))

  return state.position, get_maxgrad(state), num_iterations

def run_minimization_while_neighbor_list(energy_fn, neighbor_fn, R_init, shift,  
                                         max_grad_thresh = 1e-12, max_num_steps = 1000000, 
                                         step_inc = 1000, verbose = False, nbrs = None, **kwargs):
  if nbrs is None:
    nbrs = neighbor_fn.allocate(R_init)

  init,apply=minimize.fire_descent(jit(energy_fn), shift, **kwargs)
  apply = jit(apply)

  @jit
  def get_maxgrad(state):
    return jnp.amax(jnp.abs(state.force))

  @jit
  def body_fn(state_nbrs, t):
    state, nbrs = state_nbrs
    nbrs = neighbor_fn.update(state.position, nbrs)
    state = apply(state, neighbor=nbrs)
    return (state, nbrs), 0
  
  state = init(R_init, neighbor=nbrs)
  
  step = 0
  while step < max_num_steps:
    rtn_state, _ = lax.scan(body_fn, (state, nbrs), step + jnp.arange(step_inc))
    new_state, nbrs = rtn_state
    # If the neighbor list overflowed, rebuild it and repeat part of 
    # the simulation.
    if nbrs.did_buffer_overflow:
      print('Buffer overflow. Reallocating...')
      nbrs = neighbor_fn.allocate(state.position)
    else:
      state = new_state
      step += step_inc
      if get_maxgrad(state) <= max_grad_thresh:
        break

  if verbose:
    print('Minimized the energy in {} minimization loops ({} steps each) and reached a final \
maximum gradient of {}'.format(step//step_inc, step_inc, get_maxgrad(state)))

  return state.position, get_maxgrad(state), nbrs, step

def hessian2dynamicalmatrix(H, masses, species=None, reshape=False):
  """ Convert a Hessian matrix into a dynamical matrix

    A Dynamical matrix math:`D_{ij}^{\alpha \beta}` is defined by 
        math:`D_{ij}^{\alpha \beta} = H_{ij}^{\alpha \beta} / \sqrt( m_i m_j )`
    where math:`i` and math:`j` run over particles, math:`\alpha` and
    math:`\beta` run over the spatial dimensions, math:`H_{ij}^{\alpha \beta}` 
    is the Hessian matrix, and math:`m_i` is the mass.
  
  Args:
    H: array of shape (N,dimension,N,dimension) representing the Hessian of an 
        energy function.
    masses: array of shape (N,) or (N_species,) giving the mass of each particle
        or species type
    species: array of shape (N,) giving species information
    reshape: boolean. If true, the output is reshaped to be 
        (N*dimension,N*dimension)

  Return:
    The dynamical matrix as an array of shape (N,dimension,N,dimension) or 
      (N*dimension,N*dimension) if reshape=True
  """
  if species is not None:
    masses = masses[species]

  m_rescale = 1 / jnp.sqrt(masses)
  D = jnp.einsum('iajb,i,j->iajb',H, m_rescale, m_rescale)

  if reshape:
    D = D.reshape(D.shape[0]*D.shape[1], D.shape[0]*D.shape[1])
  return D

  from .autonotebook import tqdm as notebook_tqdm


Define a system

In [2]:
N = 1000
dimension = 2

# Define boundary conditions
density = 0.8 / u.angstrom**dimension

box_size = quantity.box_size_at_number_density(N, 0.8, dimension)
displacement, shift = space.periodic(box_size) 

# Define initial positions
key = random.PRNGKey(0)
R_init = box_size * random.uniform(key, (N, dimension), dtype=jnp.float64)

# The system ought to be a 50:50 mixture of two types of particles, one
# large and one small.
sigma = jnp.array([[1.0, 1.2], [1.2, 1.4]])
N_2 = int(N / 2)
species = jnp.where(jnp.arange(N) < N_2, 0, 1)
diameters = sigma.diagonal()[species]
masses = sigma.diagonal() ** dimension

# Define energy and neighbor functions
neighbor_fn, energy_fn = energy.soft_sphere_neighbor_list(
    displacement,
    box_size,
    species,
    sigma,
    dr_threshold=0.2*u.angstrom,
    format=partition.Sparse)

# Allocate the neighbor list
nbrs_init = neighbor_fn.allocate(R_init)

  R_init = box_size * random.uniform(key, (N, dimension), dtype=jnp.float64)


## Run an NVT simulation

In [3]:
dt = 5e-3 * u.fsecond
kT = 0.01 * u.fsecond

init, apply = simulate.nvt_nose_hoover(energy_fn, shift, dt, kT)
state = init(key, R_init, neighbor=nbrs_init)

steps = 10000
write_every = 100
def step_fn(i, state_nbrs_log):
  state, nbrs, log = state_nbrs_log
  nbrs = nbrs.update(state.position)

  # Log information about the simulation.
  T = quantity.temperature(momentum=state.momentum)
  log['kT'] = log['kT'].at[i].set(T)
  H = simulate.nvt_nose_hoover_invariant(energy_fn, state, kT, neighbor=nbrs)
  log['H'] = log['H'].at[i].set(H)
  # Record positions every `write_every` steps.
  log['position'] = lax.cond(i % write_every == 0,
                             lambda p: \
                             p.at[i // write_every].set(state.position),
                             lambda p: p,
                             log['position'])

  # Take a simulation step.
  state = apply(state, kT=kT, neighbor=nbrs)
  
  return state, nbrs, log
  
log = {
    'kT': jnp.zeros((steps,)) * u.kelvin,
    'H': jnp.zeros((steps,)) * u.eV,
    'position': jnp.zeros((steps // write_every,) + R_init.shape) * u.angstrom
}

state, nbrs, log = lax.fori_loop(0, steps, step_fn, (state, nbrs_init, log))

  return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric,
  return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric,


In [None]:
from jax_md.colab_tools import renderer

clrs = [[0.8, 0.8, 1.0], [0.4, 0.2, 1.0]]
colors=jnp.array(clrs[0]*N_2 + clrs[1]*N_2).reshape(-1,3)
renderer.render(
    box_size,
    {'particles': renderer.Disk(log['position'], diameter=diameters, color=colors)},
    resolution=(512, 512)
)

### Stress and pressure

Calculate the stress tensor

In [4]:
quantity.stress(energy_fn,
                state.position, 
                box_size, 
                velocity=state.velocity, 
                neighbor=nbrs)

  return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric,


ArrayImpl([[ 0.06686106, -0.00097949],
           [-0.00097949,  0.06656311]], dtype=float32) * 1.602176565 * 10.0^11 * pascal

Calculate the pressure

In [5]:
quantity.pressure(energy_fn, 
                  state.position, 
                  box_size, 
                  velocity=state.velocity, 
                  neighbor=nbrs)

0.05854513 * 1.602176565 * 10.0^11 * pascal