Simple multi-layer perceptron (MLP) example for the MNIST dataset. Adapted from:
https://github.com/pytorch/examples/tree/main/mnist

The original pytorch code is BSD-3 licensed:
https://github.com/pytorch/examples/blob/main/LICENSE

In [27]:
# import dependencies
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from prettytable import PrettyTable

In [28]:
# set up model
hidden_dim_1 = 16
hidden_dim_2 = 16

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, hidden_dim_1)
        self.fc2 = nn.Linear(hidden_dim_1, hidden_dim_2)
        self.fc3 = nn.Linear(hidden_dim_2, 10)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        output = F.log_softmax(x, dim=1)
        return output

In [29]:
# training and testing functions
def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


In [30]:
# from: https://stackoverflow.com/a/62508086/4975218
def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    print(table)
    print(f"Total Trainable Params: {total_params}\n")
    return total_params

In [31]:
# parameters controlling the training run
class args: pass
args.batch_size = 64
args.test_batch_size = 1000
args.epochs = 4
args.lr = 1.0 # learning rate
args.gamma = 0.7 # learning rate step gamma
args.seed = 1234 # random seed
args.log_interval = 10 # how frequently is training status reported
args.no_cuda = False  # set to True to not use CUDA for training (even if available)
args.no_mps = False # set to True to not use Mac OS GPUs for training (even if available)


In [32]:
torch.manual_seed(args.seed)

use_cuda = not args.no_cuda and torch.cuda.is_available()
use_mps = not args.no_mps and torch.backends.mps.is_available()

if use_cuda:
    device = torch.device("cuda")
elif use_mps:
    device = torch.device("mps")
else:
    device = torch.device("cpu")

train_kwargs = {'batch_size': args.batch_size}
test_kwargs = {'batch_size': args.test_batch_size}
if use_cuda:
    cuda_kwargs = {'num_workers': 1,
                   'pin_memory': True,
                   'shuffle': True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

dataset1 = datasets.MNIST('./data', train=True, download=True, transform=transform)
dataset2 = datasets.MNIST('./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

model = Net().to(device)
count_parameters(model)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
    scheduler.step()


+------------+------------+
|  Modules   | Parameters |
+------------+------------+
| fc1.weight |   12544    |
|  fc1.bias  |     16     |
| fc2.weight |    256     |
|  fc2.bias  |     16     |
| fc3.weight |    160     |
|  fc3.bias  |     10     |
+------------+------------+
Total Trainable Params: 13002


Test set: Average loss: 0.2379, Accuracy: 9280/10000 (93%)


Test set: Average loss: 0.1949, Accuracy: 9421/10000 (94%)


Test set: Average loss: 0.1774, Accuracy: 9486/10000 (95%)


Test set: Average loss: 0.1690, Accuracy: 9514/10000 (95%)



In [33]:
# save the trained model
torch.save(model.state_dict(), "mnist_mlp.pt")