In [1]:
import jax
import jax.numpy as jnp
import jax_md
import numpy as np
import pathlib

from ase.io import read
from jax_md import units
from typing import Dict

from solar import to_jax_md
from solar import SolarPotential

import time



# Nose-Hoover NPT

The code below is the code already introduces in the `npt_jaxmd.ipynb` example. For details check it out. 

In [2]:
# Simulation parameters

timestep = 0.0005  # Time step in ps
npt_cycles = 25  # Number of Cycles in the NVT.
npt_steps = 5  # Number of NVT steps per cylce. The total number of MD steps equals nvt_cylces * nvt_steps

T_init = 300  # Initial temperature.
pressure = 1.  # Target pressure. 

chain = 3  # Number of chains in the Nose-Hoover chain.
chain_steps = 2  # Number of steps per chain.
sy_steps = 3

thermo = 100  # Thermostat value in the Nose-Hoover chain. 
baro = 1000  # Barostat value in the Nose-Hoover chain. 

# Default nose hoover chain parameters.
def default_nhc_kwargs(
    tau: jnp.float32, 
    overrides: Dict
) -> Dict:
    
    default_kwargs = {
        'chain_length': 3, 
        'chain_steps': 2, 
        'sy_steps': 3,
        'tau': tau
    }
    
    if overrides is None:
        return default_kwargs
  
    return {
      k: overrides.get(k, default_kwargs[k]) for k in default_kwargs
    }

In [3]:
# Read some molecular structure. Here we read the DHA structure from the test directory.
path_to_xyz = '../tests/test_data/water_64.xyz'
atoms = read(path_to_xyz, index=-1)
atoms.wrap()

# Repeat in each direction.
atoms = atoms * [2, 2, 2]

species = jnp.array(atoms.get_atomic_numbers())
masses = jnp.array(atoms.get_masses())
num_atoms = len(species)

# it is important that scalar is array of dim = 1. Otherwise weird behavior of jax_md.simulate.npt_box(state) which returns
# a 3x3 array when passed a jnp.array of ndim = 0. This leads to TracerConversionError. Maybe open issue on jax_md?

box = jnp.array(
    [
        np.array(atoms.get_cell())[0, 0]
    ]
)  

fractional_coordinates = True
displacement, shift = jax_md.space.periodic_general(box=box, fractional_coordinates=fractional_coordinates)

positions_init = jnp.array(atoms.get_positions())
if fractional_coordinates:
    positions_init = positions_init / box

neighbor_fn, neighbor_fn_lr, energy_fn = to_jax_md(
    potential=SolarPotential(),
    displacement_or_metric=displacement,
    box_size=box,
    species=species,
    capacity_multiplier=1.25,
    buffer_size_multiplier_sr=1.25,
    buffer_size_multiplier_lr=1.25,
    minimum_cell_size_multiplier_sr=1.0,
    disable_cell_list=False,
    fractional_coordinates=fractional_coordinates
)

# Energy function.
energy_fn = jax.jit(energy_fn)
force_fn = jax.jit(jax_md.quantity.force(energy_fn))

# Initialize the short and long-range neighbor lists.
nbrs = neighbor_fn.allocate(
    positions_init,
    box=box
)
nbrs_lr = neighbor_fn_lr.allocate(
    positions_init,
    box=box
)

# For repeated execution of cell, delete the fire_state first, otherwise jit in jupyter environment 
# can meddle with things.
try:
    del fire_state
except NameError:
    pass

min_cycles = 5
min_steps = 10

fire_init, fire_apply = jax_md.minimize.fire_descent(
    energy_fn, 
    shift, 
    dt_start = 0.05, 
    dt_max = 0.1, 
    n_min = 2
)

fire_apply = jax.jit(fire_apply)
fire_state = fire_init(
    positions_init, 
    box=box, 
    neighbor=nbrs.idx,
    neighbor_lr=nbrs_lr.idx
)

@jax.jit
def step_fire_fn(i, fire_state):
    
    fire_state, nbrs, nbrs_lr = fire_state
    
    fire_state = fire_apply(
        fire_state, 
        neighbor=nbrs.idx, 
        neighbor_lr=nbrs_lr.idx, 
        box=box
    )
    
    nbrs = nbrs.update(
        fire_state.position,
        neighbor=nbrs.idx,
        box=box
    )
    
    nbrs_lr = nbrs_lr.update(
        fire_state.position,
        neighbor_lr=nbrs_lr.idx,
        box=box
    )
    
    return fire_state, nbrs, nbrs_lr

print('Step\tE\tFmax')
print('--------------------------')
for i in range(min_cycles):
    fire_state, nbrs, nbrs_lr = jax.lax.fori_loop(
        0, 
        min_steps, 
        step_fire_fn, 
        (fire_state, nbrs, nbrs_lr)
    )
    
    E = energy_fn(
        fire_state.position, 
        neighbor=nbrs.idx, 
        neighbor_lr=nbrs_lr.idx,
        box=box
    )

    F = force_fn(
        fire_state.position, 
        neighbor=nbrs.idx, 
        neighbor_lr=nbrs_lr.idx,
        box=box
    )
    
    print('{}\t{:.2f}\t{:.2f}'.format(i, E, np.abs(F).max()))

