In [3]:
import jax
import jax.numpy as jnp
import jax_md
import numpy as np

from ase.io import read
from jax_md import units
from typing import Dict

from sup_gems import to_jax_md
from sup_gems import SupGemsPotential

import time



In [4]:
# Some helper functions used throughout the example notebook.

# Default nose hoover chain parameters.
def default_nhc_kwargs(
    tau: jnp.float32, 
    overrides: Dict
) -> Dict:
    
    default_kwargs = {
        'chain_length': 3, 
        'chain_steps': 2, 
        'sy_steps': 3,
        'tau': tau
    }
    
    if overrides is None:
        return default_kwargs
  
    return {
      k: overrides.get(k, default_kwargs[k]) for k in default_kwargs
    }

# Read the Molecule

Start by reading the molecular structure. Here we assume some ASE digestable file, e.g. and `XYZ` file.

In [5]:
# Read some molecular structure. Here we read the DHA structure from the test directory.
path_to_xyz = '../tests/test_data/dha.xyz'
atoms = read(path_to_xyz, index=-1)

positions_init = jnp.array(atoms.get_positions())
species = jnp.array(atoms.get_atomic_numbers())
masses = jnp.array(atoms.get_masses())

# Prepare JAX-MD Energy Function

JAX-MD is build around the energy function and the neighborlist. Since we have short- and long-range neighbors, we have two neighborlists compared to the standard case of only a single one. We use the convenience interface `to_jax_md` provided within this repository, which returns the SO3LR energy function as well as the two neighborlists. It takes as input the `SupGemsPotential` itself, as well as a displacement function which determines how update the atomic positions. The `displacement` function calculates the displacement vectors between atoms and can differ, e.g. based on the fact if simulations are performed under perdioic boundary conditions (PBCs) or not. Also positions can be either represented in real or fractional coordinates. For more details on displacement functions as well as on fractional coordinates we refer to the JAX-MD docs. Here we simulate DHA in vacuum (no PBCs) which allows to use `jax_md.space.free` to obtain an displacement function. It also returns a `shift` function which determines how to update the atomic positions during simulation.

In [6]:
displacement, shift = jax_md.space.free()
box = None  # No simulation box. 

neighbor_fn, neighbor_fn_lr, energy_fn = to_jax_md(
    potential=SupGemsPotential(),
    displacement_or_metric=displacement,
    box_size=box,
    species=species,
    capacity_multiplier=1.25,
    buffer_size_multiplier_sr=1.25,
    buffer_size_multiplier_lr=1.25,
    minimum_cell_size_multiplier_sr=1.0,
    disable_cell_list=True  # Cell list partitioning can only be applied if there is a simulation box.
)

# Energy function.
energy_fn = jax.jit(energy_fn)
force_fn = jax.jit(jax_md.quantity.force(energy_fn))

# Initialize the short and long-range neighbor lists.
nbrs = neighbor_fn.allocate(
    positions_init,
    box=box
)
nbrs_lr = neighbor_fn_lr.allocate(
    positions_init,
    box=box
)

  return array(a, dtype=dtype, copy=bool(copy), order=order)


In [7]:
# Check that the energy function is working.
energy_fn(
    positions_init, 
    neighbor=nbrs.idx, 
    neighbor_lr=nbrs_lr.idx, 
    box=box
)

Array(-180.43394, dtype=float32)

# Structure Relaxation

In [8]:
min_cycles = 5
min_steps = 10

fire_init, fire_apply = jax_md.minimize.fire_descent(
    energy_fn, 
    shift, 
    dt_start = 0.05, 
    dt_max = 0.1, 
    n_min = 2
)

fire_apply = jax.jit(fire_apply)
fire_state = fire_init(
    positions_init, 
    box=box, 
    neighbor=nbrs.idx,
    neighbor_lr=nbrs_lr.idx
)

@jax.jit
def step_fire_fn(i, fire_state):
    
    fire_state, nbrs, nbrs_lr = fire_state
    
    fire_state = fire_apply(
        fire_state, 
        neighbor=nbrs.idx, 
        neighbor_lr=nbrs_lr.idx, 
        box=box
    )
    
    nbrs = nbrs.update(
        fire_state.position,
        neighbor=nbrs.idx
    )
    
    nbrs_lr = nbrs_lr.update(
        fire_state.position,
        neighbor_lr=nbrs_lr.idx
    )
    
    return fire_state, nbrs, nbrs_lr

