In [11]:
import time

import numpy as np
from numba import jit
from sklearn.metrics.pairwise import pairwise_distances

In [62]:
def weighted_center_mat1(distx, weights):
    n = distx.shape[0]

    row_sum = np.average(distx, axis=0, weights=weights)
    total_sum = np.average(row_sum, weights=weights)

    exp_distx = ( # code borrowed from hyppo
        np.repeat(row_sum, n).reshape(-1, n).T
        + np.repeat(row_sum, n).reshape(-1, n)
        - (total_sum)
    )

    cent_distx = distx - exp_distx

    return cent_distx

def weighted_center_mat2(distx, weights):
    n = distx.shape[0]

    scl = np.sum(weights)
    row_sum = np.sum(np.multiply(distx, weights), axis=1) / scl
    total_sum = weights @ row_sum / scl

    exp_distx = ( # code borrowed from hyppo
        np.repeat(row_sum, n).reshape(-1, n).T
        + np.repeat(row_sum, n).reshape(-1, n)
        - (total_sum)
    )

    cent_distx = distx - exp_distx

    return cent_distx


def weighted_center_mat3(distx, weights):
    n = distx.shape[0]

    row_sum = np.average(distx, axis=0, weights=weights)
    total_sum = np.average(row_sum, weights=weights)

    cent_distx = distx - row_sum.reshape(-1, n).T - row_sum.reshape(-1, n) + total_sum

    return cent_distx

def weighted_center_mat4(distx, weights):
    n = distx.shape[0]

    scl = np.sum(weights)
    row_sum = np.sum(np.multiply(distx, weights), axis=1) / scl
    total_sum = weights @ row_sum / scl

    cent_distx = distx - row_sum.reshape(-1, n).T - row_sum.reshape(-1, n) + total_sum

    return cent_distx

@jit(cache=True)
def weighted_center_mat2jit(distx, weights):
    n = distx.shape[0]

    scl = np.sum(weights)
    row_sum = np.sum(np.multiply(distx, weights), axis=1) / scl
    total_sum = weights @ row_sum / scl

    exp_distx = ( # code borrowed from hyppo
        np.repeat(row_sum, n).reshape(-1, n).T
        + np.repeat(row_sum, n).reshape(-1, n)
        - (total_sum)
    )

    cent_distx = distx - exp_distx

    return cent_distx

@jit(cache=True)
def weighted_center_mat4jit(distx, weights):
    n = distx.shape[0]

    scl = np.sum(weights)
    row_sum = np.sum(np.multiply(distx, weights), axis=1) / scl
    total_sum = weights @ row_sum / scl

    cent_distx = distx - row_sum.reshape(-1, n).T - row_sum.reshape(-1, n) + total_sum

    return cent_distx

In [54]:
n = 2000

X = np.random.normal(size=(n, 2))
dx = pairwise_distances(X)
dy = np.random.uniform(size=n)

In [55]:
r = %timeit -o -n 100 -r 3 weighted_center_mat1(dx, dy)

61.8 ms ± 241 µs per loop (mean ± std. dev. of 3 runs, 100 loops each)


In [56]:
r = %timeit -o -n 100 -r 3 weighted_center_mat2(dx, dy)

61.4 ms ± 97.3 µs per loop (mean ± std. dev. of 3 runs, 100 loops each)


In [57]:
r = %timeit -o -n 100 -r 3 weighted_center_mat3(dx, dy)

19.1 ms ± 157 µs per loop (mean ± std. dev. of 3 runs, 100 loops each)


In [58]:
r = %timeit -o -n 100 -r 3 weighted_center_mat4(dx, dy)

18.6 ms ± 48.1 µs per loop (mean ± std. dev. of 3 runs, 100 loops each)


In [59]:
r = %timeit -o -n 100 -r 3 weighted_center_mat4jit(dx, dy)

33.1 ms ± 58.8 µs per loop (mean ± std. dev. of 3 runs, 100 loops each)


In [64]:
r = %timeit -o -n 100 -r 3 weighted_center_mat2jit(dx, dy)

63.7 ms ± 1.08 ms per loop (mean ± std. dev. of 3 runs, 100 loops each)
