In [1]:
%load_ext Cython

manggling/demanggling
*pointer: access value

In [2]:
%%cython --annotate --force
# distutils: extra_compile_args=-fopenmp
# distutils: extra_link_args=-fopenmp

from scipy.spatial.distance import cdist
import itertools
import time
cimport numpy as np
import numpy as np
from libc.math cimport fabs
from cython.parallel cimport prange
from cython cimport floating, integral, numeric
cimport cython

@cython.boundscheck(False)
cdef void _compute_dist(
    floating[::1] X_a_row,
    floating[::1] X_b_row,
    integral n_features,
    floating *dist
) nogil:
    cdef integral i
    for i in range(n_features):
        dist[0] += fabs(X_a_row[i] - X_b_row[i])

@cython.boundscheck(False)
cdef void _pairwise_dist(
    floating[:, ::1] X_a, # IN
    floating[:, ::1] X_b, # IN
    floating[:, ::1] distances, # OUT
    integral dummy
) nogil:
    cdef:
        integral i, j
        integral n_rows_X_a = X_a.shape[0]
        integral n_rows_X_b = X_b.shape[0]
        integral n_features = X_a.shape[1]
        
    for i in prange(n_rows_X_a, nogil=True):
        for j in range(n_rows_X_b):
            _compute_dist(X_a[i], X_b[j], n_features, &distances[i, j])

def pairwise_dist(
    floating[:, ::1] X_a,
    floating[:, ::1] X_b
):
    float_dtype = np.float32 if floating is float else np.float64
    cdef:
        floating[:, ::1] distances = np.zeros([X_a.shape[0], X_b.shape[0]], dtype=float_dtype)
    
    _pairwise_dist(X_a, X_b, distances, 42)
    
    return np.asarray(distances)

In [6]:
def time_func(func, *args, **kwargs):
    times = []
    for _ in range(10):
        start = time.perf_counter()
        expected = func(*args, **kwargs)
        end = time.perf_counter()
        time_elapsed = end - start
        times.append(time_elapsed)
    mean_time = np.mean(times)
    return expected, mean_time

In [7]:
def test_correctness(n, p, metric="cityblock"):
    X_a = np.random.rand(n, p)
    X_b = np.random.rand(n, p)
    
    expected, t1 = time_func(cdist, X_a, X_b, metric=metric)
    actual, t2 = time_func(pairwise_dist, X_a, X_b)
    print(f"t1={t1}s")
    print(f"t2={t2}s")

    np.testing.assert_allclose(actual, expected, verbose=True)

In [14]:
test_correctness(5000, 100)

t1=2.1033318763000013s
t2=1.6547433057000263s
