In [1]:
import pickle
import numpy as np
from scipy.special import softmax

rng = np.random.default_rng()

In [2]:
with open("data/train_data.pkl", "rb") as train_file:
    train_data = pickle.load(train_file)

with open("data/test_data.pkl", "rb") as test_file:
    test_data = pickle.load(test_file)

In [3]:
# Define the network

img_px = 28 # FashinMNIST images are 28x28 pixels
n_i = 512   # Input layer, from 28*28 to 512 units
n_1 = 128   # Layer 1, from 512 to 128 units
n_2 = 64    # Layer 2, from 128 to 64 units
n_L = 10    # Layer L, from 64 to 10 units, to match the 10 classes in FashionMNIST

# To initialize the weights and biases, 
# it's usual to draw from a uniform distribution
# with boundaries defined by the layer-input size,
# as below
k_i = np.sqrt( 1 /(img_px * img_px) )
k_1 = np.sqrt( 1 / n_i )
k_2 = np.sqrt( 1 / n_1)
k_L = np.sqrt( 1 / n_2 )

w_i = rng.uniform( -k_i, k_i, (n_i, img_px*img_px) )
b_i = rng.normal( -k_i, k_i, size=n_i )

w_1 = rng.normal( -k_1, k_1, size=(n_1, n_i) )
b_1 = rng.normal( -k_1, k_1, size=n_1 )

w_2 = rng.normal( -k_2, k_2, size=(n_2, n_1))
b_2 = rng.normal( -k_2, k_2, size=n_2 )

w_L = rng.normal( -k_L, k_L, size=(n_L, n_2) )
b_L = rng.normal( -k_L, k_L, size=n_L )

In [4]:
# forward function, used to make predictions

def forward(x, return_z_a=False):
    x = x.flatten()
    
    z_i = w_i @ x + b_i
    a_i = np.maximum(0, z_i)

    z_1 = w_1 @ a_i + b_1
    a_1 = np.maximum(0, z_1)

    z_2 = w_2 @ a_1 + b_2
    a_2 = np.maximum(0, z_2)

    z_L = w_L @ a_2 + b_L
    a_L = softmax(z_L)

    if return_z_a:
        return (z_i, z_1, z_2, z_L), (a_i, a_1, a_2, a_L)
    else:
        return a_L

forward(train_data[0][0])

array([0.09958217, 0.0978552 , 0.10078088, 0.09859859, 0.09980073,
       0.0980186 , 0.10177862, 0.10044439, 0.10187438, 0.10126645])

In [5]:
# Cross-entropy loss

def loss_fn(y, a_L):
    return - (y * np.log(a_L)).sum()   


def test():
    n_samples = len(test_data)
    correct = 0
    sum_loss = 0
    for x, y in test_data:
        y_hat = forward(x)
        correct += y.argmax() == y_hat.argmax()
        sum_loss += loss_fn(y, y_hat)
    return correct / n_samples, sum_loss / n_samples        

test()

(0.1, 2.302684260091983)

In [6]:
def ReLU_derivative(x):
    r = np.ones(x.shape)
    r[x == 0] = 0
    return r

def calc_derivatives(x, y):
    x = x.flatten()
    
    (z_i, z_1, z_2, z_L), (a_i, a_1, a_2, a_L) = forward(x, True)
    
    nabla_z_L = a_L - y
    nabla_w_L = np.outer(nabla_z_L, a_2)
    
    nabla_z_2 = ReLU_derivative(z_2) * (w_L.T @ nabla_z_L)
    nabla_w_2 = np.outer(nabla_z_2, a_1)
    
    nabla_z_1 = ReLU_derivative(z_1) * (w_2.T @ nabla_z_2)
    nabla_w_1 = np.outer(nabla_z_1, a_i)
    
    nabla_z_i = ReLU_derivative(z_i) * (w_1.T @ nabla_z_1)
    nabla_w_i = np.outer(nabla_z_i, x)

    return (nabla_w_i, nabla_w_1, nabla_w_2, nabla_w_L), (nabla_z_i, nabla_z_1, nabla_z_2, nabla_z_L), loss_fn(y, a_L)

In [7]:
# Hyperparamters

learning_rate = 1e-3
mini_batch_size = 64
training_epochs = 5

n_mini_batches = len(train_data) // mini_batch_size


# Performance baseline

