## Classifying MNIST hand-written digits

In [24]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import TensorDataset, DataLoader

In [25]:
### Obtain train and test data
image_path = './'
transform = transforms.Compose([transforms.ToTensor()])

mnist_train_dataset = torchvision.datasets.MNIST(root=image_path,
                                                train=True,
                                                transform=transform,
                                                download=True)
mnist_test_dataset = torchvision.datasets.MNIST(root=image_path,
                                                train=False,
                                                transform=transform,
                                                download=False)

batch_size = 64
torch.manual_seed(1)
train_dl = DataLoader(mnist_train_dataset, batch_size, shuffle=True)

Define neural network

In [26]:
hidden_units = [32, 16]
image_size = mnist_train_dataset[0][0].shape
print(image_size)


torch.Size([1, 28, 28])


In [27]:
input_size = image_size[0] * image_size[1] * image_size[2]

In [28]:
print(input_size)

784


In [29]:
all_layers = [nn.Flatten()]
print(all_layers)

[Flatten(start_dim=1, end_dim=-1)]


In [30]:
for hidden_unit in hidden_units:
    layer = nn.Linear(input_size, hidden_unit)
    all_layers.append(layer)
    all_layers.append(nn.ReLU())
    input_size = hidden_unit
all_layers.append(nn.Linear(hidden_units[-1], 10))
model = nn.Sequential(*all_layers)

model

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=32, bias=True)
  (2): ReLU()
  (3): Linear(in_features=32, out_features=16, bias=True)
  (4): ReLU()
  (5): Linear(in_features=16, out_features=10, bias=True)
)

In [34]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

torch.manual_seed(1)
num_epochs = 20

for epoch in range(num_epochs):
    accuracy_hist_train = 0
    for x_batch, y_batch in train_dl:
        pred = model(x_batch)
        loss = loss_fn(pred, y_batch)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        is_correct = (torch.argmax(pred, dim=1) == y_batch).float()
        accuracy_hist_train += is_correct.sum()
    accuracy_hist_train /= len(train_dl.dataset)
    print(f'Epoch {epoch} Accuracy {accuracy_hist_train:.4f}')

Epoch 0 Accuracy 0.8558
Epoch 1 Accuracy 0.9303
Epoch 2 Accuracy 0.9442
Epoch 3 Accuracy 0.9534
Epoch 4 Accuracy 0.9580
Epoch 5 Accuracy 0.9622
Epoch 6 Accuracy 0.9657
Epoch 7 Accuracy 0.9678
Epoch 8 Accuracy 0.9698
Epoch 9 Accuracy 0.9723
Epoch 10 Accuracy 0.9740
Epoch 11 Accuracy 0.9750
Epoch 12 Accuracy 0.9768
Epoch 13 Accuracy 0.9782
Epoch 14 Accuracy 0.9791
Epoch 15 Accuracy 0.9808
Epoch 16 Accuracy 0.9819
Epoch 17 Accuracy 0.9826
Epoch 18 Accuracy 0.9829
Epoch 19 Accuracy 0.9847


In [36]:
pred = model(mnist_test_dataset.data / 255.)
is_correct = (torch.argmax(pred, dim=1) == mnist_test_dataset.targets).float()
print(f'Test accuracy: {is_correct.mean():.4f}')

Test accuracy: 0.9679
