We illustrate using `njit`, which is able to compile a subset of python for speedup. We implement the dot product in `njit`.

In [20]:
from numba import njit


In [21]:
@njit
def dot(a, b):
    res = 0.
    for i in range(len(a)):
        res += a[i] * b[i]
    return res


In [22]:
from numpy import array


The following illustrates the difference between the first run, during which it is compiled, to the second.

In [23]:
%time dot(array([1., 2, 3]), array([2., 3, 4]))


CPU times: user 36.5 ms, sys: 0 ns, total: 36.5 ms
Wall time: 35.9 ms


20.0

In [24]:
%time dot(array([1., 2, 3]), array([2., 3, 4]))


CPU times: user 20 µs, sys: 0 ns, total: 20 µs
Wall time: 22.9 µs


20.0

Thereby only two of the three loops are implemented in Python.

In [25]:
import torch

def matmul(a, b):
    (ar, ac), (br, bc) = a.shape, b.shape
    c = torch.zeros(ar, bc)
    for i in range(ar):
        for j in range(bc):
            c[i, j] = dot(a[i, :], b[:, j])
    return c


In [26]:
from fastai.vision.all import *

pickle_path = URLs.path('mnist_png')/'mnist_png.pkl'
path = untar_data(URLs.MNIST)/'training'

if not pickle_path.exists():
    pickle_path.parent.mkdir(parents=True, exist_ok=True)
    ds = DataBlock(
        blocks = (ImageBlock(PILImageBW), CategoryBlock),
        get_items = get_image_files,
        get_y = parent_label,
        splitter = RandomSplitter(1/6, seed=0)
    ).datasets(path)

    xs, ys = zip(*ds.train, *ds.valid)
    xs = np.stack(L(map(lambda x: np.array(x, dtype=np.float32).reshape(-1), xs))) / 255.
    ys = np.array(ys, dtype=np.int64)

    x_train, x_valid = xs[:len(ds.train)], xs[len(ds.train):]
    y_train, y_valid = ys[:len(ds.train)], ys[len(ds.train):]

    save_pickle(pickle_path, [x_train, y_train, x_valid, y_valid])

    del ds, xs, ys, x_train, y_train, x_valid, y_valid

from torch import tensor

x_train, y_train, x_valid, y_valid = map(tensor, load_pickle(pickle_path))

torch.manual_seed(1)
weights = torch.randn(784, 10)
bias = torch.zeros(10)

m1 = x_valid[:5]
m2 = weights
ar, ac = m1.shape 
br, bc = m2.shape

t1 = torch.zeros(ar, bc)

for i in range(ar):         # 5
    for j in range(bc):     # 10
        for k in range(ac): # 784
            t1[i, j] += m1[i, k] * m2[k, j]


We replicate the test that we did earlier for matmul entirely in python, but now for `njit`. We will have to export `m1` and `m2` from torch tensors to numpy arrays.

In [27]:
m1a, m2a = m1.numpy(), m2.numpy()


We verify correctness.

In [28]:
from fastcore.test import *

test_close(t1, matmul(m1a, m2a), eps=2e-5)


And test performance, which seems to be about 100x as fast.

In [29]:
%timeit -n 50 matmul(m1a, m2a)


222 µs ± 21.6 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
