# Molecular Simulation 
Carrying out molecular dynamics (MD) and Monte-Carlo (MC) simulations using a trained potential. 

In [None]:
from utils import set_env
set_env('.env')

In [None]:
import os
os.environ["JAX_ENABLE_X64"] = "1"
# os.environ["JAX_PLATFORM_NAME"] = "cpu" 
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 

In [None]:
from jaxip.datasets import RunnerDataset
from jaxip.potentials import NeuralNetworkPotential
from jaxip.simulation import (
    MDSimulator, 
    BrendsenThermostat, 
    MCSimulator,
    run_simulation
)
from jaxip.atoms import Structure
from jaxip.units import units as units

import matplotlib.pylab as plt
from pathlib import Path
from ase import Atoms
from ase.visualize import view
import ase.io
import jax.numpy as jnp

In [None]:
base_dir = Path('LJ')

## Data

In [None]:
d = 6  # Angstrom
uc = Atoms('He', positions=[(d/2, d/2, d/2)], cell=(d, d, d))
s0 = Structure.create_from_ase(uc.repeat((7, 7, 7)))

# d = 10  # Angstrom
# uc = Atoms('Ar', positions=[(d/2, d/2, d/2)], cell=(d, d, d))
# s0 = Structure.create_from_ase(uc.repeat((7, 7, 7)))

atoms = s0.to_ase_atoms()
# view(atoms, viewer='x3d') # ase, ngl

## Potential

In [None]:
from jaxip.types import Array
import jax
from functools import partial

@partial(jax.jit, static_argnums=0)
def _compute_pair_energy(obj, r: Array) -> Array:
    term = obj.sigma / r
    term6 = term**6
    return 4.0 * obj.epsilon * term6 * (term6 - 1.0)


@partial(jax.jit, static_argnums=0)
def _compute_pair_force(obj, r: Array, R: Array) -> Array:
    term = obj.sigma / r
    term6 = term**6
    force_factor = -24.0 * obj.epsilon / (r * r) * term6 * (2.0 * term6 - 1.0)
    return jnp.expand_dims(force_factor, axis=-1) * R


class LJPotential:
    def __init__(
        self,
        sigma: float,
        epsilon: float,
        r_cutoff: float,
    ) -> None:
        self.sigma = sigma
        self.epsilon = epsilon
        self.r_cutoff = r_cutoff

    def __call__(self, structure: Structure) -> Array:
        r, _ = structure.calculate_distance(atom_index=jnp.arange(structure.natoms))
        mask = (0 < r) & (r < self.r_cutoff)
        pair_energies = _compute_pair_energy(self, r)
        return 0.5 * jnp.where(mask, pair_energies, 0.0).sum()

    def compute_force(self, structure: Structure) -> Array:
        r, R = structure.calculate_distance(atom_index=jnp.arange(structure.natoms))
        mask = (0 < r) & (r < self.r_cutoff)
        pair_forces = jnp.where(
            jnp.expand_dims(mask, axis=-1),
            _compute_pair_force(self, r, R),
            jnp.zeros_like(R),
        )
        return jnp.sum(pair_forces, axis=1)


# He
ljpot = LJPotential(
    sigma=2.5238 * units.FROM_ANGSTROM,  # Bohr
    epsilon=4.7093e-04 * units.FROM_ELECTRON_VOLT,  # Hartree
    r_cutoff=6.3095 * units.FROM_ANGSTROM,  # 2.5 * sigma
)

# Ar
# ljpot = LJPotential(
#     sigma=3.405 * units.FROM_ANGSTROM,                       # Bohr
#     epsilon=0.01032439284 * units.FROM_ELECTRON_VOLT,        # Hartree
#     r_cutoff=8.5125 * units.FROM_ANGSTROM,                   # 2.5 * sigma
# )

## Molecular Dynamics (MD)

In [None]:
# v0 = MDSimulator.generate_random_velocity(temperature=300.0, mass=s0.mass, seed=2023)
# brendsen = BrendsenThermostat(target_temperature=300.0, time_constant=50.0 * units.FROM_FEMTO_SECOND)

md = MDSimulator(
    potential=ljpot,
    initial_structure=s0,
    time_step=0.5 * units.FROM_FEMTO_SECOND,
    temperature=300, # K
    # initial_velocity=v0,
    # thermostat=brendsen
)

In [None]:
# Warmp up
# run_simulation(md)

# %timeit md.run_simulation(num_steps=1, output_freq=-1)

In [None]:
run_simulation(md, num_steps=10000, output_freq=100, filename="md.xyz")

In [None]:
md.get_pressure() * units.TO_KILO_BAR

In [None]:
md.position * units.TO_ANGSTROM

In [None]:
atoms = md.get_structure().to_ase_atoms()
# view(atoms, viewer='x3d') # ase, ngl

## Monte Carlo (MC)

In [None]:
mc = MCSimulator(
    potential=ljpot,
    initial_structure=s0,
    temperature=300, # K
    translate_step=0.3 * units.FROM_ANGSTROM,
    movements_per_step=10,
)

In [None]:
# Warmp up
# run_simulation(mc)

# %timeit run_simulation(mc, num_steps=1, output_freq=-1)

In [None]:
run_simulation(mc, num_steps=10000, output_freq=100, filename=Path("mc.xyz"))

In [None]:
mc.position * units.TO_ANGSTROM

In [None]:
atoms = mc.get_structure().to_ase_atoms()
# view(atoms, viewer='x3d') # ase, ngl