In [6]:
import numpy as np
from numba import njit
import time

def compute_distance_numpy(position_i, positions_j, box_size, only_norm=True):
    """Use NumPy to calculate the distance between particles"""
    rij_xyz = np.nan_to_num(np.remainder(position_i - positions_j
                                         + box_size[:3]/2.0, box_size[:3]) - box_size[:3]/2.0)
    norm = np.linalg.norm(rij_xyz, axis=1)
    if only_norm:
        return norm, None
    else:
        return norm, rij_xyz

@njit
def compute_distance_numba(position_i, positions_j, box_size, only_norm=True):
    """Use NumBa and NumPy remainder function to calculate the distance between particles"""
    rij_xyz = np.remainder(position_i - positions_j
                           + box_size[:3] / 2.0, box_size[:3]) - box_size[:3] / 2.0
    norms = np.sqrt(np.sum(rij_xyz**2, axis=1))
    if only_norm:
        return norms, None  # Always return a tuple
    else:
        return norms, rij_xyz



In [10]:
def compare_functions():
    # Create random test data
    position_i = np.random.rand(3)*10.0
    positions_j = np.random.rand(100000, 3)*10.0
    box_size = np.array([10.0, 10.0, 10.0])

    # NumBa compilation takes time the first time it is used
    compute_distance_numba(position_i, positions_j, box_size)

    start_time = time.time()
    result_original, _ = compute_distance_numpy(position_i, positions_j, box_size)
    original_time = time.time() - start_time

    start_time = time.time()
    result_numba, _ = compute_distance_numba(position_i, positions_j, box_size)
    numba_time = time.time() - start_time

    assert np.allclose(result_original, result_numba), "Norms do not match for only_norm=True!"

    print(f"Numpy function performance: {original_time:.6f} seconds")
    print(f"Numba function performance: {numba_time:.6f} seconds")

compare_functions()

Numpy function performance: 0.007386 seconds
Numba function performance: 0.002582 seconds
