In [1]:
import numpy as np
import torch
from numba import jit
from scipy.linalg.blas import sgemm
from scipy.spatial.distance import pdist, squareform


@torch.jit.script
def app1(X: torch.Tensor, Y: torch.Tensor):
    X_norm = torch.einsum("ij,ij->i", X, X)
    Y_norm = torch.einsum("ij,ij->i", Y, Y)
    X_YT = torch.einsum("ij,kj->ik", X, Y)
    return X_norm[:, None] + Y_norm[None, :] - 2 * X_YT

def app2(X: torch.Tensor, Y: torch.Tensor):
    X_norm = torch.einsum("ij,ij->i", X, X)
    Y_norm = torch.einsum("ij,ij->i", Y, Y)
    X_YT = torch.einsum("ij,kj->ik", X, Y)
    return X_norm[:, None] + Y_norm[None, :] - 2 * X_YT

def app3(X, Y, M):
    diff = X[:, None, :] - Y[None, :, :]
    diff_M = torch.einsum("bij,jk->bik", diff, M)
    diff_M_diffT = torch.einsum("bij,bij->bi", diff_M, diff)
    return diff_M_diffT

def app4(X, Y, M):
    diff = X[:, None, :] - Y[None, :, :]
    diff_M = diff @ M
    diff_M_diffT = (diff_M * diff).sum(-1)
    return diff_M_diffT

def app5(X, Y):
    X_norm = X.pow(2).sum(dim=-1, keepdim=True)
    Y_norm = Y.pow(2).sum(dim=-1, keepdim=True)
    res = torch.addmm(
        Y_norm.transpose(-2, -1), X, Y.transpose(-2, -1), alpha=-2
    ).add_(X_norm)
    return res

def app6(X):
    sq_dist = pdist(X)
    pairwise_dists = squareform(sq_dist)**2
    return pairwise_dists

In [2]:
BATCH, DIM = [300, 20]
torch.set_default_dtype(torch.double)
X = torch.randn(BATCH, DIM)
Y = torch.randn(BATCH, DIM)
M = torch.eye(DIM)

In [4]:
%timeit app1(X, Y)
%timeit app2(X, Y)
%timeit app3(X, Y)
%timeit app4(X, Y)
%timeit app5(X, Y)
%timeit app6(X)

417 µs ± 21.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
329 µs ± 31.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
221 µs ± 1.04 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
256 µs ± 915 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
5.88 ms ± 72 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
8.43 ms ± 30.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
152 µs ± 125 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
847 µs ± 3.47 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [3]:
res1 = app1(X, X)
res2 = app2(X, X)
res3 = app3(X, X, M)
res4 = app4(X, X, M)
res5 = app5(X, X)
res6 = app6(X)

print(np.allclose(res1, res2))
print(np.allclose(res2, res3))
print(np.allclose(res3, res4))
print(np.allclose(res4, res5))
print(np.allclose(res5, res6))

True
True
True
True
True


In [4]:
from torch.autograd import grad

X = X.requires_grad_()
res1 = grad(app1(X, X.detach()).sum(), X)[0]
res2 = grad(app2(X, X.detach()).sum(), X)[0]
res3 = grad(app3(X, X.detach(), M).sum(), X)[0]
res4 = grad(app4(X, X.detach(), M).sum(), X)[0]
res5 = grad(app5(X, X.detach()).sum(), X)[0]
# res6 = grad(app6(X).sum(), X)[0]

print(np.allclose(res1, res2))
print(np.allclose(res2, res3))
print(np.allclose(res3, res4))
print(np.allclose(res4, res5))
# print(np.allclose(res5, res6))

True
True
True
True
