# Implement matrix multplex from scratch

# Get data

In [None]:
#export
import operator
from pathlib import Path
from IPython.core.debugger import set_trace
import pickle, gzip, math, torch, matplotlib as mpl
import matplotlib.pyplot as plt
from torch import tensor

In [None]:
path = 'mnist.pkl.gz'
with gzip.open(path, 'rb') as f:
    ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')

In [None]:
#export
# unit test function
def test(a, b, cmp, cname=None):
    if cname == None:
        cname == cmp.__name__
    assert cmp(a, b), f'{cname}\n{a}\n{b}\n'

def test_eq(a, b):
    return test(a, b, operator.eq, '==')

In [None]:
a = 'LALALA'
test_eq(a, 'LALALA')

In [None]:
x_train, y_train, x_valid, y_valid = map(tensor, (x_train, y_train, x_valid, y_valid))

In [None]:
x_train.shape, y_train.shape, x_valid.shape, y_valid.shape

(torch.Size([50000, 784]),
 torch.Size([50000]),
 torch.Size([10000, 784]),
 torch.Size([10000]))

In [None]:
test_eq(x_train.shape[0], y_train.shape[0])
test_eq(x_valid.shape[0], y_valid.shape[0])

In [None]:
#export
# first try, use 3 for loop
def matmul(a, b):
    ac, ar = a.shape
    bc, br = b.shape
    assert ar == bc
    c = torch.zeros(ac, br)
    for i in range(ac):
        for j in range(br):
            for k in range(ar):
                c[i,j] += a[i,k] * b[k,j]
    return c

In [None]:
a = x_train[:5,]
b = torch.randn(784, 10)
a.shape, b.shape

(torch.Size([5, 784]), torch.Size([784, 10]))

In [None]:
%time t1 = matmul(a, b)

CPU times: user 518 ms, sys: 84 µs, total: 518 ms
Wall time: 516 ms


In [None]:
t1.shape

torch.Size([5, 10])

In [None]:
#export
def near(a,b): return torch.allclose(a, b, rtol=1e-3, atol=1e-5)
def test_near(a,b): test(a,b,near)

In [None]:
test_near(t1, a@b)

# matrix multiplication

In [None]:
#export
# second try, kill last for loop, by use element wise multiply
def matmul(a, b):
    ac, ar = a.shape
    bc, br = b.shape
    assert ar == bc
    c = torch.zeros(ac, br)
    for i in range(ac):
        for j in range(br):
            c[i,j] = (a[i, :] * b[:, j]).sum()
    return c

In [None]:
a[0].shape, b[:, 0].shape

(torch.Size([784]), torch.Size([784]))

In [None]:
%time t2 = matmul(a, b)

CPU times: user 2.91 ms, sys: 42 µs, total: 2.95 ms
Wall time: 1.94 ms


In [None]:
test_near(t2, a@b)

In [None]:
a[1,:].unsqueeze(-1).shape, b.shape

(torch.Size([784, 1]), torch.Size([784, 10]))

In [None]:
c = a[1,:].unsqueeze(-1) * b
c.shape, c.sum(0).shape, c.sum(1).shape

(torch.Size([784, 10]), torch.Size([10]), torch.Size([784]))

In [None]:
#export
# third try, kill second for loop, use broadcast
def matmul(a, b):
    ac, ar = a.shape
    bc, br = b.shape
    assert ar == bc
    c = torch.zeros(ac, br)
    for i in range(ac):
        c[i,:] = (a[i,:].unsqueeze(-1) * b).sum(0)
    return c

In [None]:
%time t3 = matmul(a, b)

CPU times: user 0 ns, sys: 2.27 ms, total: 2.27 ms
Wall time: 1.19 ms


In [None]:
test_near(t3, a@b)