## MNIST Handwritten Digit Classification

### Imports

In [None]:
import os

import numpy as np
from PIL import Image
from sklearn.utils import shuffle

from net.activations.relu import ReLU
from net.activations.softmax import Softmax
from net.layers import Parameter, Dropout, Dense, Flatten, MaxPool2D, Conv2D
from net.losses.cross_entropy import CrossEntropy
from net.models.sequential import Sequential
from net.optimizers.adam import Adam

### Load Dataset

In [None]:
def load_dataset(root_dir, image_size=(28, 28), max_per_class=50):
    X, y = [], []
    for label in range(10):
        folder = os.path.join(root_dir, str(label))
        files = [
            f for f in os.listdir(folder)
            if f.endswith(".png") or f.endswith(".jpg")
        ]
        files = sorted(files)[:max_per_class]  # take up to 50 per class

        for filename in files:
            path = os.path.join(folder, filename)
            image = Image.open(path).convert("L").resize(image_size)
            X.append(np.array(image) / 255.0)
            y.append(label)

    X = np.array(X)[..., np.newaxis]  # (N, H, W, 1)
    X = np.transpose(X, (0, 3, 1, 2))  # (N, C, H, W)
    y = np.array(y)
    return X, y

In [None]:
train_X, train_y = load_dataset("data/training", max_per_class=50)
test_X, test_y = load_dataset("data/test")

train_X, train_y = shuffle(train_X, train_y, random_state=42)
test_X, test_y = shuffle(test_X, test_y, random_state=42)

print(f"Train: {train_X.shape}, {train_y.shape}")
print(f"Unique labels in train:", np.unique(train_y, return_counts=True))

### Data Preprocessing

In [None]:
def one_hot(y, num_classes=10):
    return np.eye(num_classes)[y]


def get_params_from_model(model):
    params = []
    for layer in model.layers:
        for attr in dir(layer):
            param = getattr(layer, attr)
            if isinstance(param, Parameter):
                params.append(param)
    return params

### Model

In [None]:
model = Sequential([
    Conv2D(in_channels=1, out_channels=16, kernel_size=3, padding=1),  # (1, 16, 3, 3)
    ReLU(),
    MaxPool2D(kernel_size=2),  # (16, 16, 3, 3) -> (16, 8, 8)

    Conv2D(in_channels=16, out_channels=32, kernel_size=3, padding=1),  # (16, 32, 3, 3)
    ReLU(),
    MaxPool2D(kernel_size=2),  # (32, 8, 8) -> (32, 4, 4)

    Flatten(),
    Dense(in_features=32 * 7 * 7, out_features=128),  # (32 * 4 * 4, 128) -> (128,)
    ReLU(),
    Dropout(0.3),

    Dense(128, 10),  # (128, 10) -> (10,)
    Softmax()
])

loss_fn = CrossEntropy()
optimizer = Adam(params=get_params_from_model(model), lr=1e-3)

### Training

In [None]:
def accuracy(pred, target):
    pred_labels = np.argmax(pred, axis=1)
    target_labels = np.argmax(target, axis=1)
    return np.mean(pred_labels == target_labels)


epochs = 10
batch_size = 64
num_classes = 10
lr = 1e-3

train_y_oh = one_hot(train_y, num_classes)
test_y_oh = one_hot(test_y, num_classes)

x_one = train_X[0:1]
y_one = train_y_oh[0:1]

In [None]:
train_loss_log = []
test_acc_log = []

for epoch in range(epochs):
    model.train()
    total_loss = 0
    total_batches = 0

    train_X, train_y_oh = shuffle(train_X, train_y_oh)

    for i in range(0, len(train_X), batch_size):
        x_batch = train_X[i:i + batch_size]
        y_batch = train_y_oh[i:i + batch_size]

        output = model.forward(x_batch)
        if i == 0:
            print("Pred :", np.argmax(output[:5], axis=1))
            print("GT   :", np.argmax(y_batch[:5], axis=1))
            print("Unique preds in batch:", np.unique(np.argmax(output, axis=1)))

        loss = loss_fn.forward(output, y_batch)
        grad = loss_fn.backward()
        model.backward(grad)
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss
        total_batches += 1

    avg_loss = total_loss / total_batches
    train_loss_log.append(avg_loss)

    model.eval()
    test_out = model.forward(test_X)
    test_acc = accuracy(test_out, test_y_oh)
    test_acc_log.append(test_acc)

    print(f"[Epoch {epoch + 1:02d}] Loss: {avg_loss:.2f} | Test Accuracy: {test_acc:.2f}")

### Evaluation

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_loss_log, label="Train Loss", color="tab:red", linewidth=2)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.grid(True)
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(test_acc_log, label="Test Accuracy", color="tab:blue", linewidth=2)
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Test Accuracy")
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import numpy as np

test_out = model.forward(test_X)
y_pred = np.argmax(test_out, axis=1)
y_true = test_y

print("Classification Report:\n")
print(classification_report(y_true, y_pred, digits=4))

cm = confusion_matrix(y_true, y_pred)
num_classes = cm.shape[0]

fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(cm, interpolation="nearest", cmap="Blues")
plt.colorbar(im, ax=ax)

ax.set(xticks=np.arange(num_classes),
       yticks=np.arange(num_classes),
       xlabel='Predicted Label',
       ylabel='True Label',
       title='Confusion Matrix')

ax.set_xticklabels(np.arange(num_classes))
ax.set_yticklabels(np.arange(num_classes))

plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

for i in range(num_classes):
    for j in range(num_classes):
        value = cm[i, j]
        ax.text(j, i, str(value),
                ha="center", va="center",
                color="white" if value > cm.max() * 0.5 else "black")

plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

model.eval()

pred_probs = model.forward(test_X)
pred_labels = np.argmax(pred_probs, axis=1)

num_classes = 10
samples_per_class = 10
fig, axes = plt.subplots(num_classes, samples_per_class, figsize=(15, 12))

collected = {cls: 0 for cls in range(num_classes)}

for i in range(len(test_X)):
    true_label = test_y[i]
    pred_label = pred_labels[i]

    if collected[true_label] < samples_per_class:
        ax = axes[true_label, collected[true_label]]
        image = test_X[i][0]  # (1, H, W) -> (H, W)
        ax.imshow(image, cmap="gray")
        ax.axis("off")

        title = f"P:{pred_label}"
        if pred_label != true_label:
            ax.set_title(title, color="red", fontsize=8)
        else:
            ax.set_title(title, color="green", fontsize=8)

        collected[true_label] += 1

    if all(v == samples_per_class for v in collected.values()):
        break

plt.suptitle("10 Samples per Class (Green = Correct, Red = Wrong)", fontsize=16)
plt.tight_layout(rect=(0, 0, 1, 0.97))
plt.show()