In [1]:
from ase.atoms import Atoms
from ase.build import bulk
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary
from ase.calculators.lj import LennardJones

In [5]:
def initialize_cubic_argon(multiplier=5, sigma=2.0, epsilon=1.5, rc=10.0, ro=6.0, temperature_K: int = 30) -> Atoms:
    atoms = bulk("Ar", cubic=True) * [multiplier, multiplier, multiplier]
    MaxwellBoltzmannDistribution(atoms, temperature_K=temperature_K)
    Stationary(atoms)
    atoms.calc = LennardJones(sigma=sigma, epsilon=epsilon, rc=rc, ro=ro, smooth=True)
    return atoms

In [None]:
atoms = initialize_cubic_argon()
n = len(atoms

## Influence of `batch_size`
- Within JAX-MD simulations, let's call the number of steps that are dispatched to XLA in one chunk `batch_size`.
- e.g. `steps = 1000`, `batch_size = 5`
  - We perform 1000 MD steps total
  - We do these 1000 steps in batches of 5, i.e. `lax.fori_loop(0, 5, step_fn, (state, neighbors))`
  - Every 5 steps we return to Python, for example to check for neighbor list overflows
  
- A smaller `batch_size` results in more context switches between XLA and Python
- 