In [1]:
# lib imports
import pandas as pd
import numpy as np

In [2]:
# class imports

from Layers.DenseLayer import DenseLayer
from ActivationClasses.ReluActivation import ReluActivation
from ActivationClasses.SoftmaxActivation import Softmax
from Loss.CategoricalCrossEntropy import CategoricalCrossEntropy
from Optimizer.Adam import Optimizer_Adam

In [3]:
# data load

data_train = pd.read_csv("./Datasets/mnist_train.csv")
data_test = pd.read_csv("./Datasets/mnist_test.csv")

In [4]:
X_train, y_train = data_train.iloc[:, 1:].values / 255.0, data_train.iloc[:, 0].values
X_test, y_test =  data_test.iloc[:, 1:].values / 255.0, data_test.iloc[:, 0].values

In [11]:
# layers / network params

l1 = DenseLayer(784, 128)
act1 = ReluActivation()

l2 = DenseLayer(128, 64)
act2 = ReluActivation()

l3 = DenseLayer(64, 10)
act3 = Softmax()

loss = CategoricalCrossEntropy()
optimizer = Optimizer_Adam(learning_rate=0.001)

In [6]:
# train
epochs = 200

In [7]:
for epoch in range(epochs + 1):
    # forward prop
    l1.forward_prop(X_train)
    act1.forward(l1.output)

    l2.forward_prop(act1.output)
    act2.forward(l2.output)

    l3.forward_prop(act2.output)
    act3.forward(l3.output)

    # loss calc
    loss_fin = loss.calculate(act3.output, y_train)

    # back prop

    loss.backward(act3.output, y_train)
    act3.backward(loss.dinputs)
    l3.backward(act3.dinputs)
    act2.backward(l3.dinputs)
    l2.backward(act2.dinputs)
    act1.backward(l2.dinputs)
    l1.backward(act1.dinputs)

    # optimise

    optimizer.update_params(l1)
    optimizer.update_params(l2)
    optimizer.update_params(l3)
    predictions = np.argmax(act3.output, axis=1)
    accuracy = np.mean(predictions == y_train)
    print(f'Epoch: {epoch}, Loss: {loss_fin}, Accuracy: {accuracy}')

# result
predictions = np.argmax(act3.output, axis=1)
accuracy = np.mean(predictions == y_train)

print(f'Final accuracy: {accuracy}')
print(f'Final loss: {loss_fin}')
    

Epoch: 0, Loss: 2.381284479245227, Accuracy: 0.0656
Epoch: 1, Loss: 2.309415134296257, Accuracy: 0.09151666666666666
Epoch: 2, Loss: 2.2550914121932526, Accuracy: 0.13308333333333333
Epoch: 3, Loss: 2.2028763031357617, Accuracy: 0.19218333333333334
Epoch: 4, Loss: 2.150103367525906, Accuracy: 0.265
Epoch: 5, Loss: 2.095339609559719, Accuracy: 0.34446666666666664
Epoch: 6, Loss: 2.03747232763904, Accuracy: 0.41646666666666665
Epoch: 7, Loss: 1.9753637054590467, Accuracy: 0.4748
Epoch: 8, Loss: 1.907879474680996, Accuracy: 0.51775
Epoch: 9, Loss: 1.8340210196619533, Accuracy: 0.5491833333333334
Epoch: 10, Loss: 1.753294707888769, Accuracy: 0.5767166666666667
Final accuracy: 0.5767166666666667
Final loss: 1.753294707888769


In [8]:
# test

correct = 0
for i in range(len(X_test)):
    # Forward pass
    l1.forward_prop(X_test[i])
    act1.forward(l1.output)

    l2.forward_prop(act1.output)
    act2.forward(l2.output)

    l3.forward_prop(act2.output)
    act3.forward(l3.output)

    # Calculate loss and accuracy for each test sample

    prediction = np.argmax(act3.output)
    if prediction == y_test[i]:
        correct += 1

    # Print label and guess for each test sample
    print(f'Label: {y_test[i]}, Guess: {prediction}')

# Calculate overall accuracy and loss
accuracy = correct / len(X_test)
print(f'Test Accuracy: {accuracy}')


Label: 7, Guess: 7
Label: 2, Guess: 3
Label: 1, Guess: 1
Label: 0, Guess: 0
Label: 4, Guess: 2
Label: 1, Guess: 1
Label: 4, Guess: 9
Label: 9, Guess: 6
Label: 5, Guess: 0
Label: 9, Guess: 7
Label: 0, Guess: 0
Label: 6, Guess: 0
Label: 9, Guess: 9
Label: 0, Guess: 0
Label: 1, Guess: 1
Label: 5, Guess: 3
Label: 9, Guess: 9
Label: 7, Guess: 7
Label: 3, Guess: 3
Label: 4, Guess: 9
Label: 9, Guess: 7
Label: 6, Guess: 6
Label: 6, Guess: 0
Label: 5, Guess: 6
Label: 4, Guess: 9
Label: 0, Guess: 0
Label: 7, Guess: 7
Label: 4, Guess: 0
Label: 0, Guess: 0
Label: 1, Guess: 1
Label: 3, Guess: 3
Label: 1, Guess: 6
Label: 3, Guess: 3
Label: 4, Guess: 0
Label: 7, Guess: 7
Label: 2, Guess: 2
Label: 7, Guess: 7
Label: 1, Guess: 1
Label: 2, Guess: 3
Label: 1, Guess: 1
Label: 1, Guess: 1
Label: 7, Guess: 7
Label: 4, Guess: 4
Label: 2, Guess: 6
Label: 3, Guess: 3
Label: 5, Guess: 3
Label: 1, Guess: 3
Label: 2, Guess: 2
Label: 4, Guess: 9
Label: 4, Guess: 7
Label: 6, Guess: 6
Label: 3, Guess: 6
Label: 5, Gu