In [1]:
import jax
import jax.numpy as jnp
import jax_md
import numpy as np

from ase.io import read
from jax_md import units
from typing import Dict

from solar import to_jax_md
from solar import SolarPotential

import time



In [2]:
# Some helper functions used throughout the example notebook.

# Default nose hoover chain parameters.
def default_nhc_kwargs(
    tau: jnp.float32, 
    overrides: Dict
) -> Dict:
    
    default_kwargs = {
        'chain_length': 3, 
        'chain_steps': 2, 
        'sy_steps': 3,
        'tau': tau
    }
    
    if overrides is None:
        return default_kwargs
  
    return {
      k: overrides.get(k, default_kwargs[k]) for k in default_kwargs
    }

In [3]:
# Read some molecular structure. Here we read the DHA structure from the test directory.
path_to_xyz = '../tests/test_data/water_64.xyz'
atoms = read(path_to_xyz, index=-1)
atoms.wrap()

# Repeat in each direction.
atoms = atoms * [2, 2, 2]

species = jnp.array(atoms.get_atomic_numbers())
masses = jnp.array(atoms.get_masses())
num_atoms = len(species)

# Periodic Boundary Conditions

In contrast to the `nvt_jaxmd.ipynb` example, here the simulations are performed within a box such that periodic boundary conditions (PBCs) need to be applied. The functions calculating the `displacement` vectors between atoms and the `shift` function for updating atomic positions during simulation can be created using `jax_md.space.periodic_general`. It takes the simulation `box` as input, which can be represented with a single scalar for a box with equal side lengths `L`, by a vector `[Lx, Ly, Lz]` for an orthorombic cell, or by an upper triangular matrix for a general triclinic cell. Check also the `jax_md` docs [here](https://jax-md.readthedocs.io/en/main/jax_md.space.html#jax_md.space.periodic_general). The water box loaded above has a cubic box with equal side length such that we can represent it in `jax_md` by a single scalar.

Since we are running simulations in fractional coordinates, we need to project the positions on the hypercube with side length 1. For the cubic case this is easy and corresponds to just dividing by the length of the lattice vectors. This is also described [here](https://jax-md.readthedocs.io/en/main/jax_md.space.html#jax_md.space.periodic_general).

In [4]:
# it is important that scalar is array of dim = 1. Otherwise weird behavior of jax_md.simulate.npt_box(state) which returns
# a 3x3 array when passed a jnp.array of ndim = 0. This leads to TracerConversionError. Maybe open issue on jax_md?

box = jnp.array(
    [
        np.array(atoms.get_cell())[0, 0]
    ]
)  

print('box =', box)

fractional_coordinates = True
displacement, shift = jax_md.space.periodic_general(box=box, fractional_coordinates=fractional_coordinates)

positions_init = jnp.array(atoms.get_positions())
if fractional_coordinates:
    positions_init = positions_init / box

box = [24.8]


# Prepare JAX-MD Energy Function

JAX-MD is build around the energy function and the neighborlist. Since we have short- and long-range neighbors, we have two neighborlists compared to the standard case of only a single one. We use the convenience interface `to_jax_md` provided within this repository, which returns the SO3LR energy function as well as the two neighborlists. It takes as input the `SolarPotential` itself, as well as a displacement function which determines how update the atomic positions (see above).

In [5]:
neighbor_fn, neighbor_fn_lr, energy_fn = to_jax_md(
    potential=SolarPotential(),
    displacement_or_metric=displacement,
    box_size=box,
    species=species,
    capacity_multiplier=1.25,
    buffer_size_multiplier_sr=1.25,
    buffer_size_multiplier_lr=1.25,
    minimum_cell_size_multiplier_sr=1.0,
    disable_cell_list=False,
    fractional_coordinates=fractional_coordinates
)

# Energy function.
energy_fn = jax.jit(energy_fn)
force_fn = jax.jit(jax_md.quantity.force(energy_fn))

# Initialize the short and long-range neighbor lists.
nbrs = neighbor_fn.allocate(
    positions_init,
    box=box
)
nbrs_lr = neighbor_fn_lr.allocate(
    positions_init,
    box=box
)



jax-md partition: cell_size=0.18145161867141724, cl.id_buffer.size=2875, N=1536, cl.did_buffer_overflow=False


In [6]:
# Check that the energy function is working.
energy_fn(
    positions_init, 
    neighbor=nbrs.idx, 
    neighbor_lr=nbrs_lr.idx, 
    box=box
)

Array(-4941.2207, dtype=float32)

# Structure Relaxation

In [27]:
# For repeated execution of cell, delete the fire_state first, otherwise jit in jupyter environment 
# can meddle with things.
try:
    del fire_state
except Error:
    pass

min_cycles = 5
min_steps = 10

fire_init, fire_apply = jax_md.minimize.fire_descent(
    energy_fn, 
    shift, 
    dt_start = 0.05, 
    dt_max = 0.1, 
    n_min = 2
)

fire_apply = jax.jit(fire_apply)
fire_state = fire_init(
    positions_init, 
    box=box, 
    neighbor=nbrs.idx,
    neighbor_lr=nbrs_lr.idx
)

@jax.jit
def step_fire_fn(i, fire_state):
    
    fire_state, nbrs, nbrs_lr = fire_state
    
    fire_state = fire_apply(
        fire_state, 
        neighbor=nbrs.idx, 
        neighbor_lr=nbrs_lr.idx, 
        box=box
    )
    
    nbrs = nbrs.update(
        fire_state.position,
        neighbor=nbrs.idx,
        box=box
    )
    
    nbrs_lr = nbrs_lr.update(
        fire_state.position,
        neighbor_lr=nbrs_lr.idx,
        box=box
    )
    
    return fire_state, nbrs, nbrs_lr

print('Step\tE\tFmax')
print('--------------------------')
for i in range(min_cycles):
    fire_state, nbrs, nbrs_lr = jax.lax.fori_loop(
        0, 
        min_steps, 
        step_fire_fn, 
        (fire_state, nbrs, nbrs_lr)
    )
    
    E = energy_fn(
        fire_state.position, 
        neighbor=nbrs.idx, 
        neighbor_lr=nbrs_lr.idx,
        box=box
    )

    F = force_fn(
        fire_state.position, 
        neighbor=nbrs.idx, 
        neighbor_lr=nbrs_lr.idx,
        box=box
    )
    
    print('{}\t{:.2f}\t{:.2f}'.format(i, E, np.abs(F).max()))
    

Step	E	Fmax
--------------------------
jax-md partition: cell_size=0.18145161867141724, cl.id_buffer.size=2875, N=1536, cl.did_buffer_overflow=Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=3/0)>
0	-5042.64	1.23
1	-5081.91	1.88
2	-5110.96	0.55
3	-5123.87	0.57
4	-5132.59	0.53


Print the delta between original and optimized positions for the first three atoms. The difference is in fractional coordinates. To go back to real space we can use `jax_md.space.transform`. Alternatively, one can use the `displacement` function returned by `jax_md.space.perdiodic_general` which always returns displacements in real space as it has information about if the input is in real or fractional coordinates.

In [28]:
print('Difference in fractional coordinates: ')
print((positions_init - fire_state.position)[:3])
print('\n')
print('Difference in real coordinates using transform function: ')
print((jax_md.space.transform(box, positions_init) - jax_md.space.transform(box, fire_state.position))[:3])
print('\n')
print('Difference in real coordinates from displacement function: ')
print((jax.vmap(displacement)(positions_init, fire_state.position))[:3])

Difference in fractional coordinates: 
[[-0.0078468   0.0001581   0.00202413]
 [-0.0088468   0.00183146  0.01245631]
 [-0.00348818  0.00351053  0.00309914]]


Difference in real coordinates using transform function: 
[[-0.19460106  0.00392097  0.05019855]
 [-0.2194004   0.04542017  0.30891657]
 [-0.08650684  0.08706188  0.07685852]]


Difference in real coordinates from displacement function: 
[[-0.1946007   0.00392165  0.05019803]
 [-0.21940112  0.0454205   0.30891618]
 [-0.08650693  0.08706126  0.07685876]]


# Nose-Hoover Chain NPT simulations

In [39]:
# Simulation parameters

timestep = 0.0005  # Time step in ps
npt_cycles = 25  # Number of Cycles in the NVT.
npt_steps = 5  # Number of NVT steps per cylce. The total number of MD steps equals nvt_cylces * nvt_steps

T_init = 300  # Initial temperature.
pressure = 1.  # Target pressure. 

chain = 3  # Number of chains in the Nose-Hoover chain.
chain_steps = 2  # Number of steps per chain.
sy_steps = 3

thermo = 100  # Thermostat value in the Nose-Hoover chain. 
baro = 1000  # Barostat value in the Nose-Hoover chain. 

# Dictionary with the NHC settings.
new_nhc_kwargs = {
    'chain_length': chain, 
    'chain_steps': chain_steps, 
    'sy_steps': sy_steps
}

# Convert to metal unit system.
unit = units.metal_unit_system()

timestep = timestep * unit['time']
T_init = T_init * unit['temperature']

rng_key = jax.random.PRNGKey(0)

  return asarray(x, dtype=self.dtype)


In [40]:
# Choose Nose-Hoover thermostat.
init_fn, apply_fn = jax_md.simulate.npt_nose_hoover(
    energy_fn, 
    shift, 
    dt=timestep,
    pressure=pressure, 
    kT=T_init,
    barostat_kwargs=default_nhc_kwargs(baro * timestep, new_nhc_kwargs),
    thermostat_kwargs=default_nhc_kwargs(thermo * timestep, new_nhc_kwargs)
)

apply_fn = jax.jit(apply_fn)
init_fn = jax.jit(init_fn)

# Initialize state using position and neigbhors structure relaxation.
state = init_fn(
    rng_key, 
    fire_state.position, 
    box=box, 
    neighbor=nbrs.idx, 
    neighbor_lr=nbrs_lr.idx,
    kT=T_init,
    mass=masses
)

@jax.jit
def step_npt_fn(i, state):
    state, nbrs, nbrs_lr, box = state
    
    state = apply_fn(
        state,
        neighbor=nbrs.idx, 
        neighbor_lr=nbrs_lr.idx, 
        kT=T_init,
        pressure=pressure
    )
    
    box = jax_md.simulate.npt_box(state)
    
    nbrs = nbrs.update(
        state.position, 
        neighbor=nbrs.idx, 
        box=box
    )
    
    nbrs_lr = nbrs_lr.update(
        state.position,
        neighbor_lr=nbrs_lr.idx,
        box=box
    )
    
    return state, nbrs, nbrs_lr, box

# Track total time and step times averaged over cycle.
total_time = time.time()

positions_md = []

print('Step\tKE\tPE\tTot. Energy\tTemp.\tH\ttime/steps\tInvariant drifts (H_i - H_0 , H - H_{i-1}) [meV/atom/ps]')
print('-------------------------------------------------------------------------------------------------------------------------------------')
for i in range(npt_cycles):

    if i == 0:
        initial_H_0 = jax_md.simulate.npt_nose_hoover_invariant(
            energy_fn, 
            state, 
            pressure=pressure,
            kT=T_init,
            neighbor=nbrs.idx, 
            neighbor_lr=nbrs_lr.idx
        )
    
    # Calculate initial total energy
    initial_H = jax_md.simulate.npt_nose_hoover_invariant(
        energy_fn, 
        state, 
        pressure=pressure,
        kT=T_init,
        neighbor=nbrs.idx,
        neighbor_lr=nbrs_lr.idx
    )

    old_time = time.time()
    
    # Do `npt_steps` NPT steps.
    new_state, nbrs, nbrs_lr, new_box = jax.block_until_ready(
        jax.lax.fori_loop(
            0, 
            npt_steps, 
            step_npt_fn, 
            (state, nbrs, nbrs_lr, box)
        )
    )
    
    new_time = time.time()
    
    # Check for overflor of both sr and lr neighbors.
    if nbrs.did_buffer_overflow:
        print('Neighbor list overflowed, reallocating.')
        nbrs = neighbor_fn.allocate(state.position, box = box)
        if nbrs_lr.did_buffer_overflow:
            print('Long-range neighbor list also overflowed, reallocating.')
            nbrs_lr = neighbor_fn_lr.allocate(state.position, box = box)
    elif nbrs_lr.did_buffer_overflow:
        print('Long-range neighbor list overflowed, reallocating.')
        nbrs_lr = neighbor_fn_lr.allocate(state.position, box = box)
    else:
        state = new_state
        box = new_box

    # Calculate some quantities for printing
    KE = jax_md.quantity.kinetic_energy(
        momentum=state.momentum,
        mass=state.mass
    )
    
    PE = energy_fn(
        state.position,
        neighbor=nbrs.idx,
        neighbor_lr=nbrs_lr.idx, 
        box=box
    )
    
    T = jax_md.quantity.temperature(
        momentum=state.momentum,
        mass = state.mass
    ) / unit['temperature']
    
    # Calculate initial total energy
    H = jax_md.simulate.npt_nose_hoover_invariant(
        energy_fn,
        state,
        pressure=pressure,
        kT=T_init,
        neighbor=nbrs.idx,
        neighbor_lr=nbrs_lr.idx
    )

    energy_drift_h = (H - initial_H) * 1000 / (timestep / unit['time'] * npt_steps * num_atoms)
    energy_drift_h_0 = (H - initial_H_0) * 1000 / (timestep / unit['time'] * npt_steps * (i + 1) * num_atoms)
    
    positions_md.append(np.array(state.position))
    
    print(
        f'{i*npt_steps}\t{KE:.2f}\t{PE:.2f}\t{KE+PE:.3f}\t{T:.1f}\t{H:.3f}\t{(new_time - old_time) / npt_steps:.2f}\t{energy_drift_h_0:.2f}  ,  {energy_drift_h:.2f}'
    )

print('Total_time: ', time.time() - total_time)

# Clear all caches
jax.clear_caches()

  return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric,
  return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric,


Step	KE	PE	Tot. Energy	Temp.	H	time/steps	Invariant drifts (H_i - H_0 , H - H_{i-1}) [meV/atom/ps]
-------------------------------------------------------------------------------------------------------------------------------------
jax-md partition: cell_size=0.18145161867141724, cl.id_buffer.size=2875, N=1536, cl.did_buffer_overflow=Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=3/0)>
0	42.20	-5115.29	-5073.087	212.5	-5048.314	10.26	108.72  ,  108.72
5	45.81	-5119.14	-5073.336	230.7	-5048.589	2.51	18.56  ,  -71.59
10	36.38	-5109.44	-5073.062	183.2	-5048.364	2.53	31.92  ,  58.62
15	41.94	-5115.02	-5073.080	211.2	-5048.548	2.51	11.98  ,  -47.81
20	35.01	-5107.85	-5072.839	176.3	-5048.478	2.51	13.22  ,  18.18
25	31.87	-5104.49	-5072.617	160.5	-5048.414	2.58	13.80  ,  16.66
30	39.49	-5112.03	-5072.544	198.9	-5048.551	2.59	6.72  ,  -35.73
35	39.53	-5111.62	-5072.090	199.1	-5048.416	2.59	10.28  ,  35.22
40	41.26	-5113.25	-5071.991	207.8	-5048.538	2.54	5.61  ,  -31.79
45	35.43	-510

# Visualization and IO

For now, we concatenated the positions to a simple python list `positions_md`. This allows to directly visualize the trajectory and write the frames to `xyz` after the simulation. However, for production runs it might be neccessary to save the frames (and potentially other information) along the way. For large structures, using `ase.io.write` can be very very slow and reduce the simulation time by sevaral order of magnitudes. To learn how to efficiently write frames and other statistics during simulation, check out the `production_run.ipynb` example notebooks. It uses `.hdf5` files to efficiently perform save operations.

In [45]:
# If you want to do visualization in the notebook, install nglview by doing `pip install nglview` in your virtualenv
import nglview as nv
from ase import Atoms

atoms_traj = []
for positions in positions_md:
    atoms_traj.append(
        Atoms(
            numbers=np.array(species), 
            positions=np.array(jax_md.space.transform(box=box, R=positions))  # transform back from fractional coordinates
        ),
    )

nv.show_asetraj(atoms_traj)

NGLWidget(max_frame=24)

In [46]:
# Save the frames to xyz.
from ase.io import write

for frame in atoms_traj:
    write( 
        'npt_water_md_trajectory.xyz',
        frame,
        append=True
    )