print('Step\tE\tFmax')
print('----------------------------------------')
for i in range(min_cycles):
    fire_state, nbrs, nbrs_lr = jax.lax.fori_loop(
        0, 
        min_steps, 
        step_fire_fn, 
        (fire_state, nbrs, nbrs_lr)
    )
    
    E = energy_fn(
        fire_state.position, 
        neighbor=nbrs.idx, 
        neighbor_lr=nbrs_lr.idx,
        box=box
    )

    F = force_fn(
        fire_state.position, 
        neighbor=nbrs.idx, 
        neighbor_lr=nbrs_lr.idx,
        box=box
    )
    
    print('{}\t{:.2f}\t{:.2f}'.format(i, E, np.abs(F).max()))
    

Step	E	Fmax
----------------------------------------
0	-180.91	0.51
1	-181.09	0.22
2	-181.20	0.16
3	-181.28	0.12
4	-181.32	0.06


In [9]:
# Print the delta between original and optimized positions for the first three atoms.
(positions_init - fire_state.position)[:3]

Array([[-0.06604004,  0.05084229,  0.06851196],
       [-0.00616455, -0.01034546,  0.04959106],
       [ 0.00949097,  0.0005188 , -0.03643799]], dtype=float32)

# Nose-Hoover Chain NVT Simulation

In [10]:
# Simulation parameters

timestep = 0.0005  # Time step in ps
nvt_cycles = 25  # Number of Cycles in the NVT.
nvt_steps = 100  # Number of NVT steps per cylce. The total number of MD steps equals nvt_cylces * nvt_steps

T_init = None  # Initial temperature.
T_nvt = 300  # Target temperature. 

chain = 3  # Number of chains in the Nose-Hoover chain.
chain_steps = 2  # Number of steps per chain.
sy_steps = 3
thermo = 100  # Thermo value in the Nose-Hoover chain. 

# Set the temprature at initialization
if T_init is None:
    T_init = float(T_nvt / nvt_cycles) 
else:
    T_init = T_init

# Dictionary with the NHC settings.
new_nhc_kwargs = {
    'chain_length': chain, 
    'chain_steps': chain_steps, 
    'sy_steps': sy_steps
}

# Convert to metal unit system.
unit = units.metal_unit_system()

timestep = timestep * unit['time']
T_init = T_init * unit['temperature']
T_nvt = T_nvt * unit['temperature']

rng_key = jax.random.PRNGKey(0)

  return asarray(x, dtype=self.dtype)


In [11]:
# Chosse Nose-Hoover thermostat.
init_fn, apply_fn = jax_md.simulate.nvt_nose_hoover(
    energy_fn, 
    shift, 
    dt=timestep, 
    kT=T_init,
    box=box,
    thermostat_kwargs=default_nhc_kwargs(thermo * timestep, new_nhc_kwargs)
)

apply_fn = jax.jit(apply_fn)
init_fn = jax.jit(init_fn)

# Initialize state using position and neigbhors structure relaxation.
state = init_fn(
    rng_key, 
    fire_state.position, 
    box=box, 
    neighbor=nbrs.idx, 
    neighbor_lr=nbrs_lr.idx,
    kT=T_init,
    mass=masses
)

@jax.jit
def step_nvt_fn(i, state):
    state, nbrs, nbrs_lr, box, temp_i = state
    
    state = apply_fn(
        state, 
        neighbor=nbrs.idx, 
        neighbor_lr=nbrs_lr.idx, 
        kT=temp_i,
        box=box
    )
    
    nbrs = nbrs.update(
        state.position, 
        neighbor=nbrs.idx, 
        box = box
    )
    
    nbrs_lr = nbrs_lr.update(
        state.position, 
        neighbor_lr=nbrs_lr.idx, 
        box = box
    )
    
    return state, nbrs, nbrs_lr, box, temp_i

# Track total time and step times averaged over cycle.
total_time = time.time()

positions_md = []

