# Run a simulation

Just run the entire simulation. This is 

1. A demo; in case one tries to replicate a simple run of the simulation
2. A guide to show what steps are needed to do so and why
3. A way to generate simulation data, if needed

This is only the simulation run, without any diffraction pattern calculation.

**Note**: This simulation provides both wrapped and unwrapped trajectories.

### Setup the environment

1. Make sure you select the correct kernel, probably one thats based on the envirionment created in the [README](../README.md).
2. Execute the cell below. It enables us to include python modules from the parent directory, which will be necessary later.

In [1]:
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

This cell is optional. I recommend executing JAX-MD simulations on the GPU for performance reasons. If, for whatever reason, CPU execution is preferred, un-comment the following two lines:

In [2]:
# os.environ['JAX_PLATFORM_NAME'] = "cpu"
# os.environ['JAX_PLATFORMS'] = "cpu"

Lastly, a number of imports has to be made:

In [None]:
import jax.numpy as jnp
import jax

import numpy as onp

import matplotlib.pyplot as plt
import random

from simulation.simulate_unwrapped import run_simulation

import helpers.bridge as bridge
import helpers.converters

### Define simulation parameters

Note that 343 molecules over 100.000 iterations is quite large for JAX-MD and definitely requires a dedicated GPU with >= 8 GB of graphics memory.

In [4]:
LJ_SIGMA_OO = 3.188
N_RUNTIME = 150 #picoseconds
N_MOLECULES_PER_AXIS = 7
N_SNAPSHOTS = 500

N_STEPS = N_RUNTIME * 1000 // 2 # 2 fs per step
N_MOLECULES = N_MOLECULES_PER_AXIS ** 3
BOX_SIZE = helpers.converters.get_box_length_from_density(N_MOLECULES * 3, 0.1)
initial_key = random.randint(0, 20000)

In [None]:
print("Running simulation with {} molecules for {}ps ({} steps) while taking {} snapshots.".format(N_MOLECULES_PER_AXIS**3, N_RUNTIME, N_STEPS, N_SNAPSHOTS))

### Run the simulation

Depending on the system, this might take a while. Due to the nature of JAX-MD, there is no progress callback available.

In [None]:
snapshots, snapshots_unwrapped, dt_per_snapshot = run_simulation(
    LJ_SIGMA_OO=LJ_SIGMA_OO,
    N_STEPS=N_STEPS,
    N_MOLECULES_PER_AXIS=N_MOLECULES_PER_AXIS,
    N_SLICES=N_SNAPSHOTS,
    init_key=initial_key
)

### Store the simulation results

The MDABridge builds a bridge between this simulation and the `MDAnalysis` python module. `MDAnalysis` can be used to perform various tasks, such as:

- RDFs, MSD, ...
- Diffraction patterns
- Store data, e.g. as `.dcd` or `.pdb` files

Aditionally, the bridge comes with the functionality to store and restore it self from a single file (`MDABridge.dump` and `MDABridge.from_file`).

In [None]:
mdabridge = bridge.MDABridge(
    trajectory=snapshots,
    dt_per_frame=dt_per_snapshot,
    box_size=BOX_SIZE,
    masses=None
    )

mdabridge.dump("../data/demo_simulation/dump.npz")
mdabridge.write_lammps_dcd("../data/demo_simulation/trajectory.dcd")
mdabridge.write_lammps_pdb("../data/demo_simulation/trajectory.pdb", all_frames=True)
mdabridge.write_lammps_pdb("../data/demo_simulation/trajectory.pdb", all_frames=False) # only first frame (required for Sassena)

onp.savez("../data/demo_simulation/unwrapped_trajectory2.npz", trajectory=snapshots_unwrapped, dt_per_frame=dt_per_snapshot)

parameter_string = f"""=== Simulation information ===
LJ_SIGMA_OO: {LJ_SIGMA_OO}
N_RUNTIME: {N_RUNTIME}
N_MOLECULES_PER_AXIS: {N_MOLECULES_PER_AXIS}
N_SNAPSHOTS: {N_SNAPSHOTS}
BOX_SIZE: {BOX_SIZE}
initial_key: {initial_key}
"""

with open("../data/demo_simulation/info.txt", "w") as f:
    f.write(parameter_string)