In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from pathlib import Path
from IPython.core.debugger import set_trace
from fastai import datasets
import pickle, gzip, math, torch, matplotlib as mpl
import matplotlib.pyplot as plt
from torch import tensor

In [None]:
#export
import operator

def test(a, b, cmp, cname=None):
    if cname is None: cname = cmp.__name__
    assert cmp(a, b), f"{cname}:\n{a}\n{b}"
def test_eq(a, b):
    test(a, b, operator.eq, '==')

In [None]:
MNIST_URL='http://deeplearning.net/data/mnist/mnist.pkl'
path = datasets.download_data(MNIST_URL, ext='.gz')
path

In [None]:
with gzip.open(path, 'rb') as f:
    ((x_tr, y_tr), (x_vl, y_vl), _) = pickle.load(f, encoding='latin-1')

In [None]:
x_tr, y_tr, x_vl, y_vl = map(tensor, (x_tr, y_tr, x_vl, y_vl))

In [None]:
x = x_tr[0]

In [None]:
mpl.rcParams['image.cmap'] = 'gray'

In [None]:
plt.imshow (x.view(28,28))

In [None]:
ws = torch.randn(784, 10)
b = torch.zeros(10)
xs = x_tr[:5]
ws.shape,b.shape, xs.shape

In [None]:
def matmul(a, b):
    ar, ac = a.shape
    br, bc = b.shape
    assert ac == br
    c = torch.zeros(ar, bc)
    for i in range(ar):
        for j in range(bc):
            for k in range(ac):
                c[i,j] += a[i,k]*b[k,j]
    return c

In [None]:
c1 = matmul(xs, ws)

In [None]:
c2 = xs@ws

In [None]:
xs[0].shape, ws[:,0].shape

In [None]:
c1.shape, c2.shape

In [None]:
%time c = matmul(xs, ws)

In [None]:
def matmul1(a, b):
    ar, ac = a.shape
    br, bc = b.shape
    assert ac == br
    c = torch.zeros(ar, bc)
    for i in range(ar):
        for j in range(bc):
            c[i][j] = (a[i,:]*b[:,j]).sum()
    return c

In [None]:
%time c = matmul1(xs, ws)

[1,2,3] [1,2,3,4,5] <br />
[4,5,6] [1,2,3,4,5] <br />
[7,8,9] [1,2,3,4,5] <br />

[1,2,3] => [[1],[2],[3]] => broadcast along column <br />
[1,1,1,1,1] <br />
[2,2,2,2,2] <br /> 
[3,3,3,3,3] <br />
mult to <br />
[1,2,3,4,5] <br />
[1,2,3,4,5] <br />
[1,2,3,4,5] <br />

sum along columns <br />
Rule of thumb to find axis:  <br />
take out the shape of matrix <br />
say shape (5,5) <br />
across rows (along columns) is axis 0 <br />
across columns (along rows) is axis 1 <br />
Think in terms of across, not along <br />

In [None]:
m = tensor([[1., 2, 3], [4,5,6], [7,8,9]])
m.T.sum(-1)

In [None]:
m[...].shape, m.T.shape

In [None]:
m[0,:,None],m

In [None]:
def matmul2(a, b):
    ar, ac = a.shape
    br, bc = b.shape
    assert ac == br
    c = torch.zeros(ar, bc)
    for i in range(ar):
        c[i] = (a[i,:,None]*b).sum(0)
    return c

In [None]:
%time c = matmul2(xs, ws)

In [None]:
#export
def near(a,b): return torch.allclose(a, b, rtol=1e-3, atol=1e-5)
def test_near(a,b): test(a,b,near)

In [None]:
def matmul3(a, b): return torch.einsum('ij,jk->ik', xs, ws)

In [None]:
test_near(matmul(xs, ws), matmul1(xs, ws))
test_near(matmul(xs, ws), matmul2(xs, ws))
test_near(matmul3(xs,ws), matmul2(xs,ws))

In [None]:
%time t2 = matmul2(xs,ws)
%time t3 = matmul3(xs, ws)
%time t4 = xs.matmul(ws)
%time t5 = xs@ws

In [None]:
!python3 notebook2script.py matmul.ipynb --destName=tests