# Neural Networks 
In the first part of this exercise we will be implementing feedforward propagation as we did in the second part of the previous exercise. Let's first load in the data set

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.io import loadmat
from scipy import optimize

In [2]:
data = loadmat('ex4data1.mat')
X = data['X']
X = np.insert(X, 0, 1, axis=1) # insert column of 1s into X to account for bias
y = data['y']
y[y == 10] = 0 # replace all of the 10s with 0s
display(X.shape, y.shape)

(5000, 401)

(5000, 1)

Just like the previous exercise, we are provided the initial weights. There are two sets of weights, meaning that there are a total of three layers in our neural network. Let's load these weights in. 

In [11]:
weights = loadmat('ex4weights.mat')
theta1 = weights['Theta1']
theta2 = weights['Theta2']
display(theta1.shape, theta2.shape)

(25, 401)

(10, 26)

Now, we will implement the cost function and gradient for the neural network. The function `feedForward` is used as a helper function to get the predicted output.

In [34]:
def cost(nn_params, input_layer_size, hidden_layer_size, num_labels, X, y, regParam):
    theta1 = np.reshape(nn_params[:hidden_layer_size*(input_layer_size+1)], (hidden_layer_size, input_layer_size+1))
    theta2 = np.reshape(nn_params[hidden_layer_size*(input_layer_size+1):], (num_labels, hidden_layer_size+1))
    
    

In order to test out our cost function, we need to do a little bit of initialization first.

In [35]:
nn_params = np.concatenate([theta1.flatten(), theta2.flatten()])
input_layer_size = 400 # 20x20 matrix of pixels 
hidden_layer_size = 25 # 25 hidden layer units 
num_labels = 10 # 10 output units
regParam = 0

In [36]:
cost(nn_params, input_layer_size, hidden_layer_size, num_labels, X, y, regParam)

array([[-0.76100352, -1.21244498, -0.10187131, -2.36850085, -1.05778129,
        -2.20823629,  0.56383834,  1.21105294,  2.21030997,  0.44456156,
        -1.18244872,  1.04289112, -1.60558756,  1.30419943,  1.37175046,
         1.74825095, -0.23365648, -1.52014483,  1.15324176,  0.10368082,
        -0.37207719, -0.61530019, -0.1256836 , -2.27193038, -0.71836208,
        -1.29690315],
       [-0.61785176,  0.61559207, -1.26550639,  1.85745418, -0.91853319,
        -0.05502589, -0.38589806,  1.29520853, -1.56843297, -0.97026419,
        -2.18334895, -2.85033578, -2.07733086,  1.63163164,  0.3490229 ,
         1.82789117, -2.44174379, -0.8563034 , -0.2982564 , -2.07947873,
        -1.2933238 ,  0.89982032,  0.28306578,  2.31180525, -2.46444086,
         1.45656548],
       [-0.68934072, -1.94538151,  2.01360618, -3.12316188, -0.2361763 ,
         1.38680947,  0.90982429, -1.54774416, -0.79830896, -0.65599834,
         0.7353833 , -2.58593294,  0.47210839,  0.55349499,  2.51255453,
       

In [18]:
theta1

array([[-2.25623899e-02, -1.05624163e-08,  2.19414684e-09, ...,
        -1.30529929e-05, -5.04175101e-06,  2.80464449e-09],
       [-9.83811294e-02,  7.66168682e-09, -9.75873689e-09, ...,
        -5.60134007e-05,  2.00940969e-07,  3.54422854e-09],
       [ 1.16156052e-01, -8.77654466e-09,  8.16037764e-09, ...,
        -1.20951657e-04, -2.33669661e-06, -7.50668099e-09],
       ...,
       [-1.83220638e-01, -8.89272060e-09, -9.81968100e-09, ...,
         2.35311186e-05, -3.25484493e-06,  9.02499060e-09],
       [-7.02096331e-01,  3.05178374e-10,  2.56061008e-09, ...,
        -8.61759744e-04,  9.43449909e-05,  3.83761998e-09],
       [-3.50933229e-01,  8.85876862e-09, -6.57515140e-10, ...,
        -1.80365926e-06, -8.14464807e-06,  8.79454531e-09]])

In [21]:
theta2

array([[-0.76100352, -1.21244498, -0.10187131, -2.36850085, -1.05778129,
        -2.20823629,  0.56383834,  1.21105294,  2.21030997,  0.44456156,
        -1.18244872,  1.04289112, -1.60558756,  1.30419943,  1.37175046,
         1.74825095, -0.23365648, -1.52014483,  1.15324176,  0.10368082,
        -0.37207719, -0.61530019, -0.1256836 , -2.27193038, -0.71836208,
        -1.29690315],
       [-0.61785176,  0.61559207, -1.26550639,  1.85745418, -0.91853319,
        -0.05502589, -0.38589806,  1.29520853, -1.56843297, -0.97026419,
        -2.18334895, -2.85033578, -2.07733086,  1.63163164,  0.3490229 ,
         1.82789117, -2.44174379, -0.8563034 , -0.2982564 , -2.07947873,
        -1.2933238 ,  0.89982032,  0.28306578,  2.31180525, -2.46444086,
         1.45656548],
       [-0.68934072, -1.94538151,  2.01360618, -3.12316188, -0.2361763 ,
         1.38680947,  0.90982429, -1.54774416, -0.79830896, -0.65599834,
         0.7353833 , -2.58593294,  0.47210839,  0.55349499,  2.51255453,
       