In [1]:
import torch
import torch.nn.functional as F

In [3]:
torch.manual_seed(95054)

A = torch.randn([3,4,5,7])
B = torch.randn([3,4,5,7])

In [8]:
torch.tensordot(A, B, dims=[[1,2,3], [1,2,3]]) - torch.matmul(A.view([3, -1]), B.view([3, -1]).T)

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

In [19]:
tuple(torch.roll(torch.arange(A.ndim), 2).numpy())

(2, 3, 0, 1)

In [111]:
def matricize(A, ix):
    return A.moveaxis(ix, 0).reshape([A.size()[ix], -1])
torch.tensordot(A, B, dims=[[0,2,3], [0,2,3]]) - torch.matmul(matricize(A, 1), matricize(B, 1).T)

tensor([[-4.7684e-07, -1.9073e-06,  9.5367e-07, -7.1526e-07],
        [ 1.0729e-06,  1.1921e-06,  1.9073e-06,  1.0729e-06],
        [ 0.0000e+00, -4.7684e-07,  3.3379e-06,  0.0000e+00],
        [ 1.9073e-06,  0.0000e+00, -1.9073e-06,  0.0000e+00]], device='cuda:0')

In [112]:
torch.tensordot(A, B, dims=[[0,1,3], [0,1,3]]) - torch.matmul(matricize(A, 2), matricize(B, 2).T)

tensor([[-9.5367e-07, -3.5763e-07,  9.5367e-07, -9.5367e-07,  0.0000e+00],
        [ 1.9073e-06, -2.9802e-08, -7.1526e-07,  2.8610e-06,  0.0000e+00],
        [-9.5367e-07,  0.0000e+00,  9.5367e-07,  9.5367e-07, -9.5367e-07],
        [-9.5367e-07,  1.1921e-06,  1.9073e-06, -1.4305e-06,  4.7684e-07],
        [-7.1526e-07, -9.5367e-07, -7.1526e-07,  9.5367e-07,  7.1526e-07]],
       device='cuda:0')

In [113]:
torch.tensordot(A, B, dims=[[0,1,2], [0,1,2]]) - torch.matmul(matricize(A, 3), matricize(B, 3).T)

tensor([[0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.]], device='cuda:0')

In [115]:
rank = A.ndim
A = A.cpu()
B = B.cpu()

In [116]:
%%timeit
for i in range(rank):
    axes = list(range(i)) + list(range(i + 1, rank))
    torch.tensordot(A, B, dims=[axes, axes])

    

54.5 µs ± 1.17 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [117]:
%%timeit   
for i in range(rank):
    torch.matmul(matricize(A, i), matricize(B, i).T)

61.9 µs ± 737 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [118]:
A = A.cuda()
B = B.cuda()
torch.cuda.synchronize()

In [119]:
%%timeit
for i in range(rank):
    axes = list(range(i)) + list(range(i + 1, rank))
    torch.tensordot(A, B, dims=[axes, axes])
torch.cuda.synchronize()

180 µs ± 11.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [120]:
%%timeit   
for i in range(rank):
    torch.matmul(matricize(A, i), matricize(B, i).T)
torch.cuda.synchronize()

199 µs ± 8.8 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [46]:
C = torch.randn([4,5,8,6])
D = torch.randn([4,5,8,6])

In [58]:
max_dims = list(
    torch.maximum(torch.Tensor(tuple(A.size())), torch.Tensor(tuple(C.size()))).int().numpy()
)

In [70]:
def flatten(x):
    return tuple(i for y in x for i in y)


(0, 1, 2, 3)

In [121]:
P = torch.zeros([2] + max_dims)
Q = torch.zeros([2] + max_dims)

pads = flatten(tuple((x-A.size()[i], 0) for i,x in enumerate(max_dims)))[::-1]
padded = F.pad(A, pads)
P[0] = padded
pads = flatten(tuple((x-C.size()[i], 0) for i,x in enumerate(max_dims)))[::-1]
padded = F.pad(C, pads)
P[1] = padded

pads = flatten(tuple((x-B.size()[i], 0) for i,x in enumerate(max_dims)))[::-1]
padded = F.pad(B, pads)
Q[0] = padded
pads = flatten(tuple((x-D.size()[i], 0) for i,x in enumerate(max_dims)))[::-1]
padded = F.pad(D, pads)
Q[1] = padded


In [122]:
P.device

device(type='cpu')

In [123]:
%%timeit
for i in range(rank):
    axes = list(range(i)) + list(range(i + 1, rank))
    for b in range(2):        
        torch.tensordot(P[b], Q[b], dims=[axes, axes])


136 µs ± 3.61 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [124]:
def batch_matricize(A, ix):
    return A.moveaxis(ix+1, 1).reshape([A.size()[0], A.size()[ix + 1], -1])

In [125]:
%%timeit
for i in range(rank):
    torch.bmm(batch_matricize(P, i), batch_matricize(Q, i).transpose(-2,-1))

110 µs ± 4.9 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [126]:
P = P.cuda()
Q = Q.cuda()
torch.cuda.synchronize()

In [127]:
%%timeit
for i in range(rank):
    axes = list(range(i)) + list(range(i + 1, rank))
    for b in range(2):        
        torch.tensordot(P[b], Q[b], dims=[axes, axes])
torch.cuda.synchronize()

337 µs ± 8.99 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [128]:
%%timeit
for i in range(rank):
    torch.bmm(batch_matricize(P, i), batch_matricize(Q, i).transpose(-2,-1))
torch.cuda.synchronize()

216 µs ± 6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
