In [1]:
import numpy as np

from tensorflow.keras import datasets
from tqdm import tqdm

In [2]:
%run ../src/core.py
%run ../src/layers.py
%run ../src/loss.py
%run ../src/network.py
%run ../src/metrics.py

In [3]:
# Download MNIST dataset
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()

In [4]:
# Normalize images
train_images = np.array(train_images) / 255.0
test_images = np.array(test_images) / 255.0

# Convert labels to one-hot encoding
train_labels_ohe = np.eye(10)[train_labels][:, :, None]
test_labels_ohe = np.eye(10)[test_labels][:, :, None]

In [5]:
class TwoLayerNetwork(Network):
    def __init__(self, in_features: int, hid_features: int, out_features: int):
        super().__init__(
            [
                Flatten((28, 28)),
                Linear(in_features, hid_features),
                ReLU(),
                Linear(hid_features, out_features),
                Softmax(),
            ],
            CrossEntropy(),
            argmax_equal,
        )

In [6]:
np.random.seed(42)

model = TwoLayerNetwork(784, 100, 10)
model.train(
    train_images,
    train_labels_ohe,
    epochs=15,
    lr=0.01,
)

Epoch 1/15


100%|██████████| 47837/47837 [00:24<00:00, 1928.38it/s, loss=0.23] 


Evaluating... Validation accuracy: 94.72992%
Epoch 2/15


100%|██████████| 47965/47965 [00:27<00:00, 1769.02it/s, loss=0.11] 


Evaluating... Validation accuracy: 96.75114%
Epoch 3/15


100%|██████████| 47938/47938 [00:23<00:00, 2018.23it/s, loss=0.0809]


Evaluating... Validation accuracy: 97.33875%
Epoch 4/15


100%|██████████| 48103/48103 [00:27<00:00, 1751.11it/s, loss=0.0607]


Evaluating... Validation accuracy: 96.94881%
Epoch 5/15


100%|██████████| 48022/48022 [00:26<00:00, 1811.19it/s, loss=0.0503]


Evaluating... Validation accuracy: 97.46201%
Epoch 6/15


100%|██████████| 48100/48100 [00:26<00:00, 1795.94it/s, loss=0.0417]


Evaluating... Validation accuracy: 98.55462%
Epoch 7/15


100%|██████████| 47978/47978 [00:25<00:00, 1894.26it/s, loss=0.0351]


Evaluating... Validation accuracy: 98.34470%
Epoch 8/15


100%|██████████| 47991/47991 [00:24<00:00, 1953.02it/s, loss=0.0291]


Evaluating... Validation accuracy: 98.45116%
Epoch 9/15


100%|██████████| 48107/48107 [00:25<00:00, 1877.95it/s, loss=0.024] 


Evaluating... Validation accuracy: 99.03304%
Epoch 10/15


100%|██████████| 47905/47905 [00:23<00:00, 2032.94it/s, loss=0.0201]


Evaluating... Validation accuracy: 98.92518%
Epoch 11/15


100%|██████████| 47996/47996 [00:25<00:00, 1878.72it/s, loss=0.0168]


Evaluating... Validation accuracy: 98.82539%
Epoch 12/15


 36%|███▋      | 17454/47860 [00:08<00:16, 1871.07it/s, loss=0.0155]

In [None]:
# Evaluate model on test set
acc = model.evaluate(test_images, test_labels_ohe)
print(f"Accuracy: {acc:.2f}")