# Numba tutorial

In this example, we will analyze a dump file from a simulation of a single ten-bead polymer where the bonds  
interact via the FENE potential 

$$
u(r) = \begin{cases}
-\dfrac{1}{2} k R_0^2 \ln\left(1 - \left(\dfrac{r}{R_0}\right)^2\right) + 4\epsilon\left[\left(\dfrac{\sigma}{r}\right)^{12} - \left(\dfrac{\sigma}{r}\right)^6\right] + \epsilon, & r < R_0 \\[6pt]
\infty, & r \ge R_0.
\end{cases}
$$

We've included the [LAMMPS input files](lammps_input.in) if you want to run the
simulation yourself, but we've also included the [dump](out.lammpstrj) file 
that we generated if you want to continue along now!

First, we import `lammpsio` and the corresponding dump file.

In [35]:
import lammpsio
traj = lammpsio.DumpFile("out.lammpstrj")

## Calculating the radius of gyration

We will define three different ways to calculate the radius of gyration. The first method is the most computationally expensive version of the calculation defined as

$$
R_g^2 = 
\dfrac{1}{N^2}\sum_{i=1}^{N}\sum_{j=1}^{N}\vec{R_i}^2-\vec{R_i}\vec{R_j}
$$

where, $\vec{R_i}$ and $\vec{R_j}$ are position vectors and $N$ is the number of bead in the polymer. `lammpsio` allows easy extraction of the LAMMPS dump file. We can see, however, this method requires a lot of computation time.

In [39]:
import numpy
import time

def compute_rg(pos, N):
    rg_sqr = 0
    for i in range(0, N):
        for j in range(0, N):
            rg_sqr += numpy.dot(pos[i], pos[i]) - numpy.dot(pos[i], pos[j])
    return numpy.sqrt(rg_sqr / (N * N))

start_time = time.time()


rg = []
count = 0
for i, snapshot in enumerate(traj):
    pos = snapshot.position + 2*snapshot.box.high[0] * snapshot.image
    N = snapshot.N
    rg.append(compute_rg(pos, N))
print("Radius of gyration:", numpy.mean(rg), "Time taken:", time.time() - start_time)

Radius of gyration: 1.6496794663337495 Time taken: 16.29240918159485


Therefore, we can use another formulation of $R_g$ that can be easily vectorized and would be significantly faster to compute than the first method

$$
R_g^2 = 
\dfrac{1}{N}\sum_{i=1}^{N}(\vec{R_i}-\vec{R_{cm}})^2
$$

where $R_{cm}$ is the position of the center of the mass of the polymer

$$
R_{cm} = \sum_{j=1}^{N}(M_j\vec{R_j}/\vec{M_j}),
$$

and $M_j$ is the mass of each bead. 

In [40]:
import numpy
import time 

def compute_rg(pos):
    rcm_sqr = numpy.mean(pos, axis=0)
    rg_sqr = numpy.mean(numpy.sum((pos - rcm_sqr)**2, axis=1))
    return numpy.sqrt(rg_sqr)

start_time = time.time()
traj = lammpsio.DumpFile('out.lammpstrj')
rg = numpy.zeros(len(traj))

for i, snapshot in enumerate(traj): 
    pos = snapshot.position + 2*snapshot.box.high[0] * snapshot.image
    rg[i] = compute_rg(pos)
print("Radius of gyration:", numpy.mean(rg), "Time taken:", time.time() - start_time)

Radius of gyration: 1.6496794663337464 Time taken: 5.6428210735321045


This method of calculating $R_g$ is signficantly faster than the first one resulting in almost a 3x reduction in computation time! However, using a just-in-time (JIT) compiler like [numba](https://numba.readthedocs.io/en/stable/index.html) can make this calculation even faster. We use the ``compute_rg`` from the first implementation of the $R_g$ calculation and add the decorator ``@numba.njit`` to implement JIT compilation.

In [41]:
import numba

@numba.njit
def compute_rg(pos, N):
    rg_sqr = 0
    for i in range(0, N):
        for j in range(0, N):
            rg_sqr += numpy.dot(pos[i], pos[i]) - numpy.dot(pos[i], pos[j])
    return numpy.sqrt(rg_sqr / (N * N))

start_time = time.time()
traj = lammpsio.DumpFile('out.lammpstrj')

rg = []
count = 0
for i, snapshot in enumerate(traj):
    pos = snapshot.position + 2*snapshot.box.high[0] * snapshot.image
    N = snapshot.N
    rg.append(compute_rg(pos, N))
print("Radius of gyration:", numpy.mean(rg), "Time taken:", time.time() - start_time)

Radius of gyration: 1.6496794663337495 Time taken: 4.488613128662109


We see a 3.5x reduction in computation time compared to the first implementation! However, compared to the second implementation we see a 1.23x reduction in computation time.