In [5]:
#include parent folder
import os, sys, inspect

currentdir = os.path.dirname(
    os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir)

import time
import timeit
import numpy as np
import pytest
import jax
import functools

from src.domain import Domain
from src.misc import *

In [6]:
def VLJ1(v, w):   
    r = jnp.linalg.norm(v-w) # euclidian distance between points
    #if(r == 0): return 0
    return 4 * (4*pow(1 / r, 12) - 2*pow(1 / r, 6)) # TODO: check if this is actually correct

VLJ1_jit = jax.jit(VLJ1)

In [7]:
def Epot_loops_nojit(pos): 
    # primitive Loop   
    v = to3D(pos)
    E = 0
    for i in range(len(v)):
        for j in range(i + 1, len(v)):
            E += VLJ1(v[i], v[j])    # using jit actually shaved off 20% form runtime  
    return E  

def Epot_loops_jit(pos): 
    # primitive Loop   
    v = to3D(pos)
    E = 0
    for i in range(len(v)):
        for j in range(i + 1, len(v)):
            E += VLJ1_jit(v[i], v[j])    # using jit actually shaved off 20% form runtime  
    return E    

In [21]:
def benchMark(domain, Epot_func):
    n = 10
    t = timeit.Timer(functools.partial(Epot_func, domain.pos))
    return t.timeit(n) / n  # Rrepeat n times and take average

domain = Domain()
domain.fill(5,10,1)
print(f"{benchMark(domain, Epot_loops_nojit):.4E}s: primitive loop, no jit ")
print(f"{benchMark(domain, Epot_loops_jit):.4E}s: primitive loop, jit ")

5.3482E-02s: primitive loop, no jit 
3.2912E-02s: primitive loop, jit 
