In [1]:
import torch
from torch import tensor


We now define a matmul that is vectorized over the second dimension, using broadcasting. We require that `ac` and `br` have the same dimensions.

```
a[i, :, None]     : ac x 1  x trailing
b                 : br x bc x trailing
a[i, :, None] * b : ac x bc x trailing
c[i]              : bc x trailing
c                 : ar x bc x trailing
```

In [2]:
def matmul(a, b):
    (ar, ac), (br, bc) = a.shape, b.shape
    c = torch.zeros(ar, bc)
    for i in range(ar):
#       c[i,j] = (a[i,:] * b[:,j]).sum()      # previous version
        c[i]   = (a[i, :, None] * b).sum(dim=0) # broadcast version
    return c


We setup our test data.

In [3]:
from pathlib import Path
from urllib.request import urlretrieve
import gzip, pickle

MNIST_URL='https://github.com/mnielsen/neural-networks-and-deep-learning/blob/master/data/mnist.pkl.gz?raw=true'
path_data = Path('data')
path_data.mkdir(exist_ok=True)
path_gz = path_data/'mnist.pkl.gz'

if not path_gz.exists():
    urlretrieve(MNIST_URL, path_gz)

with gzip.open(path_gz, 'rb') as f:
    ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')

x_train,y_train,x_valid,y_valid = map(tensor, (x_train,y_train,x_valid,y_valid))

torch.manual_seed(1)
weights = torch.randn(784, 10)
bias = torch.zeros(10)

m1 = x_valid[:5]
m2 = weights
ar, ac = m1.shape 
br, bc = m2.shape

t1 = torch.zeros(ar, bc)

for i in range(ar):         # 5
    for j in range(bc):     # 10
        for k in range(ac): # 784
            t1[i, j] += m1[i, k] * m2[k, j]


Verify correctness.

In [4]:
from fastcore.all import *

test_close(t1, matmul(m1, m2))


So this is about a 5x speedup.

In [5]:
%timeit -n 50 _ = matmul(m1, m2)


122 µs ± 29.3 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)


We are now ready to run this over the whole dataset:

In [6]:
tr = matmul(x_train, weights)
tr, tr.shape


(tensor([[  0.9577,  -2.9557,  -2.1148,  ..., -15.0880, -17.6923,   0.6006],
         [  6.8906,  -0.3425,   0.7923,  ..., -17.1321, -25.3573,  16.2312],
         [-10.1834,   7.3808,   4.1311,  ...,  -6.7345,  -6.7882,  -1.5788],
         ...,
         [  7.4047,   7.6421,  -3.4977,  ...,  -1.0175, -16.2215,   2.0750],
         [  3.2491,   9.5212,  -9.3684,  ...,   2.9811, -19.5786,  -1.9565],
         [ 15.6963,   4.1212,  -5.6201,  ...,   8.0785, -12.2060,   0.4163]]),
 torch.Size([50000, 10]))

In [7]:
%time _ = matmul(x_train, weights)


CPU times: user 1.04 s, sys: 1.2 ms, total: 1.04 s
Wall time: 955 ms


Pytorch provides a `matmul` function that can be invoked as `@`.

In [8]:
test_close(tr, x_train@weights, eps=1e-3)


In [11]:
%timeit -n 5 _=torch.matmul(x_train, weights)


The slowest run took 10.26 times longer than the fastest. This could mean that an intermediate result is being cached.
17.6 ms ± 19.8 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)
