In [1]:
import torch
from torch import tensor


In [2]:
from pathlib import Path
from urllib.request import urlretrieve
import gzip, pickle

MNIST_URL='https://github.com/mnielsen/neural-networks-and-deep-learning/blob/master/data/mnist.pkl.gz?raw=true'
path_data = Path('data')
path_data.mkdir(exist_ok=True)
path_gz = path_data/'mnist.pkl.gz'

if not path_gz.exists():
    urlretrieve(MNIST_URL, path_gz)

with gzip.open(path_gz, 'rb') as f:
    ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')

x_train,y_train,x_valid,y_valid = map(tensor, (x_train,y_train,x_valid,y_valid))

torch.manual_seed(1)
weights = torch.randn(784, 10)
bias = torch.zeros(10)

m1 = x_valid[:5]
m2 = weights
ar, ac = m1.shape 
br, bc = m2.shape

t1 = torch.zeros(ar, bc)

for i in range(ar):         # 5
    for j in range(bc):     # 10
        for k in range(ac): # 784
            t1[i, j] += m1[i, k] * m2[k, j]


Einstein summation.

The dot product looks like `ik,kj->ij`. The arrays are multiplied along the 2nd dim of the 1st tensor and the 1st dim of the 2nd tensor, creating a 3-dim array that can be represented as `ijk`. Then the output array is summed along the 3rd dimension.

Specified in the format e.g. `ikl,lkj->ik`. Here, we take the 2nd and 3rd dims of the 1st array and the 2nd and 1st dims of the first array and form a product that can be represented as `ijkl`; we then sum along the `j` and the `l` dimension.

```
p = zeros(sza0, szb2, sza1, sza3)
for i in sza0:
    for all j in szb2:
        for all k in sza1:
            for all l in sza3:
                p[i, j, k, l] += a[i, k, l] * b[l, k, j]
return p.sum(dims=(1, 3))
```
there are more subtleties and aspects to einsum that are not mentioned here; see docs.

In [3]:
mr = torch.einsum('ik,kj->ikj', m1, m2)
mr.shape


torch.Size([5, 784, 10])

In [4]:
def matmul(a, b):
    return torch.einsum('ik,kj->ij', a, b)


In [5]:
%timeit -n 5 _ = matmul(x_train, weights)


The slowest run took 10.95 times longer than the fastest. This could mean that an intermediate result is being cached.
17.6 ms ± 21.8 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)
