In [1]:
import numexpr as ne
import numpy as np
import torch
from numba import jit
from scipy.linalg.blas import sgemm


def app1(X, Y, gamma, var):
    X_norm = np.einsum("ij,ij->i", X, X)
    Y_norm = np.einsum("ij,ij->i", Y, Y)
    return ne.evaluate(
        "v * exp(-g * (A + B - 2 * C))",
        {
            "A": X_norm[:, None],
            "B": Y_norm[None, :],
            "C": np.dot(X, Y.T),
            "g": gamma,
            "v": var,
        },
    )


def app2(X, Y, gamma, var):
    X_norm = np.einsum("ij,ij->i", X, X)
    Y_norm = np.einsum("ij,ij->i", Y, Y)
    return ne.evaluate(
        "v * exp(-g * (A + B + C))",
        {
            "A": X_norm[:, None],
            "B": Y_norm[None, :],
            "C": sgemm(alpha=-2.0, a=X, b=Y, trans_b=True),
            "g": gamma,
            "v": var,
        },
    )


@torch.jit.script
def app3(X: torch.Tensor, Y: torch.Tensor, gamma:float, var:float):
    X_norm = torch.einsum("ij,ij->i", X, X)
    Y_norm = torch.einsum("ij,ij->i", Y, Y)
    X_YT = torch.einsum("ij,kj->ik", X, Y)
    return var * torch.exp(-gamma * (X_norm[:, None] + Y_norm[None, :] - 2 * X_YT))

def app4(X: torch.Tensor, Y: torch.Tensor, gamma:float, var:float):
    X_norm = torch.einsum("ij,ij->i", X, X)
    Y_norm = torch.einsum("ij,ij->i", Y, Y)
    X_YT = torch.einsum("ij,kj->ik", X, Y)
    return var * torch.exp(-gamma * (X_norm[:, None] + Y_norm[None, :] - 2 * X_YT))

def app5(X, Y, M, gamma:float, var:float):
    diff = X[:, None, :] - Y[None, :, :]
    dM = torch.einsum("bij,jk->bik", diff, M)
    dMdT = torch.einsum("bij,bij->bi", dM, diff)
    return var * torch.exp(-gamma * dMdT)

def app6(X, Y, M, gamma:float, var:float):
    diff = X[:, None, :] - Y[None, :, :]
    diff_M = diff @ M
    dMdT = (diff_M * diff).sum(-1)
    return var * torch.exp(-gamma * dMdT)

In [2]:
BATCH, DIM = [3000, 200]
torch.set_default_dtype(torch.double)
X = torch.randn(BATCH, DIM)
Y = torch.randn(BATCH, DIM)
M = torch.eye(DIM)
gamma = 0.1
var = 5.0


In [159]:
%timeit app1(X.numpy(), Y.numpy(), gamma, var)
%timeit app2(X.numpy(), Y.numpy(), gamma, var)
%timeit app3(X, Y, gamma, var)
%timeit app4(X, Y, gamma, var)
%timeit app5(X, Y, M, gamma, var)
%timeit app6(X, Y, M, gamma, var)

113 ms ± 3.91 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
73.4 ms ± 3.03 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
171 ms ± 1.06 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
173 ms ± 3.26 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
3min 33s ± 8.84 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
4min 49s ± 9.62 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [3]:
res1 = app1(X.numpy(), Y.numpy(), gamma, var)
res2 = app2(X.numpy(), Y.numpy(), gamma, var)
res3 = app3(X, Y, gamma, var)
res4 = app4(X, Y, gamma, var)
# res5 = app5(X, Y, M, gamma, var)
# res6 = app6(X, Y, M, gamma, var)

print(np.allclose(res1, res2))
print(np.allclose(res2, res3))
print(np.allclose(res3, res4))
# print(np.allclose(res4, res5))
# print(np.allclose(res5, res6))

True
True
True


In [10]:
from sklearn.metrics.pairwise import euclidean_distances, paired_distances

res7 = var * np.exp(-gamma * euclidean_distances(X, Y, squared=True))
res8 = var * np.exp(-gamma * paired_distances(X, Y, metric=lambda X, Y: X @ M.numpy() @ Y.T) ** 2)
np.allclose(res7, res8)

False

In [16]:
res7

array([[1.23209562e-21, 1.71467348e-20, 6.57732083e-19, ...,
        5.78520745e-17, 3.26711483e-18, 4.37651596e-16],
       [2.04684470e-16, 9.75048355e-16, 2.17910272e-19, ...,
        2.48541964e-16, 1.77949414e-16, 4.78741975e-17],
       [4.77888523e-18, 1.64343976e-16, 1.19994201e-16, ...,
        1.07172281e-18, 1.14148006e-15, 6.17344102e-16],
       ...,
       [1.84529090e-19, 8.77244639e-17, 5.73377928e-17, ...,
        2.94019341e-15, 3.41572771e-17, 2.72535019e-18],
       [3.53770999e-17, 1.02515794e-15, 1.97449921e-18, ...,
        3.00431546e-15, 5.28883421e-16, 5.76570105e-16],
       [3.12631505e-18, 3.23616528e-20, 2.82122623e-18, ...,
        8.09796268e-17, 6.35229953e-17, 1.02331294e-17]])

In [15]:
res8

array([1.65213810e-30, 3.49667650e-11, 2.14928364e-01, ...,
       6.16201577e-15, 3.64440105e+00, 8.50200258e-06])