# Convolutional Neural Network: MNIST

based on https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392

This example features training a CNN model on the MNIST dataset. This will be done using predefined neural network building blocks defined in `auto_compute.nn`. This example shows how you would go about training a basic NN.

In [None]:
! pip install -q pandas

In [None]:
import auto_compyute as ac
import auto_compyute.nn.functional as F

ac.backends.set_random_seed(0)

device = "cuda" if ac.backends.gpu_available() else "cpu"
device

## Prepare Data

In [None]:
import pandas as pd

# download the datasets, this might take a few seconds
train_url = "https://pjreddie.com/media/files/mnist_train.csv"
train_data = pd.read_csv(train_url, header=None)
train_tensor = ac.tensor(train_data.to_numpy())

test_url = "https://pjreddie.com/media/files/mnist_test.csv"
test_data = pd.read_csv(test_url, header=None)
test = ac.tensor(test_data.to_numpy())

# split the data into train, val, test
idx = ac.randperm(len(train_tensor))
n = int(0.8 * len(train_tensor))
train, val = train_tensor[idx[:n]], train_tensor[idx[n:]]

# split features from targets
X_train, y_train = train[:, 1:], train[:, 0].int()
X_val, y_val = val[:, 1:], val[:, 0].int()
X_test, y_test = test[:, 1:], test[:, 0].int()

# reshape the data into an image format (B, 784) -> (B, 1, 28, 28)
X_train = X_train.view(X_train.shape[0], 1 , 28, -1).float()
X_val = X_val.view(X_val.shape[0], 1, 28, -1).float()
X_test = X_test.view(X_test.shape[0], 1, 28, -1).float()

# scaling
def scale(x: ac.Tensor) -> ac.Tensor:
    return (x - x.mean()) / x.std()

X_train = scale(X_train)
X_val = scale(X_val)
X_test = scale(X_test)

print(f'{X_train.shape=}')
print(f'{y_train.shape=}')
print(f'{X_val.shape=}')
print(f'{y_val.shape=}')
print(f'{X_test.shape=}')
print(f'{y_test.shape=}')

## Build the Neural Network

In [None]:
from auto_compyute import nn

# B = batch size

model = nn.Sequential(
    nn.Conv2D(1, 32, 5), nn.ReLU(),                                                 # out: (B, 32, 24, 24)
    nn.Conv2D(32, 32, 5, bias=False), nn.Batchnorm(32), nn.ReLU(),                  # out: (B, 32, 20, 20)
    nn.MaxPooling2D(2), nn.Dropout(0.25),                                           # out: (B, 32, 10, 10)
    
    nn.Conv2D(32, 64, 3), nn.ReLU(),                                                # out: (B, 64, 8, 8)
    nn.Conv2D(64, 64, 3, bias=False), nn.Batchnorm(64), nn.ReLU(),                  # out: (B, 64, 6, 6)
    nn.MaxPooling2D(2), nn.Dropout(0.25),                                           # out: (B, 64, 3, 3)

    nn.Flatten(),                                                                   # out: (B, 64*3*3)
    nn.Linear(64*3*3, 256, bias=False), nn.Batchnorm(256), nn.ReLU(),               # out: (B, 256)
    nn.Linear(256, 128, bias=False), nn.Batchnorm(128), nn.ReLU(),                  # out: (B, 128)
    nn.Linear(128, 84, bias=False), nn.Batchnorm(84), nn.ReLU(), nn.Dropout(0.25),  # out: (B, 84)
    nn.Linear(84, 10),                                                              # out: (B, 10)
).to(device)

## Training

In [None]:
batch_size = 256
train_loader = nn.Dataloader((X_train, y_train), batch_size, device, shuffle_data=True, drop_remaining=True)
train_steps = len(train_loader)
val_loader = nn.Dataloader((X_val, y_val), batch_size, device)
optim = nn.optimizers.Adam(model.parameters())

In [None]:
def accuracy(y_pred, y_true):
    return (y_pred.argmax(-1) == y_true).mean()

In [None]:
import time

epochs = 5

for e in range(1, epochs + 1):

    # training
    model.train()
    dt = time.perf_counter()
    for step, (x, y) in enumerate(train_loader(), start=1):
        print(f"step {step}/{train_steps}", end="\r")
        optim.reset_param_grads()
        F.cross_entropy_loss(model(x), y).backward()
        optim.update_params()
    dt = time.perf_counter() - dt

    # validation
    model.eval()
    val_loss, val_acc = 0.0, 0.0
    with ac.no_autograd_tracking():
        for x, y in val_loader():
            y_pred = model(x)
            val_loss += F.cross_entropy_loss(y_pred, y).item()
            val_acc += accuracy(y_pred, y).item()
    val_loss /= len(val_loader)
    val_acc /= len(val_loader)

    print(f"epoch {e}/{epochs} | {val_loss=:.4f} | {val_acc=:.4f} | {dt=:.4f}s")    