In [1]:
import pickle
import numpy as np

In [2]:
with open("data/train_data.pkl", "rb") as train_file:
    train_data = pickle.load(train_file)

In [3]:
# Define the network

img_px = 28
n_i = 512
n_1 = 128
n_2 = 64
n_L = 10 # there are 10 classes in FashionMNIST

k_i = 1/np.sqrt( img_px * img_px )
k_1 = 1/np.sqrt( n_i * n_i )
k_2 = 1/np.sqrt( n_1 * n_1)
k_L = 1/np.sqrt( n_2 * n_2 )

w_i = np.random.uniform( -k_i, k_i, (img_px*img_px, n_i) )
b_i = np.random.normal( -k_i, k_i, size=n_i )

w_1 = np.random.normal( -k_1, k_1, size=(n_i, n_1) )
b_1 = np.random.normal( -k_1, k_1, size=n_1 )

w_2 = np.random.normal( -k_2, k_2, size=(n_1, n_2))
b_2 = np.random.normal( -k_2, k_2, size=n_2 )

w_L = np.random.normal( -k_L, k_L, size=(n_2, n_L) )
b_L = np.random.normal( -k_L, k_L, size=n_L )

In [14]:
# forward function, used to make predictions

def softmax(x):
    exp_x = np.exp(x)
    return exp_x / exp_x.sum()

def forward(x, return_activations=False):
    leak = 0.01
    x = x.flatten()
    
    z_i = x @ w_i + b_i
    a_i = np.maximum(0, z_i)

    z_1 = a_i @ w_1 + b_1
    a_1 = np.maximum(0, z_1)

    z_2 = a_1 @ w_2 + b_2
    a_2 = np.maximum(0, z_2)

    z_L = a_2 @ w_L + b_L
    a_L = softmax(z_L)

    if return_activations:
        return a_L, (a_i, a_1, a_2)
    else:
        return a_L

forward(train_data[0][0])

array([0.09910156, 0.10134875, 0.10126736, 0.09972461, 0.09966782,
       0.09801581, 0.0988231 , 0.10154293, 0.10076158, 0.09974649])

In [5]:
# Cross-entropy loss

def loss(y, a_L):
    return - (y * np.log(a_L)).sum()   

loss(train_data[0][1], forward(train_data[0][0].flatten()))

2.3051234281895785

In [17]:
m = np.array([
[1, 2, 3],
[4, 5, 6],
[7, 8, 9]
])

a = np.array([1, 2, 3])

a @ m

array([30, 36, 42])

In [19]:
i = 2

a[0] * m[0,i] + a[1] * m[1, i] + a[2] * m[2, i]

42

In [29]:
a.reshape(-1, 1) * np.ones((3, 4))

array([[1., 1., 1., 1.],
       [2., 2., 2., 2.],
       [3., 3., 3., 3.]])

In [30]:
a, a.reshape(-1, 1)

(array([1, 2, 3]),
 array([[1],
        [2],
        [3]]))