We'll start by creating the dataset:

In [1]:
from torchvision import datasets
from torchvision.transforms import ToTensor

train_data = datasets.MNIST(
    root='data',
    train=True,
    transform=ToTensor(),
    download=True,
)
validation_data = datasets.MNIST(
    root='data',
    train=False,
    transform=ToTensor()
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



Next, we'll create the data loaders:

In [2]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    train_data,
    batch_size=100,
    shuffle=True)

validation_loader = DataLoader(
    validation_data,
    batch_size=100,
    shuffle=True)

Then, we'll define the NN:

In [3]:
import torch

torch.manual_seed(1234)

hidden_units = 100
classes = 10

net = torch.nn.Sequential(
    torch.nn.Linear(28 * 28, hidden_units),
    torch.nn.BatchNorm1d(hidden_units),
    torch.nn.ReLU(),
    torch.nn.Linear(hidden_units, classes),
)

Next, let's define the optimizer:

In [4]:
cost_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters())

We'll train for 20 epochs:

In [5]:
epochs = 20

for epoch in range(epochs):
    train_loss = 0

    for i, (inputs, targets) in enumerate(train_loader):
        # Flatten 28x28 images into a 784 long vector
        inputs = inputs.view(inputs.shape[0], -1)

        optimizer.zero_grad()  # Zero the gradient
        out = net(inputs)  # Forward pass
        loss = cost_func(out, targets)  # Compute loss
        loss.backward()  # Backward pass
        optimizer.step()  # Weight updates

        train_loss += loss.item() * inputs.size(0)  # Aggregate loss

    train_loss /= len(train_loader.dataset)
    print('Epoch %d, Loss: %.4f' % (epoch + 1, train_loss))

Epoch 1, Loss: 0.3272
Epoch 2, Loss: 0.1421
Epoch 3, Loss: 0.0999
Epoch 4, Loss: 0.0760
Epoch 5, Loss: 0.0611
Epoch 6, Loss: 0.0495
Epoch 7, Loss: 0.0422
Epoch 8, Loss: 0.0358
Epoch 9, Loss: 0.0309
Epoch 10, Loss: 0.0262
Epoch 11, Loss: 0.0228
Epoch 12, Loss: 0.0201
Epoch 13, Loss: 0.0182
Epoch 14, Loss: 0.0174
Epoch 15, Loss: 0.0160
Epoch 16, Loss: 0.0131
Epoch 17, Loss: 0.0121
Epoch 18, Loss: 0.0113
Epoch 19, Loss: 0.0107
Epoch 20, Loss: 0.0101


Finally, we'll run the evaluation:

In [6]:
net.eval()  # set network for evaluation
validation_loss = correct = 0
for inputs, target in validation_loader:
    # Flatten 28x28 images into a 784 long vector
    inputs = inputs.view(inputs.shape[0], -1)

    out = net(inputs)  # Forward pass
    loss = cost_func(out, target)  # Compute loss

    # update running validation loss and accuracy
    validation_loss += loss.item() * inputs.size(0)
    pred = out.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
    correct += pred.eq(target.view_as(pred)).sum().item()

correct = 100 * correct / len(validation_loader.dataset)
validation_loss /= len(validation_loader.dataset)
print('Accuracy: %.1f, Validation loss: %.4f' % (correct, validation_loss))

Accuracy: 97.7, Validation loss: 0.0886
