In [1]:
import torch
from torch import tensor


We now define a matmul that is vectorized over the second dimension, using broadcasting. We require that `ac` and `br` have the same dimensions.

```
a[i, :, None]     : ac x 1  x trailing
b                 : br x bc x trailing
a[i, :, None] * b : ac x bc x trailing
c[i]              : bc x trailing
c                 : ar x bc x trailing
```

In [2]:
def matmul(a, b):
    (ar, ac), (br, bc) = a.shape, b.shape
    c = torch.zeros(ar, bc)
    for i in range(ar):
#       c[i,j] = (a[i,:] * b[:,j]).sum()      # previous version
        c[i]   = (a[i, :, None] * b).sum(dim=0) # broadcast version
    return c


We setup our test data.

In [3]:
from fastai.vision.all import *

pickle_path = URLs.path('mnist_png')/'mnist_png.pkl'
path = untar_data(URLs.MNIST)/'training'

if not pickle_path.exists():
    pickle_path.parent.mkdir(parents=True, exist_ok=True)
    ds = DataBlock(
        blocks = (ImageBlock(PILImageBW), CategoryBlock),
        get_items = get_image_files,
        get_y = parent_label,
        splitter = RandomSplitter(1/6, seed=0)
    ).datasets(path)

    xs, ys = zip(*ds.train, *ds.valid)
    xs = np.stack(L(map(lambda x: np.array(x, dtype=np.float32).reshape(-1), xs))) / 255.
    ys = np.array(ys, dtype=np.int64)

    x_train, x_valid = xs[:len(ds.train)], xs[len(ds.train):]
    y_train, y_valid = ys[:len(ds.train)], ys[len(ds.train):]

    save_pickle(pickle_path, [x_train, y_train, x_valid, y_valid])

    del ds, xs, ys, x_train, y_train, x_valid, y_valid

import torch
from torch import tensor

x_train, y_train, x_valid, y_valid = map(tensor, load_pickle(pickle_path))

torch.manual_seed(1)
weights = torch.randn(784, 10)
bias = torch.zeros(10)

m1 = x_valid[:5]
m2 = weights
ar, ac = m1.shape 
br, bc = m2.shape

t1 = torch.zeros(ar, bc)

for i in range(ar):         # 5
    for j in range(bc):     # 10
        for k in range(ac): # 784
            t1[i, j] += m1[i, k] * m2[k, j]


Verify correctness.

In [4]:
from fastcore.all import *

test_close(t1, matmul(m1, m2), 2e-5)


So this is about a 5x speedup.

In [5]:
%timeit -n 50 _ = matmul(m1, m2)


74.4 µs ± 7.34 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)


We are now ready to run this over the whole dataset:

In [6]:
tr = matmul(x_train, weights)
tr, tr.shape


(tensor([[  3.7516,  11.3573, -11.6573,  ...,   0.1467,  -5.5152,  27.7418],
         [  0.5857,   1.7986, -10.1528,  ...,   6.3370, -31.0563,   8.5035],
         [ -5.6874,  -2.5619,   6.2130,  ...,  -1.1058,  -8.1474,  -7.1438],
         ...,
         [  0.2533,   7.9083, -19.5304,  ...,  -6.1203, -12.0303,   8.3984],
         [  5.4290,   1.5903, -12.1485,  ...,   2.1751,  -7.8525,   5.2241],
         [  2.8835,  -1.4121,  -5.7172,  ..., -16.5278, -30.9194,  17.2718]]),
 torch.Size([50000, 10]))

In [7]:
%time _ = matmul(x_train, weights)


CPU times: user 710 ms, sys: 0 ns, total: 710 ms
Wall time: 664 ms


Pytorch provides a `matmul` function that can be invoked as `@`.

In [8]:
test_close(tr, x_train @ weights, eps=1e-4)


In [9]:
%timeit -n 5 _=torch.matmul(x_train, weights)


27.9 ms ± 1.65 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)
