In [1]:
#export
import torch
from torch import tensor

In [2]:
m1=torch.rand([10,56,85])
m2=torch.rand([10,85,56])

In [3]:
%timeit torch.matmul(m1,m2)
torch.matmul(m1,m2).shape

89.7 µs ± 1.51 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


torch.Size([10, 56, 56])

In [4]:
def naive_matmul(m1,m2):
    out=torch.zeros([len(m1),len(m1[0]),len(m2[0][0])])
    for i in range(0,len(m1)):
        for ix in range(len(m1[0])):
            for iy in range(len(m2[0][0])):
                out[i,ix,iy]=(m1[i][ix]*m2[i][:,iy]).sum()
    return out

In [5]:
%timeit torch.matmul(m1,m2)
%timeit naive_matmul(m1,m2)

90.6 µs ± 816 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
1.06 s ± 8.95 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
def less_naive_matmul(m1,m2):
    out=torch.zeros([len(m1),len(m1[0]),len(m2[0][0])])
    for i in range(0,len(m1)):
        for ix in range(len(m1[0])):
                out[i,ix]=(m1[i][ix][:,None]*m2[i]).sum(dim=0)
    return out

In [7]:
%timeit torch.matmul(m1,m2)
%timeit naive_matmul(m1,m2)
%timeit less_naive_matmul(m1,m2)

96.2 µs ± 3.86 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
1.06 s ± 3.56 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
26.7 ms ± 30.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [8]:
def lesser_naive_matmul(m1,m2):
    out=torch.zeros([len(m1),len(m1[0]),len(m2[0][0])])
    for i in range(0,len(m1)):
        (m1[i].unsqueeze(-1)*m2[i]).sum(-2)
    return out

In [9]:
%timeit torch.matmul(m1,m2)
%timeit naive_matmul(m1,m2)
%timeit less_naive_matmul(m1,m2)
%timeit lesser_naive_matmul(m1,m2)

90.6 µs ± 3.05 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
1.07 s ± 2.77 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
26.7 ms ± 18.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
3.39 ms ± 946 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)


Thought is would be a good idea to move all the way down to a fully broadcasted version

In [10]:
def my_matmul(m1,m2):
    return (m1[...,None]*m2[:][:,None]).sum(-2)

In [11]:
%timeit torch.matmul(m1,m2)
%timeit naive_matmul(m1,m2)
%timeit less_naive_matmul(m1,m2)
%timeit lesser_naive_matmul(m1,m2)
%timeit my_matmul(m1,m2)

90.8 µs ± 3.11 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
1.06 s ± 1.05 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
26.6 ms ± 25.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
3.38 ms ± 993 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.02 ms ± 1.39 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Thought changing m1 to shape [10, 54, 1, 85] and [10, 1, 56, 85] to allow for [85]*[56,85] would inprove perf. This was not the case. 

In [12]:
def alt_matmul(m1,m2):
    return (m1[:,:,None]*m2.transpose(1,2)[:,None]).sum(-1)

In [13]:
%timeit torch.matmul(m1,m2)
%timeit naive_matmul(m1,m2)
%timeit less_naive_matmul(m1,m2)
%timeit lesser_naive_matmul(m1,m2)
%timeit my_matmul(m1,m2)
%timeit alt_matmul(m1,m2)

88.2 µs ± 1.58 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
1.07 s ± 1.93 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
26.3 ms ± 52.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
3.38 ms ± 913 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.02 ms ± 4.86 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.03 ms ± 3.73 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [27]:
%timeit torch.matmul(m1,m2)
%timeit naive_matmul(m1,m2)
%timeit less_naive_matmul(m1,m2)
%timeit lesser_naive_matmul(m1,m2)
%timeit my_matmul(m1,m2)
%timeit alt_matmul(m1,m2)
%timeit torch.einsum('bik,bkj->bij',m1,m2)

89.6 µs ± 275 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
1.02 s ± 1.03 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
25.4 ms ± 54 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
3.28 ms ± 1.53 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.9 ms ± 1.14 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.9 ms ± 1.41 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
118 µs ± 417 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [28]:
%timeit torch.matmul(m1,m2)
%timeit naive_matmul(m1,m2)
%timeit less_naive_matmul(m1,m2)
%timeit lesser_naive_matmul(m1,m2)
%timeit my_matmul(m1,m2)
%timeit alt_matmul(m1,m2)
%timeit torch.einsum('bik,bkj->bij',m1,m2)
%timeit m1@m2

90.1 µs ± 728 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
1.02 s ± 2.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
25.4 ms ± 41.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
3.28 ms ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.9 ms ± 1.61 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.91 ms ± 1.44 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
118 µs ± 191 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
89.8 µs ± 990 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


## Playing with Einsums

Want to get a basic understanding of einsums, as they seem very powerful

In [29]:
e1=torch.rand([10,7,5])
e2=torch.rand([10,5,9])

In [35]:
#matmul
torch.einsum('bik,bkj->bij',e1,e2).shape

torch.Size([10, 7, 9])

Transpose

In [30]:
torch.einsum('bik->bki',e1).shape

torch.Size([10, 5, 7])

Added deminsion instead of adding up

In [38]:
torch.einsum('bik,bkj->bijk',e1,e2).shape

torch.Size([10, 7, 9, 5])

Pointless but interesting

In [41]:
torch.einsum('abc,def->abcdef',e1,e2).shape

torch.Size([10, 7, 5, 10, 5, 9])