print('Step\tKE\tPE\tTotal Energy\tTemperature\tH\ttime/steps')
print('-----------------------------------------------------------------------------------')
for i in range(nvt_cycles):
    
    temp_i = T_nvt

    old_time = time.time()
    
    # Do `nvt_steps` NVT steps.
    new_state, nbrs, nbrs_lr, new_box, temp_i = jax.block_until_ready(
        jax.lax.fori_loop(
            0,
            nvt_steps,
            step_nvt_fn,
            (state, nbrs, nbrs_lr, box, temp_i)  # carry state is tuple
        )
    )
    
    new_time = time.time()
    
    # Check for overflor of both sr and lr neighbors.
    if nbrs.did_buffer_overflow:
        print('Neighbor list overflowed, reallocating.')
        nbrs = neighbor_fn.allocate(state.position, box = box)
        if nbrs_lr.did_buffer_overflow:
            print('Long-range neighbor list also overflowed, reallocating.')
            nbrs_lr = neighbor_fn_lr.allocate(state.position, box = box)
    elif nbrs_lr.did_buffer_overflow:
        print('Long-range neighbor list overflowed, reallocating.')
        nbrs_lr = neighbor_fn_lr.allocate(state.position, box = box)
    else:
        state = new_state
        box = new_box

    # Calculate some quantities for printing
    KE = jax_md.quantity.kinetic_energy(
        momentum=state.momentum,
        mass=state.mass
    )
    
    PE = energy_fn(
        state.position,
        neighbor=nbrs.idx,
        neighbor_lr=nbrs_lr.idx, 
        box=box
    )
    
    T = jax_md.quantity.temperature(
        momentum=state.momentum,
        mass = state.mass
    ) / unit['temperature']
    
    H = jax_md.simulate.nvt_nose_hoover_invariant(
        energy_fn, 
        state, 
        kT=temp_i,
        neighbor=nbrs.idx,
        neighbor_lr=nbrs_lr.idx, 
        box=box
    )

    positions_md.append(np.array(state.position))
    
    print(f'{i*nvt_steps}\t{KE:.2f}\t{PE:.2f}\t{KE+PE:.3f}\t{T:.1f}\t{H:.3f}\t{(new_time - old_time) / nvt_steps:.4f}')

print('Total_time: ', time.time()-total_time)

# Clear all caches
jax.clear_caches()

  return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric,


Step	KE	PE	Total Energy	Temperature	H	time/steps
-----------------------------------------------------------------------------------
0	0.12	-181.31	-181.190	16.4	-181.229	0.2125
100	0.22	-181.32	-181.094	30.7	-181.229	0.0158
200	0.34	-181.27	-180.926	47.6	-181.228	0.0160
300	0.49	-181.18	-180.686	67.6	-181.228	0.0160
400	0.86	-181.18	-180.321	118.4	-181.227	0.0159
500	1.19	-181.01	-179.820	164.7	-181.227	0.0157
600	1.68	-180.91	-179.230	231.8	-181.227	0.0159
700	2.04	-180.71	-178.667	282.2	-181.226	0.0160
800	2.22	-180.50	-178.279	306.4	-181.226	0.0160
900	2.24	-180.35	-178.105	309.6	-181.226	0.0157
1000	1.98	-179.90	-177.922	273.7	-181.225	0.0161
1100	2.77	-180.18	-177.404	382.9	-181.225	0.0168
1200	2.51	-180.34	-177.831	346.7	-181.225	0.0158
1300	2.17	-180.42	-178.249	300.0	-181.226	0.0160
1400	1.97	-180.29	-178.321	272.6	-181.226	0.0160
1500	2.26	-180.29	-178.021	312.9	-181.226	0.0162
1600	2.46	-180.16	-177.699	339.5	-181.226	0.0158
1700	2.58	-180.57	-177.983	357.1	-181.227	0.0157
1

In [15]:
# If you want to do visualization in the notebook, install nglview by doing `pip install nglview` in your virtualenv
import nglview as nv
from ase import Atoms

atoms_traj = []
for positions in positions_md:
    atoms_traj.append(
        Atoms(numbers=np.array(species), positions=positions),
    )

nv.show_asetraj(atoms_traj)

NGLWidget(max_frame=24)

In [44]:
# Save the frames to xyz.
from ase.io import write

for frame in atoms_traj:
    write( 
        'nvt_md_trajectory.xyz',
        frame,
        append=True
    )