# Dictionary with the NHC settings.
new_nhc_kwargs = {
    'chain_length': chain, 
    'chain_steps': chain_steps, 
    'sy_steps': sy_steps
}

# Convert to metal unit system.
unit = units.metal_unit_system()

timestep = timestep * unit['time']
T_init = T_init * unit['temperature']

rng_key = jax.random.PRNGKey(0)

# Choose Nose-Hoover thermostat.
init_fn, apply_fn = jax_md.simulate.npt_nose_hoover(
    energy_fn, 
    shift, 
    dt=timestep,
    pressure=pressure, 
    kT=T_init,
    barostat_kwargs=default_nhc_kwargs(baro * timestep, new_nhc_kwargs),
    thermostat_kwargs=default_nhc_kwargs(thermo * timestep, new_nhc_kwargs)
)

apply_fn = jax.jit(apply_fn)
init_fn = jax.jit(init_fn)

# Initialize state using position and neigbhors structure relaxation.
state = init_fn(
    rng_key, 
    fire_state.position, 
    box=box, 
    neighbor=nbrs.idx, 
    neighbor_lr=nbrs_lr.idx,
    kT=T_init,
    mass=masses
)

@jax.jit
def step_npt_fn(i, state):
    state, nbrs, nbrs_lr, box = state
    
    state = apply_fn(
        state,
        neighbor=nbrs.idx, 
        neighbor_lr=nbrs_lr.idx, 
        kT=T_init,
        pressure=pressure
    )
    
    box = jax_md.simulate.npt_box(state)
    
    nbrs = nbrs.update(
        state.position, 
        neighbor=nbrs.idx, 
        box=box
    )
    
    nbrs_lr = nbrs_lr.update(
        state.position,
        neighbor_lr=nbrs_lr.idx,
        box=box
    )
    
    return state, nbrs, nbrs_lr, box



jax-md partition: cell_size=0.18145161867141724, cl.id_buffer.size=2875, N=1536, cl.did_buffer_overflow=False
Step	E	Fmax
--------------------------
jax-md partition: cell_size=0.18145161867141724, cl.id_buffer.size=2875, N=1536, cl.did_buffer_overflow=Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=3/0)>
0	-5042.64	1.23
1	-5081.91	1.88
2	-5110.96	0.55
3	-5123.87	0.57
4	-5132.59	0.53


  return asarray(x, dtype=self.dtype)
  return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric,
  return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric,


# Production run
Production runs require efficient saving of the frames and additional properties during the simulations. Operations like `ase.io.write` can become very slow for larger structures with thousands of atoms and pose a bottleneck for efficient simulations in requiring orders of magnitude more time than the MD steps itself. This is, why for production runs we use `.hdf5` to allow for efficient writing to disk. 

It is build around `DataSetEntry` and `HDF5Store`, provided by the `mlff` package. Below you find an example function called `init_hdf5_store` which initializes and `.hdf5` data base under `save_to`. Instead of saving frames subsequently it is intended to first collect multiple values for the property of interest (e.g. multiple frames) and flush them in a single call to the data base. The number of values that is flushed at once is the `batch_size`. As the data base needs to know the shapes a priori, each property that is intended to be saved needs to be initialized with the correct shapes. E.g. for the positions its `(batch_size, num_atoms, 3)`. During simulation arrays of the specified size will be stacked along the 0-th axis, such that after e.g. 10 flushing operations `positions` will be of size `(10, batch_size, num_atoms, 3)`. In addition to positions we will save the `velocities` and the `box` which can be represented in terms of a fixed number of entries. For details on the box representation in `jax_md` see the `npt_jaxmd.ipynb` example.

In [4]:
from mlff.mdx.hdfdict import DataSetEntry, HDF5Store


def init_hdf5_store(
    save_to, 
    batch_size, 
    num_atoms,
    num_box_entries,
    exist_ok=False
):
    
    parent_dir = pathlib.Path(save_to).expanduser().resolve().parent
    parent_dir.mkdir(exist_ok=True)
    
    _save_to = pathlib.Path(save_to)
    if _save_to.exists():
        if exist_ok is False:
            raise RuntimeError(
                f'File exists save_to={_save_to}. '
                f'Set exists_ok=True to override file.'
            )
    
    dataset = {
        'positions': DataSetEntry(
            chunk_length=1, 
            shape=(batch_size, num_atoms, 3), 
            dtype=np.float32
        ),
        'velocities': DataSetEntry(
            chunk_length=1, 
            shape=(batch_size, num_atoms, 3), 
            dtype=np.float32
        ),
        'box': DataSetEntry(
            chunk_length=1, 
            shape=(batch_size, num_box_entries), 
            dtype=np.float32
        )
    }

    return HDF5Store(_save_to, datasets=dataset, mode='w')

In [5]:
# initialize the hdf5_store object.

# How many npt_cicle outputs are collected before flush is called. 
# I think a good default might be 100 but one prob. has to benchmark
# different choices of buffer size.
hdf5_buffer_size = 3

