In [5]:
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
import matplotlib.pyplot as plt
from Dense import Dense
from Activations import Tanh
from Losses import mse, mse_prime
from Network import train, predict

def preprocess_data(x, y, limit):
    # reshape and normalize input data
    x = x.reshape(x.shape[0], 28 * 28, 1)
    x = x.astype("float32") / 255
    # encode output which is a number in range [0,9] into a vector of size 10
    # e.g. number 3 will become [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
    y = np_utils.to_categorical(y)
    y = y.reshape(y.shape[0], 10, 1)
    return x[:limit], y[:limit]


# load MNIST from server
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, y_train = preprocess_data(x_train, y_train, 1000)
x_test, y_test = preprocess_data(x_test, y_test, 200)

# neural network
network = [
    Dense(28 * 28, 40, regularization=0.01),
    Tanh(),
    Dense(40, 10, regularization=0.01),
    Tanh()
]

# train
costs = train(network, mse, mse_prime, x_train, y_train, epochs=500, batch_size = 1, learning_rate=0.01)

# plot costs
import matplotlib.pyplot as plt
plt.plot(costs)
plt.xlabel('Iterations')
plt.ylabel('Cost')
plt.show()

# calculate error on test set
error = 0
for x, y in zip(x_test, y_test):
    output = predict(network, x)
    error += mse(y, output)
error /= len(x_test)
print("error:", error)


# test
for x, y in zip(x_test, y_test):
    output = predict(network, x)
    print('pred:', np.argmax(output), '\ttrue:', np.argmax(y))


1/500, error=0.9451583271036662
2/500, error=0.8595465721671504
3/500, error=0.7712459199676363
4/500, error=0.6478531014262534
5/500, error=0.4986427630104103
6/500, error=0.33662708941713293
7/500, error=0.19666036316440516
8/500, error=0.11354728199999488
9/500, error=0.07929164659915676
10/500, error=0.067397912592574
11/500, error=0.06264263062447036
12/500, error=0.06015890284384618
13/500, error=0.058523053688759936
14/500, error=0.057283453296261506
15/500, error=0.056319745130132895
16/500, error=0.0556142111543063
17/500, error=0.055131209326857486
18/500, error=0.054803438617930245
19/500, error=0.054568617144041445
20/500, error=0.05438657116345159
21/500, error=0.054235582416204305
22/500, error=0.05410487820630043
23/500, error=0.053989351459145116
24/500, error=0.05388656500971057
25/500, error=0.05379514647763927
26/500, error=0.053713987019712715
27/500, error=0.05364193560415926
28/500, error=0.053577770933230705
29/500, error=0.05352027640449202
30/500, error=0.05346