# Molecular Dynamics 
How to carry out molecular dynamics (MD) simulation using trained potential. 

In [None]:
from utils import set_env
set_env('.env')

In [None]:
import os
os.environ["JAX_ENABLE_X64"] = "1"
# os.environ["JAX_PLATFORM_NAME"] = "cpu" 
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 

In [None]:
from jaxip.datasets import RunnerDataset
from jaxip.potentials import NeuralNetworkPotential
from jaxip.simulation.molecular_dynamics import MDSimulator
from jaxip.simulation.thermostat import BrendsenThermostat
from jaxip.atoms import Structure
from jaxip.units import units as units

import matplotlib.pylab as plt
from pathlib import Path
from ase import Atoms
from ase.visualize import view
import ase.io
import jax.numpy as jnp

In [None]:
base_dir = Path('LJ')

## Data

In [None]:
# structures = RunnerDataset(Path(base_dir, "input.data"), persist=True) 
# print("Total number of structures:", len(structures))
# structures

In [None]:
# s = structures[0]
# s

In [None]:
# s.position

In [None]:
d = 4  # Angstrom
uc = Atoms('He', positions=[(d/2, d/2, d/2)], cell=(d, d, d))
s0 = Structure.create_from_ase(uc.repeat((10, 10, 10)))

atoms = s0.to_ase_atoms()
# view(atoms, viewer='x3d') # ase, ngl

## Potential

In [None]:
# nnp = NeuralNetworkPotential.create_from_file(Path(base_dir, "input.nn"))
# nnp.load()

In [None]:
# nnp(s)

In [None]:
# nnp.compute_force(s)

In [None]:
from jaxip.types import Array
import jax
from functools import partial

@partial(jax.jit, static_argnums=0)
def _compute_pair_energy(obj, r: Array) -> Array:
    term = obj.sigma / r
    term6 = term**6
    term12 = term6 * term6
    return 4.0 * obj.epsilon * (term12 - term6)

@partial(jax.jit, static_argnums=0)
def _compute_pair_force(obj, r: Array, R: Array) -> Array:
    term = obj.sigma / r
    term6 = term**6
    term12 = term6 * term6
    force_factor = 24.0 * obj.epsilon / (r * r) * (2 * term12 - term6)
    return jnp.expand_dims(force_factor, axis=-1) * R


class LJPotential:
    
    def __init__(
        self, 
        sigma: float, 
        epsilon: float,
        r_cutoff: float,
    ) -> None:
        self.sigma = sigma
        self.epsilon = epsilon
        self.r_cutoff = r_cutoff
        
    def __call__(self, structure: Structure) -> Array:
        r, _ = structure.calculate_distance(atom_index=jnp.arange(structure.natoms))
        mask = (0 < r) & (r < self.r_cutoff)
        pair_energies = _compute_pair_energy(self, r)
        return 0.5 * jnp.where(mask, pair_energies, 0.0).sum()
    
    def compute_force(self, structure: Structure) -> Array:
        r, R = structure.calculate_distance(atom_index=jnp.arange(structure.natoms))
        mask = (0 < r) & (r < self.r_cutoff)
        pair_forces = jnp.where(
            jnp.expand_dims(mask, axis=-1), 
            _compute_pair_force(self, r, R), 
            jnp.zeros_like(R)
        )
        return jnp.sum(pair_forces, axis=1)
    
ljpot = LJPotential(
    sigma=2.5238 * units.FROM_ANGSTROM,                 # Bohr
    epsilon= 8.507457e-04 * units.FROM_ELECTRON_VOLT,   # Hartree
    r_cutoff= 6.3095 * units.FROM_ANGSTROM,             # 2.5 * sigma
)

## MD simulator

In [None]:
# v0 = MDSimulator.generate_random_velocity(temperature=300.0, mass=s0.mass, seed=2023)
# brendsen = BrendsenThermostat(target_temperature=300.0, time_constant=50.0 * units.FROM_FEMTO_SECOND)

md = MDSimulator(
    potential=ljpot,
    initial_structure=s0,
    time_step=0.1 * units.FROM_FEMTO_SECOND,  # := 0.5 fs
    temperature=300, # K
    # initial_velocity=v0,
    # thermostat=brendsen
)

In [None]:
# Warmp up
# md.run_simulation()

# %timeit md.run_simulation(num_steps=1, output_freq=-1)

In [None]:
def run_sumulation(
    md: MDSimulator, 
    num_steps: float = 1, 
    freq: int = 100,
    filename = Path('dump.xyz'),
) -> None:
    for step in range(num_steps):
        if step % freq == 0:
            print(md.repr_physical_params(), f" Pres[kb]:{md.get_pressure() * units.TO_KILO_BAR}")
            atoms = md.structure.to_ase_atoms()
            ase.io.write(str(Path(filename)), atoms, append=True)
        md.molecular_dynamics_step() 
    
!rm -f /home/hossin/dump.xyz
run_sumulation(md, num_steps=10000, filename="/home/hossein/dump.xyz")

In [None]:
# %time md.run_simulation(num_steps=10000, output_freq=100)

In [None]:
assert jnp.allclose(md.get_com_velocity(), jnp.zeros(3))

In [None]:
md.get_pressure() * units.TO_KILO_BAR

In [None]:
md.position * units.TO_ANGSTROM

In [None]:
atoms = md.structure.to_ase_atoms()
view(atoms, viewer='x3d') # ase, ngl