hdf5_store = init_hdf5_store(
    save_to='hdf5/trajectory.hdf5',
    batch_size=hdf5_buffer_size,  # number of frames that is flushed in one call
    num_atoms=num_atoms,  # number of atoms
    num_box_entries=1,  # cubic box is represented by the 3 values in jax-md
    exist_ok=False  # if save_to is allowed to exist. If true, the .hdf5 file will be overwritten
)

In [None]:
# Track total time and step times averaged over cycle.
total_time = time.time()

velocities, positions, boxes = [], [], []
print('Step\tKE\tPE\tTot. Energy\tTemp.\ttime/steps')
print('------------------------------------------------------------------------------------------------------------')
for i in range(npt_cycles):

    old_time = time.time()
    
    # Do `npt_steps` NPT steps.
    new_state, nbrs, nbrs_lr, new_box = jax.block_until_ready(
        jax.lax.fori_loop(
            0, 
            npt_steps, 
            step_npt_fn, 
            (state, nbrs, nbrs_lr, box)
        )
    )
    
    new_time = time.time()
    
    # Check for overflor of both sr and lr neighbors.
    if nbrs.did_buffer_overflow:
        print('Neighbor list overflowed, reallocating.')
        nbrs = neighbor_fn.allocate(state.position, box = box)
        if nbrs_lr.did_buffer_overflow:
            print('Long-range neighbor list also overflowed, reallocating.')
            nbrs_lr = neighbor_fn_lr.allocate(state.position, box = box)
    elif nbrs_lr.did_buffer_overflow:
        print('Long-range neighbor list overflowed, reallocating.')
        nbrs_lr = neighbor_fn_lr.allocate(state.position, box = box)
    else:
        state = new_state
        box = new_box

        # Only save if there are no overflows in the neighbor lists.
        
        ######## HDF5 saving part ########
        
        # Append velocities, calculated from the momenta.
        velocities.append(
            jnp.divide(
                state.momentum,
                state.mass
            )
        )

        # Append positions of the last frame of `npt_steps` MD steps.
        positions.append(
            jax_md.space.transform(box=box, R=state.position)  # this is only correct for cubic cells with equally long lattice vectors
        )
        
        # Append current box.
        boxes.append(
            box
        )
    
    # If buffer is full and positions is not empty, append velocities, positions and boxes for in 
    # `hdf5_buffer_size` frames to the trajectory.hdf5
    if (len(positions) % hdf5_buffer_size == 0) and (len(positions) > 0):   
        
        # Create dictionary from stacked velocities and positions.
        step_data = dict(
            velocities=jnp.stack(velocities, axis=0),  # (hdf5_buffer_size, num_atoms, 3)
            positions=jnp.stack(positions, axis=0),  # (hdf5_buffer_size, num_atoms, 3)
            box=jnp.stack(boxes, axis=0)            # (hdf5_buffer_size, num_box_entries)
        )
        
        # Make the jax.numpy arrays to numpy arrays.
        step_data = jax.tree.map(lambda u: np.asarray(u), step_data)

        # Append the step data to the hdf5 store object.
        hdf5_store.append(
            step_data
        )
        
        # Create fresh lists.
        velocities, positions, boxes = [], [], []
    
        ######## End of HDF5 saving part ########

    # Calculate some quantities for printing
    KE = jax_md.quantity.kinetic_energy(
        momentum=state.momentum,
        mass=state.mass
    )
    
    PE = energy_fn(
        state.position,
        neighbor=nbrs.idx,
        neighbor_lr=nbrs_lr.idx, 
        box=box
    )
    
    T = jax_md.quantity.temperature(
        momentum=state.momentum,
        mass = state.mass
    ) / unit['temperature']

    print(
        f'{(i+1)*npt_steps}\t{KE:.2f}\t{PE:.2f}\t{KE+PE:.3f}\t{T:.1f}\t{(new_time - old_time) / npt_steps:.2f}'
    )

print('Total_time: ', time.time() - total_time)

# Clear all caches
jax.clear_caches()

Step	KE	PE	Tot. Energy	Temp.	time/steps
------------------------------------------------------------------------------------------------------------
jax-md partition: cell_size=0.18145161867141724, cl.id_buffer.size=2875, N=1536, cl.did_buffer_overflow=Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=3/0)>
5	42.14	-5114.97	-5072.826	212.2	10.79


# Postprocess the Trajectory

In [39]:
import h5py

file = h5py.File('hdf5/trajectory.hdf5', mode='r')

# Load velocities and reshape.
velocities = np.array(file['velocities'])
# velocities have shape (num_flushes, hdf5_buffer_size, num_atoms, 3)
print(velocities.shape)

# reshape velocities such that one has only one leading axis of consecutive velocities
# separated by delta_t = npt_steps * time_step [ps].
velocities = velocities.reshape(-1, *file['velocities'].shape[-2:])

print(velocities.shape)  # (num_flushes*hdf5_buffer_size, num_atoms, 3)

(9, 3, 1536, 3)
(27, 1536, 3)
