# MNIST with Two Layer Neural Network

In [1]:
import numpy as np

from tensorflow.keras import datasets
from tqdm import tqdm

In [2]:
%run ../src/core.py
%run ../src/layers.py
%run ../src/loss.py
%run ../src/network.py
%run ../src/metrics.py

In [3]:
# Download MNIST dataset
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


In [4]:
# Normalize images
train_images = np.array(train_images) / 255.0
test_images = np.array(test_images) / 255.0

# Convert labels to one-hot encoding
train_labels_ohe = np.eye(10)[train_labels][:, :, None]
test_labels_ohe = np.eye(10)[test_labels][:, :, None]

## Model and training

In [5]:
class TwoLayerNetwork(Network):
    def __init__(self, in_features: int, hid_features: int, out_features: int):
        super().__init__(
            [
                Flatten((28, 28)),
                Linear(in_features, hid_features),
                ReLU(),
                Linear(hid_features, out_features),
                Softmax(),
            ],
            CrossEntropy(),
            argmax_equal,
        )

In [6]:
np.random.seed(42)

model = TwoLayerNetwork(784, 100, 10)
model.train(
    train_images,
    train_labels_ohe,
    epochs=15,
    lr=0.01,
)

Epoch 1/15


100%|██████████| 47837/47837 [00:20<00:00, 2279.02it/s, loss=0.229]


Evaluating... Validation score: 94.86147%
Epoch 2/15


100%|██████████| 47965/47965 [00:20<00:00, 2298.85it/s, loss=0.109]


Evaluating... Validation score: 96.53511%
Epoch 3/15


100%|██████████| 47938/47938 [00:22<00:00, 2131.68it/s, loss=0.0809]


Evaluating... Validation score: 97.38849%
Epoch 4/15


100%|██████████| 48103/48103 [00:20<00:00, 2325.46it/s, loss=0.0597]


Evaluating... Validation score: 97.39430%
Epoch 5/15


100%|██████████| 48022/48022 [00:21<00:00, 2231.69it/s, loss=0.0517]


Evaluating... Validation score: 97.16146%
Epoch 6/15


100%|██████████| 48100/48100 [00:20<00:00, 2364.66it/s, loss=0.0431]


Evaluating... Validation score: 98.51261%
Epoch 7/15


100%|██████████| 47978/47978 [00:20<00:00, 2299.14it/s, loss=0.035] 


Evaluating... Validation score: 98.28647%
Epoch 8/15


100%|██████████| 47991/47991 [00:20<00:00, 2365.90it/s, loss=0.0285]


Evaluating... Validation score: 98.62603%
Epoch 9/15


100%|██████████| 48107/48107 [00:20<00:00, 2373.81it/s, loss=0.0244]


Evaluating... Validation score: 99.06668%
Epoch 10/15


100%|██████████| 47905/47905 [00:20<00:00, 2346.46it/s, loss=0.0218]


Evaluating... Validation score: 98.83423%
Epoch 11/15


100%|██████████| 47996/47996 [00:21<00:00, 2217.11it/s, loss=0.0166]


Evaluating... Validation score: 98.93369%
Epoch 12/15


100%|██████████| 47860/47860 [00:19<00:00, 2395.81it/s, loss=0.0154]


Evaluating... Validation score: 99.26689%
Epoch 13/15


100%|██████████| 47907/47907 [00:20<00:00, 2368.36it/s, loss=0.0121]


Evaluating... Validation score: 99.23923%
Epoch 14/15


100%|██████████| 47993/47993 [00:22<00:00, 2140.77it/s, loss=0.0108]


Evaluating... Validation score: 99.28375%
Epoch 15/15


100%|██████████| 47938/47938 [00:20<00:00, 2378.53it/s, loss=0.00848]


Evaluating... Validation score: 99.51915%


In [7]:
# Evaluate model on test set
acc = model.evaluate(test_images, test_labels_ohe)
print(f"Accuracy: {acc:.2f}")

Accuracy: 0.98
