In [13]:
import numpy as np
from numpy import load
import matplotlib.pyplot as plt

import requests, gzip, os, hashlib

np.random.seed(42)

In [84]:
mnist_data = dict(load('data/mnist.npz'))
Xtr = mnist_data['x_train'].reshape(60000, -1).T[:, :100]
Ytr = mnist_data['y_train'].reshape(-1, 1)[:100].T
Xtest = mnist_data['x_test'].reshape(10000, -1).T[:, :100]
Ytest = mnist_data['y_test'].reshape(-1, 1)[:100].T
Xtr.shape, Ytr.shape, Xtest.shape, Ytest.shape

((784, 100), (1, 100), (784, 100), (1, 100))

In [85]:
def relu(x: np.ndarray): return np.maximum(0, x)
def d_relu(x: np.ndarray): return (x>0)*1
def loss(yh: float, y: float): return -(y*np.log(yh)+(1-y)*np.log(1-yh))
def cost(yh: np.ndarray, y: np.ndarray):
    m, c, yh = len(y), 0, yh.T
    for i in range(m): c += loss(yh[i], y[i]) 
    return (c/m).squeeze()

In [86]:
class Model:
    def __init__(self, n_in: int, n_hidden: int, n_out: int):
        self.W1 = np.random.randn(n_hidden, n_in) * (2/n_in)**0.5
        self.b1 = np.zeros((n_hidden, 1))
        # add a dropout maybe
        self.a1 = relu
        self.W2 = np.random.randn(n_out, n_hidden)
        self.b2 = np.zeros((n_out, 1))
        self.a2 = relu # make this softmax

    def forward(self, x: np.ndarray):
        l1 = self.a1(np.dot(self.W1, x) + self.b1)
        l2 = self.a2(np.dot(self.W2, l1) + self.b2)
        return l2

    def backward(self, x: np.ndarray, yh: np.ndarray, y: np.ndarray):
        m = x.shape[1]
        z1 = np.dot(self.W1, x) + self.b1
        a1 = self.a1(z1)

        dz2 = yh - y
        self.dW2 = 1/m*np.dot(dz2, a1.T)
        self.db2 = 1/m*np.sum(dz2, axis=1, keepdims=True)

        dz1 = np.dot(self.W2.T, dz2) * d_relu(z1)
        self.dW1 = 1/m*np.dot(dz1, x.T)
        self.db1 = 1/m*np.sum(dz1, axis=1, keepdims=True)

    def step(self, lr: float):
        self.W1 += -lr*self.dW1
        self.W2 += -lr*self.dW2
        self.b1 += -lr*self.db1
        self.b2 += -lr*self.db2

In [122]:
m = Model(784, 128, 10)
epochs = 1
iepoch = 1

In [123]:
for epoch in range(epochs):
    for Xb, Yb in zip(Xtr.T, Ytr.T): # t so that zip works on correct axis
        Xb, Yb = Xb.reshape(-1, 1), Yb.reshape(1, -1)
        yh = m.forward(Xb)
        #c = cost(yh[0], Yb) # broken

        m.backward(Xb, yh, Yb)
        m.step(0.01)

    #if epoch % iepoch == 0: 
        #print(f'{i}: loss {c:.4f}')