In [1]:
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils

from dense import Dense
from activations import Tanh
from losses import mse, mse_derivative
from network import train, predict


def preprocess_data(x, y, limit):
    # reshape and normalize input data
    x = x.reshape(x.shape[0], 28 * 28, 1)
    x = x.astype("float32") / 255
    # encode output which is a number in range [0,9] into a vector of size 10
    # e.g. number 3 will become [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
    y = np_utils.to_categorical(y)
    y = y.reshape(y.shape[0], 10, 1)
    return x[:limit], y[:limit]


# load MNIST from server
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, y_train = preprocess_data(x_train, y_train, 1000)
x_test, y_test = preprocess_data(x_test, y_test, 20)

# neural network
network = [
    Dense(28 * 28, 40),
    Tanh(),
    Dense(40, 10),
    Tanh()
]

# train
train(network, mse, mse_derivative, x_train, y_train, epochs=100, learning_rate=0.1)

# test
for x, y in zip(x_test, y_test):
    output = predict(network, x)
    print('pred:', np.argmax(output), '\ttrue:', np.argmax(y))

Epoch: 1/100, error=[0.91256262]
Epoch: 2/100, error=[0.81659155]
Epoch: 3/100, error=[0.77495594]
Epoch: 4/100, error=[0.73639467]
Epoch: 5/100, error=[0.68225538]
Epoch: 6/100, error=[0.56717781]
Epoch: 7/100, error=[0.38382506]
Epoch: 8/100, error=[0.2381113]
Epoch: 9/100, error=[0.16438316]
Epoch: 10/100, error=[0.13552506]
Epoch: 11/100, error=[0.12490674]
Epoch: 12/100, error=[0.11889775]
Epoch: 13/100, error=[0.11395518]
Epoch: 14/100, error=[0.10944577]
Epoch: 15/100, error=[0.10624034]
Epoch: 16/100, error=[0.10380694]
Epoch: 17/100, error=[0.10210835]
Epoch: 18/100, error=[0.10027465]
Epoch: 19/100, error=[0.09877937]
Epoch: 20/100, error=[0.09717585]
Epoch: 21/100, error=[0.09600268]
Epoch: 22/100, error=[0.0950293]
Epoch: 23/100, error=[0.09424859]
Epoch: 24/100, error=[0.09372633]
Epoch: 25/100, error=[0.09263469]
Epoch: 26/100, error=[0.09154988]
Epoch: 27/100, error=[0.090441]
Epoch: 28/100, error=[0.0891816]
Epoch: 29/100, error=[0.08857648]
Epoch: 30/100, error=[0.0880