In [1]:
%config IPCompleter.greedy=True

import numpy as np
from mnist import load_data
import utils

In [277]:
training, validation, testing = load_data()
train_x, train_y, test_x, test_y = training[0], training[1], testing[0], testing[1]
train_y = utils.one_hot_encoding(train_y)
test_y = utils.one_hot_encoding(test_y)

class Network():
    def __init__(self, layers, lr=0.00001, epochs=10):
        self.n_layers = len(layers)
        self.layers = layers
        w1 = np.random.randn(784, layers[0])
        w_last = np.random.randn(layers[len(layers)-1], 10)
        self.weights = [np.random.randn(x, y) for x, y in zip(layers[:-1], layers[1:])]
        self.weights.insert(0, w1)
        self.weights.append(w_last)
        self.weights = np.asarray(self.weights)
        
        self.biases = [np.random.randn(y) for x, y in zip(layers[:-1], layers[1:])]
        b1 = np.random.randn(1, layers[0])
        b_last = np.random.randn(1, 10)
        self.biases.insert(0, b1)
        self.biases.append(b_last)
        self.biases = np.array(self.biases)
        
        self.lr = lr
        self.epochs = epochs
        self.batch_size = 50000

    def feed_forward(self, inputs):
        activations = [inputs]
        z_vec = []
        for w, b in zip(self.weights, self.biases):
            z = np.dot(inputs, w) + b
            inputs = utils.sigmoid(z)
            z_vec.append(z)
            activations.append(inputs)
        inputs = utils.softmax(inputs)
        return inputs, activations, z_vec
    
    def compute_loss(self, logits, labels, epsilon=np.finfo(float).eps):
        return -np.sum(np.multiply(labels, np.log10(logits+epsilon)))/logits.shape[0]
    
    def get_gradients(self, logits, labels, activations, z_vec):
        nabla_w = [np.empty(w.shape) for w in self.weights]
        nabla_b = [np.empty(b.shape) for b in self.biases]
        w_last = self.weights[-1]
        error = logits - labels
        w_last_g = np.dot(activations[-2].T, error)
        nabla_w[-1] = w_last_g
        error = error * utils.sigmoid_prime(z_vec[-1])
        errors = []
        errors.append(error)
        for i in range(len(activations) - 3, -1, -1):
            activation = activations[i]
            weight = self.weights[i+1].T
            new_error = np.dot(errors[-1], weight) * utils.sigmoid_prime(z_vec[i])
            dB = new_error
            dW = np.dot(activation.T, new_error)
            nabla_w[i] = dW
            nabla_b[i] = dB
            errors.append(new_error)
        return nabla_w, nabla_b
    
    def update_weights_and_biases(self, nabla_w, nabla_b):
        for dW in nabla_w:
            dW *= self.lr/self.batch_size
        for dB in nabla_b:
            dB *- self.lr/self.batch_size
        self.weights -= nabla_w
        self.biases -= nabla_b
    
    def train(self):
        for i in range(self.epochs):
            logits, activations, z = self.feed_forward(train_x)
            nabla_w, nabla_b = self.get_gradients(logits=logits, labels=train_y, activations=activations, z_vec=z)
            self.update_weights_and_biases(nabla_w, nabla_b)
    
    def test(self):
        for i in range(5):
            logits, activations, z = self.feed_forward(test_x[i])
            prediction = np.argmax(logits, axis=1)
            print(prediction)


In [278]:
model = Network([10, 20])
model.train()
model.test()
t = np.array([1, 2, 3, 4, 5, 6])

[[2.39731197e-06 2.06357817e-06 1.91964464e-06 ... 2.36678569e-06
  2.02681642e-06 2.19020918e-06]
 [2.41283494e-06 1.95849668e-06 2.02677022e-06 ... 2.35136095e-06
  2.07366892e-06 2.21471978e-06]
 [2.40193194e-06 2.03554605e-06 1.97593215e-06 ... 2.34866356e-06
  2.03644556e-06 2.19774093e-06]
 ...
 [2.40225201e-06 2.00591946e-06 1.97573306e-06 ... 2.33927564e-06
  2.10016004e-06 2.16329584e-06]
 [2.40187598e-06 2.03650918e-06 1.97506326e-06 ... 2.34763048e-06
  2.04071241e-06 2.19547451e-06]
 [2.40220847e-06 2.01264269e-06 1.97194872e-06 ... 2.33884800e-06
  2.09785627e-06 2.15729624e-06]]
[0 0 0 ... 0 0 0]
[[2.08217695e-06 2.16059593e-06 2.13219900e-06 ... 2.42091853e-06
  1.70241381e-06 1.87953090e-06]
 [2.24613769e-06 2.08367819e-06 2.20759708e-06 ... 2.41970615e-06
  1.74938024e-06 1.93913173e-06]
 [2.11042975e-06 2.13964375e-06 2.18201165e-06 ... 2.41976857e-06
  1.72677906e-06 1.89079865e-06]
 ...
 [2.12003886e-06 2.12650120e-06 2.18081220e-06 ... 2.41923884e-06
  1.79027993e-