In [66]:
import numpy as np

from tensorflow.keras import datasets
from tqdm import tqdm

In [67]:
%run ../src/core.py
%run ../src/layers.py
%run ../src/loss.py
%run ../src/network.py

In [68]:
# Download MNIST dataset
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()

In [69]:
# Normalize images
train_images = np.array(train_images) / 255.0
test_images = np.array(test_images) / 255.0

# Convert labels to one-hot encoding
train_labels_ohe = np.eye(10)[train_labels][:, :, None]
test_labels_ohe = np.eye(10)[test_labels][:, :, None]

In [70]:
class TwoLayerNetwork(Network):
    def __init__(self, in_features: int, hid_features: int, out_features: int):
        super().__init__(
            [
                Flatten((28, 28)),
                Linear(in_features, hid_features),
                ReLU(),
                Linear(hid_features, out_features),
                Softmax(),
            ],
            CrossEntropy(),
        )

In [71]:
np.random.seed(42)

model = TwoLayerNetwork(784, 100, 10)
model.train(
    train_images,
    train_labels_ohe,
    epochs=15,
    lr=0.01,
)

Epoch 1/15


100%|██████████| 47837/47837 [00:25<00:00, 1898.83it/s, loss=0.232]


Evaluating... Validation accuracy: 94.73%
Epoch 2/15


100%|██████████| 47965/47965 [00:25<00:00, 1859.02it/s, loss=0.11] 


Evaluating... Validation accuracy: 96.75%
Epoch 3/15


100%|██████████| 47938/47938 [00:24<00:00, 1976.50it/s, loss=0.0813]


Evaluating... Validation accuracy: 97.34%
Epoch 4/15


100%|██████████| 48103/48103 [00:24<00:00, 1932.34it/s, loss=0.0604]


Evaluating... Validation accuracy: 96.95%
Epoch 5/15


100%|██████████| 48022/48022 [00:25<00:00, 1882.03it/s, loss=0.0503]


Evaluating... Validation accuracy: 97.46%
Epoch 6/15


100%|██████████| 48100/48100 [00:23<00:00, 2034.66it/s, loss=0.0417]


Evaluating... Validation accuracy: 98.55%
Epoch 7/15


100%|██████████| 47978/47978 [00:23<00:00, 2036.53it/s, loss=0.0355]


Evaluating... Validation accuracy: 98.34%
Epoch 8/15


100%|██████████| 47991/47991 [00:23<00:00, 2083.64it/s, loss=0.0293]


Evaluating... Validation accuracy: 98.45%
Epoch 9/15


100%|██████████| 48107/48107 [00:22<00:00, 2170.53it/s, loss=0.024] 


Evaluating... Validation accuracy: 99.03%
Epoch 10/15


100%|██████████| 47905/47905 [00:24<00:00, 1989.14it/s, loss=0.0199]


Evaluating... Validation accuracy: 98.93%
Epoch 11/15


100%|██████████| 47996/47996 [00:25<00:00, 1848.19it/s, loss=0.017] 


Evaluating... Validation accuracy: 98.83%
Epoch 12/15


100%|██████████| 47860/47860 [00:21<00:00, 2226.53it/s, loss=0.0152]


Evaluating... Validation accuracy: 99.03%
Epoch 13/15


100%|██████████| 47907/47907 [00:22<00:00, 2088.36it/s, loss=0.0112]


Evaluating... Validation accuracy: 99.23%
Epoch 14/15


100%|██████████| 47993/47993 [00:26<00:00, 1821.44it/s, loss=0.00942]


Evaluating... Validation accuracy: 99.31%
Epoch 15/15


100%|██████████| 47938/47938 [00:24<00:00, 1948.57it/s, loss=0.00708]


Evaluating... Validation accuracy: 99.44%


In [74]:
# Evaluate model on test set
acc = model.evaluate(test_images, test_labels_ohe)
print(f"Accuracy: {acc:.2f}")

Accuracy: 0.98
