# Training a *SotA MLP SoftMax classifier* on the *MNIST* dataset

In [1]:
# Data download:
import os

# NNets & co.:
import numpy as np
import torch as th
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
from ebtorch.nn import FCBlock

# Data(set) handling
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Normalize, Compose, Lambda


In [2]:
# MNIST DataLoader(s) builder


def spawn_mnist_loaders(
    data_root="datasets/",
    batch_size_train=256,
    batch_size_test=512,
    cuda_accel=False,
    **kwargs
):

    os.makedirs(data_root, exist_ok=True)

    transforms = Compose(
        [
            ToTensor(),
            Normalize((0.1307,), (0.3081,)),  # usual magic constants for MNIST
            Lambda(lambda x: th.flatten(x)),
        ]
    )

    trainset = MNIST(data_root, train=True, transform=transforms, download=True)
    testset = MNIST(data_root, train=False, transform=transforms, download=True)

    cuda_args = {}
    if cuda_accel:
        cuda_args = {"num_workers": 1, "pin_memory": True}

    trainloader = DataLoader(
        trainset, batch_size=batch_size_train, shuffle=True, **cuda_args
    )
    testloader = DataLoader(
        trainset, batch_size=batch_size_test, shuffle=False, **cuda_args
    )

    return trainloader, testloader


In [3]:
# Train / Test tooling


def train_epoch(
    model, device, train_loader, loss_fn, optimizer, epoch, print_every_nep
):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % print_every_nep == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    100.0 * batch_idx / len(train_loader),
                    loss.item(),
                )
            )


def test(model, device, test_loader, loss_fn):
    model.eval()
    test_loss = 0
    correct = 0
    with th.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += loss_fn(
                output, target, reduction="sum"
            ).item()  # sum up batch loss
            pred = output.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print(
        "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
            test_loss,
            correct,
            len(test_loader.dataset),
            100.0 * correct / len(test_loader.dataset),
        )
    )


In [4]:
device = th.device("cuda" if th.cuda.is_available() else "cpu")


In [5]:
# Hyperparameters & co.

minibatch_size_train: int = 32  # (cfr. Masters & Luschi, 2018)
minibatch_size_test: int = 512

nrepochs = 20

lossfn = F.nll_loss


In [6]:
train_loader, test_loader = spawn_mnist_loaders(
    batch_size_train=minibatch_size_test,
    batch_size_test=minibatch_size_test,
    cuda_accel=True,
)


In [7]:
model = FCBlock(
    28 * 28,
    [150, 40],
    10,
    hactiv=lambda x: F.leaky_relu_(x, negative_slope=0.06),
    oactiv=lambda x: F.log_softmax(x, dim=1),
    bias=True,
).to(device)
optimizer = th.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, nesterov=True)
scheduler = StepLR(optimizer, step_size=2, gamma=0.7)  # or a lot of patience :)


In [8]:
for epoch in range(1, nrepochs + 1):
    train_epoch(
        model, device, train_loader, lossfn, optimizer, epoch, print_every_nep=10
    )
    test(model, device, test_loader, lossfn)
    test(model, device, test_loader, lossfn)
    scheduler.step()



Test set: Average loss: 0.1617, Accuracy: 56918/60000 (95%)


Test set: Average loss: 0.1617, Accuracy: 56918/60000 (95%)


Test set: Average loss: 0.0819, Accuracy: 58541/60000 (98%)


Test set: Average loss: 0.0819, Accuracy: 58541/60000 (98%)


Test set: Average loss: 0.0676, Accuracy: 58770/60000 (98%)


Test set: Average loss: 0.0676, Accuracy: 58770/60000 (98%)


Test set: Average loss: 0.0567, Accuracy: 58931/60000 (98%)


Test set: Average loss: 0.0567, Accuracy: 58931/60000 (98%)


Test set: Average loss: 0.0332, Accuracy: 59433/60000 (99%)


Test set: Average loss: 0.0332, Accuracy: 59433/60000 (99%)


Test set: Average loss: 0.0287, Accuracy: 59526/60000 (99%)


Test set: Average loss: 0.0287, Accuracy: 59526/60000 (99%)


Test set: Average loss: 0.0204, Accuracy: 59721/60000 (100%)


Test set: Average loss: 0.0204, Accuracy: 59721/60000 (100%)


Test set: Average loss: 0.0178, Accuracy: 59768/60000 (100%)


Test set: Average loss: 0.0178, Accuracy: 59768/60000 (100%)


Tes