print("Pre-train stats")
accuracy, avg_loss = test()
print(f"==> Accuracy: {100*accuracy:0.1f}%, Avg loss: {avg_loss:>8f}\n")

Pre-train stats
==> Accuracy: 10.0%, Avg loss: 2.302684



In [8]:
# Training loop

for epoch in range(training_epochs):
    print(f"Epoch {epoch}\n" + 20*'-')
    rng.shuffle(training_index)
    for mini_batch_number in range(n_mini_batches):
        # batch average derivatives
        sum_nabla_w_i = np.zeros(w_i.shape)
        sum_nabla_w_1 = np.zeros(w_1.shape)
        sum_nabla_w_2 = np.zeros(w_2.shape)
        sum_nabla_w_L = np.zeros(w_L.shape)
        sum_nabla_b_i = np.zeros(b_i.shape)
        sum_nabla_b_1 = np.zeros(b_1.shape)
        sum_nabla_b_2 = np.zeros(b_2.shape)
        sum_nabla_b_L = np.zeros(b_L.shape)
        sum_loss = 0
        mini_batch = train_data[mini_batch_number*mini_batch_size : (mini_batch_number+1)*mini_batch_size]
        for x, y in mini_batch:
            (nabla_w_i, nabla_w_1, nabla_w_2, nabla_w_L), (nabla_b_i, nabla_b_1, nabla_b_2, nabla_b_L), loss = calc_derivatives(x, y)
            sum_nabla_w_i += nabla_w_i
            sum_nabla_w_1 += nabla_w_1
            sum_nabla_w_2 += nabla_w_2
            sum_nabla_w_L += nabla_w_L
            sum_nabla_b_i += nabla_b_i
            sum_nabla_b_1 += nabla_b_1
            sum_nabla_b_2 += nabla_b_2
            sum_nabla_b_L += nabla_b_L
            sum_loss += loss
        w_i -= learning_rate * (sum_nabla_w_i / len(mini_batch))
        w_1 -= learning_rate * (sum_nabla_w_1 / len(mini_batch))
        w_2 -= learning_rate * (sum_nabla_w_2 / len(mini_batch))
        w_L -= learning_rate * (sum_nabla_w_L / len(mini_batch))
        b_i -= learning_rate * (sum_nabla_b_i / len(mini_batch))
        b_1 -= learning_rate * (sum_nabla_b_1 / len(mini_batch))
        b_2 -= learning_rate * (sum_nabla_b_2 / len(mini_batch))
        b_L -= learning_rate * (sum_nabla_b_L / len(mini_batch))
        if mini_batch_number % 100 == 0:
            print(f"loss: {sum_loss/len(mini_batch):>8f} [mini-batch {mini_batch_number} / {n_mini_batches}]")
    # Test at the end of each epoch
    accuracy, avg_loss = test()
    print(f"==> Accuracy: {100*accuracy:0.1f}%, Avg loss: {avg_loss:>8f}\n")

Epoch 0
--------------------
loss: 2.304699 [mini-batch 0 / 937]
loss: 2.302388 [mini-batch 100 / 937]
loss: 2.301739 [mini-batch 200 / 937]
loss: 2.218086 [mini-batch 300 / 937]
loss: 2.143030 [mini-batch 400 / 937]
loss: 1.750400 [mini-batch 500 / 937]
loss: 1.242420 [mini-batch 600 / 937]
loss: 1.079896 [mini-batch 700 / 937]
loss: 0.949014 [mini-batch 800 / 937]
loss: 0.928119 [mini-batch 900 / 937]
==> Accuracy: 65.2%, Avg loss: 0.908876

Epoch 1
--------------------
loss: 0.876424 [mini-batch 0 / 937]
loss: 0.933463 [mini-batch 100 / 937]
loss: 0.676525 [mini-batch 200 / 937]
loss: 0.914220 [mini-batch 300 / 937]
loss: 0.775485 [mini-batch 400 / 937]
loss: 0.653926 [mini-batch 500 / 937]
loss: 0.749385 [mini-batch 600 / 937]
loss: 0.764485 [mini-batch 700 / 937]
loss: 0.728813 [mini-batch 800 / 937]
loss: 0.710892 [mini-batch 900 / 937]
==> Accuracy: 73.8%, Avg loss: 0.696800

Epoch 2
--------------------
loss: 0.685634 [mini-batch 0 / 937]
loss: 0.726217 [mini-batch 100 / 937]
l