# Lets build a model to classify MNIST digits using the layers we created

In [6]:
import numpy as np

from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

from dense import Dense
from activations import Tanh
from loss_functions import mse_loss, mse_loss_derivative
from engine import train, predict
from activations import ReLU

In [7]:
def preprocessing(x, y, limit):
    x = x.reshape(x.shape[0], 28 * 28, 1)
    x = x.astype(np.float32) / 255.0

    y = to_categorical(y)
    y = y.reshape(y.shape[0], 10, 1)

    return x[:limit], y[:limit]

In [8]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, y_train = preprocessing(x_train, y_train, 10000)
x_test, y_test = preprocessing(x_test, y_test, 20)

In [9]:
network = [
    Dense(28 * 28, 40),
    Tanh(),
    Dense(40, 10),
    Tanh()
]

In [10]:
train(network, mse_loss, mse_loss_derivative, x_train, y_train, epochs=100, learning_rate=0.001)

Epoch 1/100, Error: 0.1064
Epoch 2/100, Error: 0.0651
Epoch 3/100, Error: 0.0564
Epoch 4/100, Error: 0.0518
Epoch 5/100, Error: 0.0488
Epoch 6/100, Error: 0.0465
Epoch 7/100, Error: 0.0446
Epoch 8/100, Error: 0.0430
Epoch 9/100, Error: 0.0416
Epoch 10/100, Error: 0.0404
Epoch 11/100, Error: 0.0393
Epoch 12/100, Error: 0.0383
Epoch 13/100, Error: 0.0374
Epoch 14/100, Error: 0.0366
Epoch 15/100, Error: 0.0358
Epoch 16/100, Error: 0.0351
Epoch 17/100, Error: 0.0345
Epoch 18/100, Error: 0.0339
Epoch 19/100, Error: 0.0333
Epoch 20/100, Error: 0.0328
Epoch 21/100, Error: 0.0323
Epoch 22/100, Error: 0.0318
Epoch 23/100, Error: 0.0314
Epoch 24/100, Error: 0.0310
Epoch 25/100, Error: 0.0306
Epoch 26/100, Error: 0.0302
Epoch 27/100, Error: 0.0298
Epoch 28/100, Error: 0.0295
Epoch 29/100, Error: 0.0291
Epoch 30/100, Error: 0.0288
Epoch 31/100, Error: 0.0285
Epoch 32/100, Error: 0.0282
Epoch 33/100, Error: 0.0279
Epoch 34/100, Error: 0.0277
Epoch 35/100, Error: 0.0274
Epoch 36/100, Error: 0.0271
E

In [11]:
for x, y in zip(x_test, y_test):
    output = predict(network, x)
    print('pred:', np.argmax(output), '\ttrue:', np.argmax(y))

pred: 7 	true: 7
pred: 2 	true: 2
pred: 1 	true: 1
pred: 0 	true: 0
pred: 4 	true: 4
pred: 1 	true: 1
pred: 4 	true: 4
pred: 9 	true: 9
pred: 6 	true: 5
pred: 9 	true: 9
pred: 0 	true: 0
pred: 8 	true: 6
pred: 9 	true: 9
pred: 0 	true: 0
pred: 1 	true: 1
pred: 5 	true: 5
pred: 9 	true: 9
pred: 7 	true: 7
pred: 8 	true: 3
pred: 4 	true: 4
