In [4]:
import numpy as np
from numba import njit
import time
from MDAnalysis.analysis import distances

@njit
def contact_matrix(positions, cutoff, box=None):
    n_atoms = positions.shape[0]
    matrix = np.zeros((n_atoms, n_atoms), dtype=np.bool_)
    
    for i in range(n_atoms - 1):
        for j in range(i + 1, n_atoms):
            diff = positions[i] - positions[j]
            if box is not None:
                diff -= np.round(diff / box) * box
            dist_sq = np.dot(diff, diff)
            if dist_sq < cutoff ** 2:
                matrix[i, j] = True
                matrix[j, i] = True  # Symmetric matrix
    return matrix

@njit
def update_neighbor_lists_numba(positions, cut_off, box_mdanalysis):
    n_atoms = positions.shape[0]
    neighbor_lists = []
    matrix = contact_matrix(positions, cutoff=cut_off, box=box_mdanalysis)
    for cpt in range(n_atoms - 1):
        neighbor_list = np.where(matrix[cpt])[0]
        neighbor_list = neighbor_list[neighbor_list > cpt]
        neighbor_lists.append(neighbor_list)
    return neighbor_lists

def update_neighbor_lists_mda(positions, cut_off, box_mdanalysis):
    neighbor_lists = []
    matrix = distances.contact_matrix(positions, cutoff=cut_off, returntype="numpy", box=box_mdanalysis)
    for cpt, array in enumerate(matrix[:-1]):
        neighbor_list = np.where(array)[0].tolist()
        neighbor_list = [ele for ele in neighbor_list if ele > cpt]
        neighbor_lists.append(neighbor_list)
    return neighbor_lists

def compare_functions():
    n_atoms = 1000
    box = np.array([10.0, 10.0, 10.0])
    box_mdanalysis = np.concatenate([box, [90.0, 90.0, 90.0]])
    cut_off = 1.5
    positions = np.random.rand(n_atoms, 3) * box
    
    # Ensure the positions and box_mdanalysis arrays are of type float64 for numba
    positions = positions.astype(np.float64)
    box_mdanalysis = box_mdanalysis.astype(np.float64)

    # Original method
    start_time = time.time()
    mda_neighbor_lists = update_neighbor_lists_mda(positions, cut_off, box_mdanalysis)
    original_time = time.time() - start_time

    # Numba method
    update_neighbor_lists_numba(positions, cut_off, box)
    start_time = time.time()
    numba_neighbor_lists = update_neighbor_lists_numba(positions, cut_off, box)
    numba_time = time.time() - start_time

    # Compare results and timings
    print(f"Original time: {original_time:.6f} seconds")
    print(f"Numba time: {numba_time:.6f} seconds")
    print(f"Results match: {len(mda_neighbor_lists) == len(numba_neighbor_lists)}")

compare_functions()


Original time: 0.042530 seconds
Numba time: 0.151321 seconds
Results match: True
