In [124]:
import numpy as np

# network goes R50 --> R20 --> R20 --> R10
w_1 = np.random.normal(0,0.05,(20, 50))
b_1 = np.random.normal(0,0.05,(20, 1))

w_2 = np.random.normal(0,0.05,(10, 20))
b_2 = np.random.normal(0,0.05,(10,1))

w_3 = np.random.normal(0,0.05,(1, 10))
b_3 = np.random.normal(0,0.05,(1,1))

X = np.random.normal(0, 1, (50, 10000))
y = np.array([np.sum(X[:, x]**2) for x in range(10000)])

X_train = X[:, :8000]
X_test = X[:, 8000:]

y_train = np.sum(X_train[:8000], axis=0).reshape(1, 8000)
y_test = np.sum(X_test[8000:], axis=0).reshape(1, 2000)

def loss(y, y_pred):
    return .5*(y-y_pred)**2

def dloss(y, y_pred):
    return y-y_pred

def activation(z):
    return z * (z>0)

def dactivation(z):
    return 1 * (z>0)

test_losses = []
train_losses = []

def forward(X, params):
    w_1, b_1, w_2, b_2, w_3, b_3 = params
    z_1 = w_1 @ X + b_1
    a_1 = activation(z_1)

    z_2 = w_2 @ a_1 + b_2
    a_2 = activation(z_2)

    z_3 =  w_3 @ a_2 + b_3
    a_3 = activation(z_3)

    return a_3, a_2, a_1, z_3, z_2, z_1

epochs = 100
for _ in range(epochs):    
    a_3, a_2, a_1, z_3, z_2, z_1 = forward(X_train, (w_1, b_1, w_2, b_2, w_3, b_3))
    out_test, _, _, _, _, _ = forward(X_test, (w_1, b_1, w_2, b_2, w_3, b_3))
    
    l = loss(y_train, a_3)
    
    dl = dloss(y_train, a_3) # 1 x 8000 
    
    dZ3 = dl * dactivation(z_3) # 1 x 8000
    dW3 = np.dot(dZ3, a_2.T) / X_train.shape[1]
    db3 = np.sum(dZ3, axis=1, keepdims=True) / X_train.shape[1]
    
    dA2 = np.dot(w_3.T, dZ3)
    dZ2 = dA2 * dactivation(z_2)
    dW2 = np.dot(dZ2, a_1.T) / X_train.shape[1]
    db2 = np.sum(dZ2, axis=1, keepdims=True) / X_train.shape[1]
    
    dA1 = np.dot(w_2.T, dZ2)
    dZ1 = dA1 * dactivation(z_1)
    dW1 = np.dot(dZ1, X_train.T) / X_train.shape[1]
    db1 = np.sum(dZ1, axis=1, keepdims=True) / X_train.shape[1]

    w_3 -= 0.01*dW3
    w_2 -= 0.01*dW2
    w_1 -= 0.01*dW1

    b_3 -= 0.01*db3
    b_2 -= 0.01*db2
    b_1 -= 0.01*db1

    train_losses.append(np.mean(l))
    test_losses.append(np.mean(loss(y_test, out_test)))
